In [1]:
# 1. pip install torch rich 

import torch
from rich.progress import track
import plotly.express as px
from biotite.sequence.io.fasta import FastaFile


torch.manual_seed(3007)

<torch._C.Generator at 0x1185b9750>

In [21]:
import plotly.io as pio
pio.renderers.default='notebook'


In [2]:
# 2. read in our FASTA file the same way

input_file = "./fasta/acyp.fa"

proteins = []
fasta_file = FastaFile.read(input_file)
for header, sequence in fasta_file.items():
    proteins.append(sequence)

max_protein_length = max(len(w) for w in proteins)

proteins[:3], max_protein_length

(['MCLLSLAAATVAARRTPLRLLGRGLAAAMSTAGPLKSVDYEVFGRVQGVCFRMYTEGEAKKIGVVGWVKNTSKGTVTGQVQGPEDKVNSMKSWLSKVGSPSSRIDRTNFSNEKTISKLEYSNFSIRY',
  'MSSQIKKSKTTTKKLVKSAPKSVPNAAADDQIFCCQFEVFGHVQDFSGVFFRKHTQKKANELGITGWCMNTTRGTVQGMLEGSLDQMTDMKYWLQHKGSPRSVIEKAVFSENEALPINNFKMFSIRR',
  'MLTKLYLKIVLCLLVALPFLSEVTSQNTDTTMTKLVGVDFEVYGRVQGVFFRKYTQKHSTELGLKGWCMNTDKGTVVGRIEGEKEKVEQMKNWLRYTGSPQSAIDKAEFKNEKELSQPSFTNFEIKK'],
 127)

In [3]:
# 3. NLP context: establish vocabulary 

chars = sorted(list(set("".join(proteins))))  # all the possible characters
tokens = sum(len(protein) for protein in proteins)

print(f"Number of examples in the dataset: {len(proteins)}")
print(f"Max protein length: {max_protein_length}")
print(f"Number of unique characters in the vocabulary: {len(chars)}")
print(f"Vocabulary (amino acids): {''.join(chars)}")
print(f"Total tokens: {tokens}")

Number of examples in the dataset: 26779
Max protein length: 127
Number of unique characters in the vocabulary: 20
Vocabulary (amino acids): ACDEFGHIKLMNPQRSTVWY
Total tokens: 2519424


In [4]:
# 4. build our tokenizer, or mapping from tokens to ints

stoi = {s: i + 1 for i, s in enumerate(chars)}
stoi["."] = 0  # assign a special character to 0 

itos = {i: s for s, i in stoi.items()}  # reverse mapping, int to char 

print(itos)

{1: 'A', 2: 'C', 3: 'D', 4: 'E', 5: 'F', 6: 'G', 7: 'H', 8: 'I', 9: 'K', 10: 'L', 11: 'M', 12: 'N', 13: 'P', 14: 'Q', 15: 'R', 16: 'S', 17: 'T', 18: 'V', 19: 'W', 20: 'Y', 0: '.'}


In [5]:
# 5. build the dataset

block_size = 3  # context length: how many characters do we take to predict the next one

X, Y = [], []
for protein in proteins:

    # begin with context of all zeros
    context = [0] * block_size

    # for each character, build a training example 
    for ch in protein + ".":
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        if len(X) < 10:
            print(''.join(itos[i] for i in context), '--->', itos[ix])
        context = context[1:] + [ix]  # crop and append

X = torch.tensor(X)
Y = torch.tensor(Y)

X.shape, X.dtype, Y.shape, Y.dtype

... ---> M
..M ---> C
.MC ---> L
MCL ---> L
CLL ---> S
LLS ---> L
LSL ---> A
SLA ---> A
LAA ---> A


(torch.Size([2546203, 3]), torch.int64, torch.Size([2546203]), torch.int64)

In [6]:
# 6. create training, validation, and test splits

shuffled = torch.randperm(X.shape[0])

n1 = int(0.8 * len(shuffled))
n2 = int(0.9 * len(shuffled))

train_ix = shuffled[:n1]
val_ix = shuffled[n1:n2]
test_ix = shuffled[n2:]

x_train, y_train = X[train_ix], Y[train_ix]
x_val, y_val = X[val_ix], Y[val_ix]
x_test, y_test = X[test_ix], Y[test_ix]

x_train.shape, x_val.shape, x_test.shape

(torch.Size([2036962, 3]), torch.Size([254620, 3]), torch.Size([254621, 3]))

In [7]:
# 7. begin to build the model. first, embedding table

vocab_size = len(chars) + 1

C = torch.randn((vocab_size, 10))

In [8]:
# 8. indexing into embedding table is easy

emb = C[X]
emb.shape

torch.Size([2546203, 3, 10])

In [9]:
# 9. create a linear layer (weight matrix and bias)

W1 = torch.randn((30, 100))
b1 = torch.randn(100)

In [10]:
# 10. hidden state after flatten, multiply by weight, add bias, and tanh 

h = torch.tanh(emb.view(-1, 30) @ W1 + b1)

h.shape 

torch.Size([2546203, 100])

In [11]:
# 11. second later, output 

W2 = torch.randn((100, vocab_size))
b2 = torch.randn(vocab_size)

In [12]:
# 12. logits 

logits = h @ W2 + b2

logits[0]

tensor([ -9.8112,   2.5057, -14.8283,  -3.1949,  -2.8039,   4.5284,   9.2156,
         -4.5143, -14.6304,   0.3441, -11.5997,  -5.3201, -10.2188,  -1.6343,
          8.1713,   9.7798,   2.7054,   5.8052,   5.1445,  -2.0171, -17.2942])

In [13]:
# 13. unnormalized counts and probabilities 

counts = logits.exp()
prob = counts / counts.sum(1, keepdims=True)

prob[0]

tensor([1.7195e-09, 3.8419e-04, 1.1390e-11, 1.2847e-06, 1.8996e-06, 2.9042e-03,
        3.1524e-01, 3.4342e-07, 1.3882e-11, 4.4237e-05, 2.8751e-10, 1.5342e-07,
        1.1439e-09, 6.1181e-06, 1.1094e-01, 5.5421e-01, 4.6912e-04, 1.0412e-02,
        5.3777e-03, 4.1720e-06, 9.6739e-13])

In [14]:
# 14. negative log likelihood loss on the normalized probabilities (same as cross entropy)

loss = -prob[torch.arange(X.shape[0]), Y].log().mean()

loss

tensor(17.4846)

In [15]:
# 15. cleaned up neural network 

block_size = 3 
embed_size = 10 
hidden_size = 200 

C = torch.randn((vocab_size, embed_size))

W1 = torch.randn((embed_size * block_size, hidden_size))
b1 = torch.randn(hidden_size)

W2 = torch.randn((hidden_size, vocab_size))
b2 = torch.randn(vocab_size)

parameters = [C, W1, b1, W2, b2]

In [16]:
# 16. number of trainable params 

print(f"Model has {sum(p.nelement() for p in parameters):,} trainable params")

for p in parameters:
    p.requires_grad = True

Model has 10,631 trainable params


In [17]:
# 17. training loop 

training_loss = []
total_steps = 200_000

for i in track(range(total_steps)):

    # create a minibatch of size 32 
    ix = torch.randint(0, x_train.shape[0], (32,))

    # forward pass
    emb = C[x_train[ix]]  # (32, 3, 10)
    h = torch.tanh(emb.view(-1, 30) @ W1 + b1)  # (32, 200)
    logits = h @ W2 + b2  # (32, vocab_size)
    loss = torch.nn.functional.cross_entropy(logits, y_train[ix])
    # print(loss.item())

    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()

    # update
    lr = 0.1 if i < 100000 else 0.01
    for p in parameters:
        p.data += -lr * p.grad

    # track stats
    training_loss.append(loss.log10().item())

print(loss.item())

Output()

2.4765069484710693


In [22]:
# 18. visualize training loss over minibatches 

px.line(x=range(total_steps), y=training_loss, width=512)

In [23]:
# 21. visualize embedding space 

px.scatter(
    x=C[:, 0].data,
    y=C[:, 1].data,
    size=[10] * C.shape[0],
    text=[itos[i] for i in range(C.shape[0])],
    width=512,
    height=512,
)

In [20]:
# 20 (final). sample from the model

for _ in range(20):
    out = []
    context = [0] * block_size  # initialize with all zeros 
    while True:
        emb = C[torch.tensor([context])]  # (1, block_size, d)
        h = torch.tanh(emb.view(1, -1) @ W1 + b1)
        logits = h @ W2 + b2
        probs = torch.nn.functional.softmax(logits, dim=1)
        ix = torch.multinomial(probs, num_samples=1).item()
        context = context[1:] + [ix]
        out.append(ix)
        if ix == 0:
            break

    print(f">sample\n{''.join(itos[i] for i in out)}")

>sample
MGITGEFKNFYNTAGGVEGKQQDFVENLETVAGWVRNISSMDEDDFEIRYARVDGHVQGVWFRRQWVHGKVERAGDATQYTARVDLTAPGARNQPDTTIFHGRVEYEGSESRADGRVETVTGINGIKGKVQGVSYVDYARNDSSVDAVEKLLAALVSLGLSGFEILVRNLPDGEAALLGWCRSGPPLAAGEELLRGYEGPRDEGRVEQLALWIRNERALFDGSVEAVFEGFSLANGLVGTDDFEVI.
>sample
MNKQISAVRKTQLFVSGVSYEVMKLGRVQGVFFSKAQIAVTWSSAGPPLASVDSVSLITWLENLDDGNVKLEKR.
>sample
MLKLGLDGFDRKRVHVFISGRVEVMKATELRLEGE.
>sample
MIKWLKKITAIVY.
>sample
MKMESLRGTPFTGHVEAMGIYSSFAVKYTTQHVDYVAEGGYDSFLGTATVEVLYEGSDDFDIEEKSPRQAESLAGSSNQPDRVAGWVRNALAGWVRNLADGRVQVQGVCFRQSQAPQGLVQGVFYRAATGGVEYNGKVMVADVEMKVE.
>sample
MAENIDKMLTWATPGRREIAVEYEQHKSDGTVEILIQSVGFRMSKMRFFALHVAVRLTAWVENLPDGRVQGVFFHGRVQGVGFRRFEVKFDELSVDIVSGRVQTFFVDGIVG.
>sample
MSEFINELIAWVHGRVQGVLGLAGWVRNESGFEKIWITR.
>sample
MIRAGFAKYIIVDNILV.
>sample
MAEIGANRQDTGEFTDGRVEALVRWATPAHGEAETARLGLTGFRYTTRILKGYQEMLQETGDPMEKYKTQVPEYSGFEIQWLQGFSGFSIQYGRVQVVVEGAVQNLPDGSADEVHFEGPRNEEEKAICEWIRGRVQGVGFRYHTKIQGRVQGVGFRWAVAEELGVFFEKVVQGVYFRGFEIRY.
>sample
MARGRVQGVGFRYSVADGRVEAFEVTMKIHYFPTT