In [1]:
from tokenizers.trainers import BpeTrainer 
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.pre_tokenizers import Whitespace
from makemore import BpeProteinDataset
from biotite.sequence.io.fasta import FastaFile

In [2]:
# read in data

In [3]:
fasta = FastaFile.read("./datasets/bglb.fa")
proteins = [] 

for header, sequence in fasta.items():
    if 256 <= len(sequence) <= 512:
        proteins.append(sequence)
    
proteins[:2], len(proteins) 

(['MTGPDQFPQFPPGFVWGVATASYQIEGAVTEDGRGPSVWDTFSHTPGRTHDGDTGDVADDHYHRYPEDLDLMAGLGVNAYRFSIAWPRIQPTGEGPVNPAGLAFYDRLVDAMLAKGITPAATLYHWDLPQALEDKGGWTQRDIPHYFAEYTAAVADKLGDRVGLWCTLNEPFIVTAFGYVLGVHAPGRQLFTDAFAVAHHQLLGHGLAVEALRAANVTGTIGVVNALAPVHPDSDDPADHVAAGILDTLMNRTYTDPLLLGRYPEETPAVYAGADLSVVKDGDLSTISTPIDFFGVNFYNPHRVRAAAPEQFGTGPLNFETVEYPGVPTTAMGWPVVPEAFTELLTGLHERYGEKLPPIYITENGAAYDDEPGPDGRVRDDDRIAYLDRHLRAVHAAMAAGADVRGYFCWSFLDNFEWAEGYQKRFGLVRVDYETLERTPKASYDWYRSVIALAAAVGP',
  'MHHLPQDFVWGVATAAYQIEGAVDVDGRSPSIWDTFGRVPGAIANGDTGDVACDHYHRWPEDLGLIRELGVDAYRFSVAWPRVIPTGTGAVNAKGLAFYDRLVDELLAAGIRPFVTLYHWDLPQVLQDKGGWPARHTAEAFADYAAVVAGALGDRVTDWTTVNEPLCVAWIGHLEGTMAPGLRDLHRAIDASHHVLLGHGLATQAIRAAASARADVGIVLNPSPADAATDRPEDAAAAVRADGHTNRWWLDPLYGRGYPRDMVETYGYEPPVLDGDLETIATPTDYLGVNYYFRAVVADDPSGPAPYAKQVEVSGRHTAMGWEVNPGGLARILTRIAADYAPERIFVTEQGSAWPDVVEPDGSIADKDRIDYLEQHLEAINAAVDAGVPLAGYFVWSLLDNLEWAYGYDKRFGLVHVDYATQQRTMKASGRRYAEIVREHRAARG'],
 15230)

In [4]:
max_word_length = max(len(w) for w in proteins)

max_word_length

512

In [5]:
# train the tokenizer 

In [6]:
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()
trainer = BpeTrainer(
        vocab_size=30_000,
        min_frequency=0, 
        special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"])

tokenizer.train_from_iterator(proteins, trainer)

tokenizer.save("tokenizer-bglb.json")

In [7]:
# load saved tokenizer 

tokenizer = Tokenizer.from_file("tokenizer-bglb.json")

tokenizer

<tokenizers.Tokenizer at 0x7f804b334200>

In [8]:
tokenizer.encode("MCAT").ids

[15, 1498]

In [11]:
tokenizer.decode([15, 1498])

'M CAT'

In [12]:
tokenizer.get_vocab_size()

30000

In [13]:
dataset = BpeProteinDataset(proteins, tokenizer, max_word_length)

In [14]:
for item in dataset:
    print(item)
    break 

(tensor([    0,  2797, 10970,   733,  2404,  2290, 12386,  3516,   308,  2117,
        10211,  7850,  5414,   124,  5277,  8157,     8, 22490,   845,  1102,
          171,  1269, 26455,  4770,   103,   432,   281,   698,   419,   190,
        27792,   128,    27,  3445,  1057,  6766,   867,   112,   272,  1082,
          158,  1556,    69,  6985,     8,  1776,   209,    78,   181, 12522,
          590,  4030,  1236,  2819,    83,   460,  1233,   403,   316,    50,
          302,  1569,  1632,   558,   564,    68,  7617,    80, 13887,  6499,
         2414,  2222,  1054, 24265,    36,   641,    48,    18,   329,  7635,
         1665,  5052,  8020,  8131,   365, 11894, 26574,  4715, 20760,  1689,
         5116, 20263,  2630,  3991,   312,  9613, 17876, 24847,  6709,  8531,
         3657,  2752, 16147,   401,  2095,   602,    97,   294,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,