In [1]:
# start with same imports as last time, plus torch
import plotly.express as px
from biotite.sequence.io.fasta import FastaFile

import torch
from rich.progress import track


torch.manual_seed(12)

<torch._C.Generator at 0x11e80ba50>

In [2]:
# read in our FASTA file the same way

input_file = "./fasta/acyp.fa"

proteins = []
fasta_file = FastaFile.read(input_file)
for header, sequence in fasta_file.items():
    proteins.append(sequence)

max_protein_length = max(len(w) for w in proteins)

proteins[:3], max_protein_length

(['MCLLSLAAATVAARRTPLRLLGRGLAAAMSTAGPLKSVDYEVFGRVQGVCFRMYTEGEAKKIGVVGWVKNTSKGTVTGQVQGPEDKVNSMKSWLSKVGSPSSRIDRTNFSNEKTISKLEYSNFSIRY',
  'MSSQIKKSKTTTKKLVKSAPKSVPNAAADDQIFCCQFEVFGHVQDFSGVFFRKHTQKKANELGITGWCMNTTRGTVQGMLEGSLDQMTDMKYWLQHKGSPRSVIEKAVFSENEALPINNFKMFSIRR',
  'MLTKLYLKIVLCLLVALPFLSEVTSQNTDTTMTKLVGVDFEVYGRVQGVFFRKYTQKHSTELGLKGWCMNTDKGTVVGRIEGEKEKVEQMKNWLRYTGSPQSAIDKAEFKNEKELSQPSFTNFEIKK'],
 127)

In [3]:
chars = sorted(list(set("".join(proteins))))  # all the possible characters
tokens = sum(len(protein) for protein in proteins)

print(f"Number of examples in the dataset: {len(proteins)}")
print(f"Max protein length: {max_protein_length}")
print(f"Number of unique characters in the vocabulary: {len(chars)}")
print(f"Vocabulary (amino acids): {''.join(chars)}")
print(f"Total tokens: {tokens}")

Number of examples in the dataset: 26779
Max protein length: 127
Number of unique characters in the vocabulary: 20
Vocabulary (amino acids): ACDEFGHIKLMNPQRSTVWY
Total tokens: 2519424


In [4]:
# build our tokenizer, or mapping from tokens to ints

stoi = {s: i + 1 for i, s in enumerate(chars)}
stoi["."] = 0
itos = {i: s for s, i in stoi.items()}

print(itos)

{1: 'A', 2: 'C', 3: 'D', 4: 'E', 5: 'F', 6: 'G', 7: 'H', 8: 'I', 9: 'K', 10: 'L', 11: 'M', 12: 'N', 13: 'P', 14: 'Q', 15: 'R', 16: 'S', 17: 'T', 18: 'V', 19: 'W', 20: 'Y', 0: '.'}


In [5]:
# build the dataset

block_size = 3  # context length: how many characters do we take to predict the next one

X, Y = [], []
for protein in proteins:

    context = [0] * block_size
    for ch in protein + ".":
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        if len(X) < 10:
            print(''.join(itos[i] for i in context), '--->', itos[ix])
        context = context[1:] + [ix]  # crop and append

X = torch.tensor(X)
Y = torch.tensor(Y)

X.shape, X.dtype, Y.shape, Y.dtype

... ---> M
..M ---> C
.MC ---> L
MCL ---> L
CLL ---> S
LLS ---> L
LSL ---> A
SLA ---> A
LAA ---> A


(torch.Size([2546203, 3]), torch.int64, torch.Size([2546203]), torch.int64)

In [6]:
# create training, validation, and test splits

shuffled = torch.randperm(X.shape[0])

n1 = int(0.8 * len(shuffled))
n2 = int(0.9 * len(shuffled))

train_ix = shuffled[:n1]
val_ix = shuffled[n1:n2]
test_ix = shuffled[n2:]

x_train, y_train = X[train_ix], Y[train_ix]
x_val, y_val = X[val_ix], Y[val_ix]
x_test, y_test = X[test_ix], Y[test_ix]

x_train.shape, x_val.shape, x_test.shape

(torch.Size([2036962, 3]), torch.Size([254620, 3]), torch.Size([254621, 3]))

In [7]:
# start to build the model

In [8]:
# embedding table

vocab_size = len(chars) + 1

C = torch.randn((vocab_size, 2))

In [9]:
# indexing into embedding table is easy

emb = C[X]
emb.shape

torch.Size([2546203, 3, 2])

In [10]:
# create first layer, 

W1 = torch.randn((6, 100))
b1 = torch.randn(100)

In [11]:
h = torch.tanh(emb.view(-1, 6) @ W1 + b1)

h

tensor([[-0.9801,  0.9530,  0.8693,  ..., -1.0000, -0.9664, -0.8894],
        [-0.9461,  0.9989,  0.8651,  ..., -1.0000, -0.9986, -0.9158],
        [-0.9902,  0.9998,  0.9241,  ..., -1.0000, -0.9997, -0.9959],
        ...,
        [ 0.4237,  0.4925, -0.9961,  ..., -0.9957,  0.9821, -0.8355],
        [-0.0798,  0.7958, -0.9997,  ...,  0.0749,  0.8921, -0.8575],
        [ 0.0146,  0.9768, -0.9995,  ..., -0.9709, -0.3588, -0.4815]])

In [12]:
h.shape

torch.Size([2546203, 100])

In [13]:
# second layer

W2 = torch.randn((100, vocab_size))
b2 = torch.randn(vocab_size)

In [14]:
# output layer

logits = h @ W2 + b2

In [15]:
logits.shape

torch.Size([2546203, 21])

In [16]:
counts = logits.exp()

In [17]:
prob = counts / counts.sum(1, keepdims=True)

In [18]:
prob.shape

torch.Size([2546203, 21])

In [19]:
loss = -prob[torch.arange(X.shape[0]), Y].log().mean()
loss

tensor(14.7418)

In [20]:
# cleaned up 

embed_size = 10 

g = torch.Generator().manual_seed(2147483647) # for reproducibility

C = torch.randn((vocab_size, 10), generator=g)

W1 = torch.randn((30, 200), generator=g)
b1 = torch.randn(200, generator=g)

W2 = torch.randn((200, vocab_size), generator=g)
b2 = torch.randn(vocab_size, generator=g)

parameters = [C, W1, b1, W2, b2]

In [21]:
sum(p.nelement() for p in parameters)  # number of parameters in total

10631

In [22]:
for p in parameters:
    p.requires_grad = True

In [23]:
lre = torch.linspace(-3, 0, 1000)
lrs = 10**lre

In [24]:
lri = []
lossi = []
stepi = []

In [25]:
for i in track(range(200000)):

    # minibatch construct
    ix = torch.randint(0, x_train.shape[0], (32,))

    # forward pass
    emb = C[x_train[ix]]  # (32, 3, 10)
    # print(emb.shape)
    h = torch.tanh(emb.view(-1, 30) @ W1 + b1)  # (32, 200)
    logits = h @ W2 + b2  # (32, vocab_size)
    loss = torch.nn.functional.cross_entropy(logits, y_train[ix])
    # print(loss.item())

    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()

    # update
    # lr = lrs[i]
    lr = 0.1 if i < 100000 else 0.01
    for p in parameters:
        p.data += -lr * p.grad

    # track stats
    # lri.append(lre[i])
    stepi.append(i)
    lossi.append(loss.log10().item())

# print(loss.item())

Output()

In [26]:
import plotly.express as px

px.line(x=stepi, y=lossi, width=512)

In [27]:
px.scatter(
    x=C[:, 0].data,
    y=C[:, 1].data,
    size=[10] * C.shape[0],
    text=[itos[i] for i in range(C.shape[0])],
    width=512,
    height=512,
)

In [28]:
# sample from the model

g = torch.Generator().manual_seed(12)

for _ in range(20):

    out = []
    context = [0] * block_size  # initialize with all ...
    while True:
        emb = C[torch.tensor([context])]  # (1,block_size,d)
        h = torch.tanh(emb.view(1, -1) @ W1 + b1)
        logits = h @ W2 + b2
        probs = torch.nn.functional.softmax(logits, dim=1)
        ix = torch.multinomial(probs, num_samples=1, generator=g).item()
        context = context[1:] + [ix]
        out.append(ix)
        if ix == 0:
            break

    print(f">sample\n{''.join(itos[i] for i in out)}")

>sample
MNHEALAGWVQGVGFRYHTPPTA.
>sample
MGLVGHVQGVGFRYTTYGRVEGSRDRLRVWEGSVEVVAETTRGTVESTFSISYFSSDEKVAEGPALDLCRWLRSGPRDFTIRY.
>sample
MKDRVEPASPGARVYSDNAADDLEF.
>sample
MNEEDAEIFANAHFSLAGERALVHLEIDRLHEGPAGDDTPGRVEVVAQGTSETQ.
>sample
MPEELRNARQRAHVKAGVDGTVLFHGAVQIEAYALNTLDKVDKTDTKRMIAAGELKFLKRLGVRGWVMSESRLAKYHVAQGLKERFLARGSVEALELFHGVGFRYWAQNEALCARLYTQQERLGLTGWVRNLPDGSVVLELGVVSGLPNNIFNKDDTPEAVLELGLERGSPEARALEKQIAAGPDSTPKVESAAVKGWSIASKLSLVSCKQLNIKAFSEETLTIQY.
>sample
MSPYAAFVEEVPPYA.
>sample
MAEDLPVIACGEEKPEQLVIQCRKGTFSIRY.
>sample
MKVCGEQQRCSLTDVGEVQGVGFKITLAQALEIGGLVKDVCIIAWVHGQVEVVWVHGMAGPALALLSRVEPHSTIAFDIR.
>sample
MKRFEVRWEHWKHRLGLTGYARGNVQGVGPPLTGFVRNLTDISGRVEVELLIVAQGPHSKVQHVKLENIKISFSIAEITLYEAKVFSSFPEWARPQ.
>sample
MIEWEKPGGPSAARLRGFRVSEIQIKGYAKNLSDFNQKGRVEVFAGPTA.
>sample
MTRYGPQESLLEGWVRNLMPKTIDNIKSGPPAARVRY.
>sample
MQKIEKQCVKNQSDNAIVHVKITLDREEEDAEGFNFYIKISIRY.
>sample
MEGSSQSGDILH.
>sample
MSRESGRVQGVGFRTYTE.
>sample
MSKIEGDADKLSFKIVEEVHGRVDRMWAHVSHVEEVQGPKKEESPEAVGFRYFMADVEK