In [1]:
%load_ext autoreload
%autoreload 2

In [6]:
from Bio import SeqIO
import csv

import tqdm
import pickle
from torch.utils.data import DataLoader

# import proteinbert_gen.constants as consts
from proteinbert_gen.tokenizer import ProteinTokenizer
from proteinbert_gen.dataset import sprot_train

In [5]:
ofile = "../data/uniprot_sprot_1m.csv"

debug = False
count = 0
max_count = 1_000_000
min_len = 10
max_len = 510 # 512 - 2 (i.e. start + end)
tokenizer = ProteinTokenizer()

with open(ofile, "w", newline="") as fout, open("/data/protein-modeling/uniprot/uniprot_sprot.dat") as f:
    writer = csv.writer(fout)
    
    for record in SeqIO.parse(f, "swiss"):
        if len(record.seq) > max_len:
            continue
        if len(record.seq) < min_len:
            continue

        if not tokenizer.is_valid_seq(record.seq):
            continue

        count += 1

        if debug:
            print(record.id, record.seq[:32])
            # record.annotations["protein_existence"]; https://www.uniprot.org/help/protein_existence
        pid, seq = record.id, record.seq
        desc = record.annotations.get("comment", "<NO_DATA>")

        writer.writerow([count, pid, seq, desc])
        
        if count > max_count:
            break

In [7]:
# generate word freq dictionary
# word_freq = torch.zeros((consts.VOCAB_SIZE,), dtype=torch.int64)
word_freq = {token: 0 for token in tokenizer.ALL_TOKENS}

for data in tqdm.tqdm(sprot_train):
    # print(data["seq"])
    for aa in data["seq"]:
        word_freq[aa] += 1

print(word_freq)

with open("../data/sprot_1m_word_freq_dict.pkl", "wb") as f:
    pickle.dump(word_freq, f)

100%|███████████████████████████████████████████████████████████████████████████| 325230/325230 [00:04<00:00, 69307.81it/s]

{'A': 7230031, 'C': 1124625, 'D': 4463802, 'E': 5439173, 'F': 3237405, 'G': 6113797, 'H': 1878744, 'I': 5098187, 'K': 4810102, 'L': 8086525, 'M': 2096564, 'N': 3206475, 'P': 3758540, 'Q': 3079117, 'R': 4652322, 'S': 5085484, 'T': 4357979, 'U': 226, 'V': 5903758, 'W': 893809, 'X': 4038, 'Y': 2415353, '&': 0, '^': 0, '$': 0, '_': 0}





In [6]:
len(sprot_train)

325340