In [1]:
import math 
import random 
from dataclasses import dataclass

import torch
import torch.nn as nn 
from torch.nn import functional as F
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from biotite.sequence.io.fasta import FastaFile


torch.manual_seed(12)
random.seed(42)

In [2]:
# define a Torch dataset class for our data 

class ProteinDataset(Dataset):
    """Dataset for protein sequences with character-level tokenization"""

    def __init__(self, proteins, chars, max_protein_length):
        """Create a dataset 
        
        proteins: list of str, protein sequences
        chars: list of str, all the characters in the vocabulary
        max_protein_length: int, the length of the longest sequence 
        """
        self.proteins = proteins
        self.chars = chars
        self.max_protein_length = max_protein_length
        self.stoi = {ch: i + 1 for i, ch in enumerate(chars)}
        self.itos = {i: s for s, i in self.stoi.items()} # inverse mapping

    def __len__(self):
        return len(self.proteins)

    def contains(self, protein):
        return protein in self.proteins

    def get_vocab_size(self):
        return len(self.chars) + 1 # all the possible characters and special 0 token

    def get_output_length(self):
        return self.max_protein_length + 1 # <START> token followed by proteins

    def encode(self, protein):
        ix = torch.tensor([self.stoi[w] for w in protein], dtype=torch.long)
        return ix

    def decode(self, ix):
        word = ''.join(self.itos[i] for i in ix)
        return word

    def __getitem__(self, idx):
        protein = self.proteins[idx]
        ix = self.encode(protein)
        x = torch.zeros(self.max_protein_length + 1, dtype=torch.long)
        y = torch.zeros(self.max_protein_length + 1, dtype=torch.long)
        x[1:1+len(ix)] = ix
        y[:len(ix)] = ix
        y[len(ix)+1:] = -1 # index -1 will mask the loss at the inactive locations
        return x, y

In [3]:
# example, some short proteins of length 38 

proteins = [
    "MCLLSLAAATVAARRTPLRLLGRGLAAAMSTAGPLKSV", 
    "MSSQIKKSKTTTKKLVKSAPKSVPNAAADDQIFCCQFE", 
    "MCLLSLAAATVAARRTPLRLLGRGLAAAMSTAGPLKSV", 
]

chars = "ACDEFGHIKLMNPQRSTVWY"

max_length = 38 

dataset = ProteinDataset(proteins, chars, max_length)

x, y = dataset[0]

x

tensor([ 0, 11,  2, 10, 10, 16, 10,  1,  1,  1, 17, 18,  1,  1, 15, 15, 17, 13,
        10, 15, 10, 10,  6, 15,  6, 10,  1,  1,  1, 11, 16, 17,  1,  6, 13, 10,
         9, 16, 18])

In [4]:
# now wrap this logic into a nice function 

def create_datasets(input_file):
    """Create train and test datasets from a FASTA file (90/10 split)"""

    # preprocessing of the input text file
    proteins = []
    fasta_file = FastaFile.read(input_file) 
    for header, sequence in fasta_file.items():
        proteins.append(sequence)
    max_protein_length = max(len(w) for w in proteins)

    # partition the input data into a training and the test set
    test_set_size = int(len(proteins) * 0.1) # 10% of the training set
    rp = torch.randperm(len(proteins)).tolist()
    train_proteins = [proteins[i] for i in rp[:-test_set_size]]
    test_proteins = [proteins[i] for i in rp[-test_set_size:]]
    print(f"Split up the dataset into {len(train_proteins)} training examples and {len(test_proteins)} test examples")

    chars = sorted(list(set(''.join(proteins)))) # all the possible characters
    tokens = sum(len(w) for w in proteins)
    
    print(f"Number of examples in the dataset: {len(proteins)}")
    print(f"Max protein length: {max_protein_length}")
    print(f"Number of unique characters in the vocabulary: {len(chars)}")
    print(f"Vocabulary (amino acids): {''.join(chars)}")
    print(f"Total tokens: {tokens}")

    # wrap in dataset objects
    train_dataset = ProteinDataset(train_proteins, chars, max_protein_length)
    test_dataset = ProteinDataset(test_proteins, chars, max_protein_length)

    return train_dataset, test_dataset

In [5]:
# example of using this on our dataset 

input_file = "./fasta/hypf.fa"

train_dataset, test_dataset = create_datasets(input_file)

x, y = train_dataset[0]

x 

Split up the dataset into 24102 training examples and 2677 test examples
Number of examples in the dataset: 26779
Max protein length: 127
Number of unique characters in the vocabulary: 20
Vocabulary (amino acids): ACDEFGHIKLMNPQRSTVWY
Total tokens: 2519424


tensor([ 0, 11, 10, 15,  9,  7,  8,  5, 18, 16,  6, 15, 18, 14,  6, 18, 20, 20,
        15, 18,  5,  8, 15,  9, 17,  1, 19,  4, 17,  6, 18,  9,  6, 19, 18, 15,
        12, 15, 15,  3,  6, 15, 18,  4,  1, 18, 11,  4,  6,  4, 13,  4,  1,  8,
        12,  4, 18,  8, 15, 15,  2, 15, 14,  6, 13, 13,  6,  1,  5,  8,  3,  9,
         8,  3,  8,  3,  8,  4, 13,  5, 17,  6,  4,  5,  3,  3,  5, 12,  8, 18,
        13, 17, 18,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0])

In [6]:
# ok so to get started, let's keep good habits and record all our params in one place! 

@dataclass
class ModelConfig:
    block_size: int = None # length of the input sequences of integers
    vocab_size: int = None # the input integers are in range [0 .. vocab_size -1]
    n_layer: int = 4
    embed_dim: int = 64
    n_head: int = 4
    learning_rate: float = 4e-3
    batch_size: int = 32 
    device = "cpu"
    max_steps = 100 


config = ModelConfig() 

In [7]:
# define the whole transformer model


class NewGELU(nn.Module):
    """
    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
    Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
    """

    def forward(self, x):
        return (0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))))


class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.embed_dim % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.embed_dim, 3 * config.embed_dim)
        # output projection
        self.c_proj = nn.Linear(config.embed_dim, config.embed_dim)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer(
            "bias",
            torch.tril(torch.ones(config.block_size, config.block_size)).view(
                1, 1, config.block_size, config.block_size
            ),
        )
        self.n_head = config.n_head
        self.embed_dim = config.embed_dim

    def forward(self, x):
        B, T, C = (x.size())

        # Calculate querys, keys, values for all heads in batch, using head dim as additional batch dimension 
        q, k, v = self.c_attn(x).split(self.embed_dim, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, head_size)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, head_size)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, head_size)

        # Causal self-attention: (B, nh, T, hs) @ (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        y = att @ v  # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)
        y = (y.transpose(1, 2).contiguous().view(B, T, C))  # re-assemble all head outputs side by side

        # output projection
        y = self.c_proj(y)
        return y


class Block(nn.Module):
    """an unassuming Transformer block"""

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.embed_dim)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.embed_dim)
        self.mlp = nn.ModuleDict(
            dict(
                c_fc=nn.Linear(config.embed_dim, 4 * config.embed_dim),
                c_proj=nn.Linear(4 * config.embed_dim, config.embed_dim),
                act=NewGELU(),
            )
        )
        m = self.mlp
        self.mlpf = lambda x: m.c_proj(m.act(m.c_fc(x)))  # MLP forward

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlpf(self.ln_2(x))
        return x


class Transformer(nn.Module):
    """Transformer Language Model, exactly as seen in GPT-2"""

    def __init__(self, config):
        super().__init__()
        self.block_size = config.block_size

        self.transformer = nn.ModuleDict(
            dict(
                wte=nn.Embedding(config.vocab_size, config.embed_dim),
                wpe=nn.Embedding(config.block_size, config.embed_dim),
                h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
                ln_f=nn.LayerNorm(config.embed_dim),
            )
        )
        self.lm_head = nn.Linear(config.embed_dim, config.vocab_size, bias=False)

        # report number of parameters (note we don't count the decoder parameters in lm_head)
        n_params = sum(p.numel() for p in self.transformer.parameters())
        print("number of parameters: %.2fM" % (n_params / 1e6,))

    def get_block_size(self):
        return self.block_size

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert (
            t <= self.block_size
        ), f"Cannot forward sequence of length {t}, block size is only {self.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)  # shape (1, t)

        # forward the GPT model itself
        tok_emb = self.transformer.wte(idx)  # token embeddings of shape (b, t, embed_dim)
        pos_emb = self.transformer.wpe(pos)  # position embeddings of shape (1, t, embed_dim)
        x = tok_emb + pos_emb
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)

        # if we are given some desired targets also calculate the loss
        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
            )

        return logits, loss

In [8]:
# example of making a tiny model 

config.block_size = train_dataset.get_output_length()
config.vocab_size = train_dataset.get_vocab_size()
model = Transformer(config) 

model 

number of parameters: 0.21M


Transformer(
  (transformer): ModuleDict(
    (wte): Embedding(21, 64)
    (wpe): Embedding(128, 64)
    (h): ModuleList(
      (0-3): 4 x Block(
        (ln_1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=64, out_features=192, bias=True)
          (c_proj): Linear(in_features=64, out_features=64, bias=True)
        )
        (ln_2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (mlp): ModuleDict(
          (c_fc): Linear(in_features=64, out_features=256, bias=True)
          (c_proj): Linear(in_features=256, out_features=64, bias=True)
          (act): NewGELU()
        )
      )
    )
    (ln_f): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=64, out_features=21, bias=False)
)

In [10]:
# ok so now let's do a basic training loop

# init optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, betas=(0.9, 0.99), eps=1e-8)

train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    sampler=torch.utils.data.RandomSampler(train_dataset, replacement=True, num_samples=int(1e10)),
    pin_memory=True,
)

# training loop
best_loss = None
step = 0
for batch in train_loader:

    # get the next batch, ship to device, and unpack it to input and target
    batch = [t.to(config.device) for t in batch]
    X, Y = batch

    # feed into the model
    logits, loss = model(X, Y)

    # calculate the gradient, update the weights
    model.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    # logging
    if step % 10 == 0:
        print(f"step {step} | loss {loss.item():.4f}")

    step += 1

    # termination conditions
    if config.max_steps >= 0 and step >= config.max_steps:
        break

step 0 | loss 3.2072
step 10 | loss 2.6525
step 20 | loss 2.5570
step 30 | loss 2.4798
step 40 | loss 2.4857
step 50 | loss 2.2520
step 60 | loss 2.2820
step 70 | loss 2.1116
step 80 | loss 2.1213
step 90 | loss 2.1865


In [11]:
# OK so you're gonna want to generate from this! 

@torch.no_grad()
def generate(model, idx, max_new_tokens, temperature=1.0, greedy=False, top_k=None):
    """
    Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
    the sequence max_new_tokens times, feeding the predictions back into the model each time.
    Most likely you'll want to make sure to be in model.eval() mode of operation for this.
    """
    block_size = model.get_block_size()
    for _ in range(max_new_tokens):
        # if the sequence context is growing too long we must crop it at block_size
        idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:]
        # forward the model to get the logits for the index in the sequence
        logits, _ = model(idx_cond)
        # pluck the logits at the final step and scale by desired temperature
        logits = logits[:, -1, :] / temperature
        # optionally crop the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float('Inf')
        # apply softmax to convert logits to (normalized) probabilities
        probs = F.softmax(logits, dim=-1)
        # either take most likely (greedy) or sample from the distribution 
        if greedy:
            _, idx_next = torch.topk(probs, k=1, dim=-1)
        else:
            idx_next = torch.multinomial(probs, num_samples=1)
        # append sampled index to the running sequence and continue
        idx = torch.cat((idx, idx_next), dim=1)

    return idx


def print_samples(num=10):
    """ samples from the model and pretty prints the decoded samples """
    X_init = torch.zeros(num, 1, dtype=torch.long).to(config.device)
    top_k = None
    steps = train_dataset.get_output_length() - 1 # -1 because we already start with <START> token (index 0)
    X_samp = generate(model, X_init, steps, top_k=top_k, greedy=False).to('cpu')
    train_samples, test_samples, new_samples = [], [], []
    
    for i in range(X_samp.size(0)):
        # get the i'th row of sampled integers, as python list
        row = X_samp[i, 1:].tolist() # note: we need to crop out the first <START> token
        # token 0 is the <STOP> token, so we crop the output sequence at that point
        crop_index = row.index(0) if 0 in row else len(row)
        row = row[:crop_index]
        word_samp = train_dataset.decode(row)
        
        # better than this, calculate percent ID to each member of train set,
        # but this entails alignment and expensive calculation, so just detect 
        # 100% identity here 
        if train_dataset.contains(word_samp):
            train_samples.append(word_samp)
        elif test_dataset.contains(word_samp):
            test_samples.append(word_samp)
        else:
            new_samples.append(word_samp)
    
    print(f"Printing {num} samples from the model:")
    groups = [(train_samples, 'in train'), (test_samples, 'in test'), (new_samples, 'new')]
    for lst, desc in groups:
        print(f"{len(lst)} samples that are {desc}:")
        for i, word in enumerate(lst):
            header = f"sample_{i + 1} {desc}"
            print(f">{header}\n{word}")
    print("Done printing samples")

In [12]:
# example of generation 

print_samples(3)

Printing 3 samples from the model:
0 samples that are in train:
0 samples that are in test:
3 samples that are new:
>sample_1 new
MYKFSSVVSGRVSVQNVGFLRYTRAQAAHFIVVVHAGWVKNCTVSGLIESEYADDRQQHDCHKEGPPTASVYVTPAVWDSEHPRTEIRDVEEIEIK
>sample_2 new
MSAYVGRAWCLVQGVGFRYSATEFTTHARIKGWVRNCPTDGVEANLAGEVKVMVEWVELEKLGMSPPARARVAVEIGDHILDDFQTHVTGY
>sample_3 new
MYEELSRAQMIIHIHTGRVGVWFRVEAVQQELTARLGLLNMEDGRVYIVSGEAGRDVEFEALRDCRKRGARVELIVDEPIVRFEYAPA
Done printing samples


In [13]:
# ok so we end with a trained model that can egenrate sequences, how can we tell if these are any good? next up, we'll build an evaluation suite for our model! 