## BERT model of Amino-Seq Masked-Language-Model

## 0. init

In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import math
import torch.nn.functional as F

import string
from typing import Iterable, Tuple

## 1. Build Dataset and DataLoader

In [2]:
file_path = 'X_set.txt'

# Initialize lists to hold the phylogenetic position strings and amino acid sequences
specie_code = []
amino_acid_sequences = []

# Read the file
with open(file_path, 'r') as file:
    for line in file:
        parts = line.strip().split(' ')
        specie_code.append(parts[0])
        amino_acid_sequences.append(parts[1])

amino_acid_sequences[0:3]

['---LSQF--LLMLWVPGSKGEIVLTQSPASVSVSPGERVTISCQASESVGNTYLNWLQQKSGQSPRWLIYQVSKLESGIPARFRGSGSGTDFTFTISRVEAEDVAHYYSQQ-----',
 'MESLSQC--LLMLWVPVSRGAIVLTQSPALVSVSPGERVTISCKASQSVGNTYLSWFRQKPGQSPRGLIYKVSNLPSGVPSRFRGSGAEKDFTLTISRVEAVDGAVYYCAQASYSP',
 'MESLSQC--LLMLWVPVSRGAIVLTQSPASVSVSPGERVTISCKASQSLGNTYLHWFQQKPGQSPRRLIYQVSNLLSGVPSRFSGSGAGKDFSLTISSVEAGDGAVYYCFQGSYDP']

### 1.1. Tokenizer

- There are 20 amino acids, each letter in the chain represents one of them. 
- Converting them into 20 tokens, meaning each amino acid would get a number associated with it. 
- Would also need a special character token, which is "-", something related to multiple-sequence-alignment 

In [3]:
# Creating a set of all amino-acids

amino_acid_set = set()

for seq in amino_acid_sequences:
    for acid in seq:
        if acid != "-":
            amino_acid_set.add(acid)

# 20 amino acids
print(f"Num of Amino Acids: {len(amino_acid_set) }")
amino_acids_list = list(amino_acid_set)

Num of Amino Acids: 20


In [4]:
# Creating a Tokenzer class, which ennodes and decodes an amino acid sequence 

class Tokenizer:
    ''' 
    To encode and decode any amino acid string
    '''
    # class attribute 
    amino_acids = amino_acids_list

    def __init__(self, special_tokens = Iterable[str]):
        # define a vocab
        self.vocab = Tokenizer.amino_acids + list(special_tokens)
        # mapping each vocab to a token (a numeric value)
        self.token2idx = {token:i for i, token in enumerate(self.vocab)} 
        # mapping numeric value back to a token
        self.idx2token = {i:token for token, i  in self.token2idx.items()}

    def encode(self, inputs: Iterable[str]) -> Iterable[int]:
        return [self.token2idx[token] for token in inputs]
    
    def decode(self, inputs: Iterable[int]) -> Iterable[str]:
        return [self.idx2token[idx] for idx in inputs]

    def __len__(self):
        return len(self.vocab)

In [5]:
# creating an instance of the Tokenizer. 
amino_acid_tokenizer = Tokenizer(special_tokens=["-", "[MASK]"])

# let's encode the first amino-acid-sequence and see the first 10 positions
print(f"First 20 amino acids         : {[i for i in amino_acid_sequences[0][0:20]]}")
print(f"First 20 encoded amino acids : {amino_acid_tokenizer.encode(amino_acid_sequences[0])[0:20]}")
print(f"First 20 decoded amino acids : {amino_acid_tokenizer.decode(amino_acid_tokenizer.encode(amino_acid_sequences[0])[0:20])}")

First 20 amino acids         : ['-', '-', '-', 'L', 'S', 'Q', 'F', '-', '-', 'L', 'L', 'M', 'L', 'W', 'V', 'P', 'G', 'S', 'K', 'G']
First 20 encoded amino acids : [20, 20, 20, 18, 6, 13, 3, 20, 20, 18, 18, 8, 18, 10, 14, 15, 2, 6, 0, 2]
First 20 decoded amino acids : ['-', '-', '-', 'L', 'S', 'Q', 'F', '-', '-', 'L', 'L', 'M', 'L', 'W', 'V', 'P', 'G', 'S', 'K', 'G']


In [6]:
print(amino_acid_tokenizer.token2idx)

{'K': 0, 'R': 1, 'G': 2, 'F': 3, 'Y': 4, 'I': 5, 'S': 6, 'C': 7, 'M': 8, 'D': 9, 'W': 10, 'E': 11, 'T': 12, 'Q': 13, 'V': 14, 'P': 15, 'H': 16, 'N': 17, 'L': 18, 'A': 19, '-': 20, '[MASK]': 21}


### 1.2 Creating a Tensor for all amino-seqs

In [7]:
# making sure that the size of each amino-acid-seq is same

len_amino_acid_seq = set()
for seq in amino_acid_sequences:
    len_amino_acid_seq.add(len(seq))

# this set should have only one value 
len_amino_acid_seq
# perfect! all the seq are 116 character long

{116}

In [8]:

def create_amino_acids_tensor(amino_acid_sequences:list, my_tokenizer:Tokenizer):

    amino_acid_tensors = []

    for seq in amino_acid_sequences:
        amino_acid_tensors.append(torch.Tensor(my_tokenizer.encode(seq)).to(torch.int64))

    # stacking them 
    stacked_tensor =  torch.stack(amino_acid_tensors)

    return stacked_tensor

In [9]:
all_amino_acids_tensor = create_amino_acids_tensor(amino_acid_sequences, amino_acid_tokenizer)

In [10]:
all_amino_acids_tensor.shape

# 1001 seqs, each with the length of 116

torch.Size([1001, 116])

### 1.3 Create Training data 

- So what we need is to mask a random position in seq
- let's only mask one posiiton as of now

In [11]:
from torch.utils.data import Dataset, DataLoader

class MaskedAminoSeqDataset(Dataset):
    def __init__(self, input_tensor: torch.Tensor, mask_token: int):
            """
            Dataset for masked amino acid sequence prediction.

            Args:
            input_tensor (torch.Tensor): Input tensor of shape (num_sequences, sequence_length).
            mask_token (int): The token used for masking.
            """
            self.input_tensor = input_tensor
            self.mask_token = mask_token

    def __len__(self):
        return self.input_tensor.shape[0] 

    def __getitem__(self, idx):
        input_seqs, target_amino_acids, mask_positions = \
            self._create_training_data(self.input_tensor, batch_size=1, mask_token=self.mask_token)
        return input_seqs.squeeze(0), target_amino_acids.squeeze(0), mask_positions.squeeze(0)

    def _create_training_data(self, input_tensor: torch.Tensor, batch_size: int, mask_token: int):
        """
        Creates masked training data efficiently using vectorized operations.

        Args:
        input_tensor (torch.Tensor): Input tensor of shape (num_sequences, sequence_length)
        batch_size (int): The desired batch size.
        mask_token (int): The token used for masking.

        Returns:
        tuple: (input_seqs, target_amino_acids, mask_positions)
            - input_seqs: Tensor of shape (batch_size, sequence_length) with masked sequences.
            - target_amino_acids: Tensor of shape (batch_size,) containing the masked amino acids.
            - mask_positions: Tensor of shape (batch_size,) indicating mask positions.
        """
        rows = input_tensor.shape[0]
        seq_len = input_tensor.shape[1]
        # Randomly select 'batch_size' rows (amino acid sequences)
        idx = torch.randint(rows, size=(batch_size,))
        input_seqs = input_tensor[idx].clone()

        # Generate random mask positions within each selected sequence
        mask_positions = torch.randint(seq_len, size=(batch_size, 1))

        # Get the target amino acids at the mask positions
        target_amino_acids = input_seqs.gather(1, mask_positions).squeeze()

        # Create a mask for the selected positions 
        mask = torch.zeros(input_seqs.size(), dtype=torch.bool)
        mask.scatter_(1, mask_positions, 1)

        # Apply the mask to replace the target positions with the mask_token
        input_seqs[mask] = mask_token

        return input_seqs, target_amino_acids, mask_positions.squeeze()


In [12]:
# token id for the MASK
amino_acid_tokenizer.encode(["[MASK]"])

[21]

In [13]:
# Assuming input_tensor is your tensor of amino acid sequences
masked_amino_seq_dataset = MaskedAminoSeqDataset(all_amino_acids_tensor, mask_token=21) 
masked_amino_seq_dataloader = DataLoader(masked_amino_seq_dataset, batch_size=32, shuffle=True)

In [14]:
## each iteration now gives a batch with 32 data points.
for i in masked_amino_seq_dataloader:
    print(f"amino seqs with masked: \n shape: {i[0].shape} \n {i[0]}")
    print(f"targets amino acid:  \n shape: {i[1].shape} \n{i[1]}")
    print(f"mask posittions:  \n shape: {i[2].shape} \n{i[2]}")


    break

amino seqs with masked: 
 shape: torch.Size([32, 116]) 
 tensor([[ 8, 14,  6,  ..., 17,  4, 15],
        [20, 20, 20,  ...,  6,  9, 15],
        [ 8, 11, 12,  ...,  6,  6, 15],
        ...,
        [14,  1, 14,  ...,  6,  9, 15],
        [ 8, 10,  6,  ..., 11,  6, 15],
        [ 8,  2,  6,  ..., 17, 20, 20]])
targets amino acid:  
 shape: torch.Size([32]) 
tensor([12,  6,  6, 18,  6, 20, 13, 11, 18, 19, 10,  6, 14,  2, 13,  6, 18, 10,
        12, 18,  9,  2, 14, 13,  9,  4,  2,  4,  5, 13, 12,  2])
mask posittions:  
 shape: torch.Size([32]) 
tensor([ 76,  31,  47,  53,  39, 114,  62,  36,  74,  70,  13,  45,  95,  61,
         58,  76,   6,  55,  92,  93,  36,  86,  38, 110,  90, 107,   4, 107,
         21,  99,  39,  61])


## 2. Embeddings

We need to embeddings

- amino acid embeddings 
- position embeddings



In [15]:
class SinusoidalPositionEncoding(nn.Module):
    def __init__(self, embed_size, max_seq_length=5000):
        super().__init__()
        self.embed_size = embed_size
        
        pe = torch.zeros(max_seq_length, embed_size)
        position = torch.arange(0, max_seq_length, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-math.log(10000.0) / embed_size))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return self.pe[:, :x.size(1)]


In [16]:
class BERTEmbeddings(nn.Module):

    def __init__(self, vocab_size, embed_size, max_seq_length, dropout=0.1):
        super().__init__()

        self.embed_size = embed_size
        self.token = torch.nn.Embedding(vocab_size, embed_size, dtype=torch.float32)
        self.position = SinusoidalPositionEncoding(embed_size, max_seq_length=max_seq_length)
        self.dropout = torch.nn.Dropout(p=dropout)

    def forward(self, x):
    
        word_embed = self.token(x) 
        pos_embed = self.position(x)
        out = word_embed + pos_embed

        return self.dropout(out)

In [17]:
vocab_size = len(amino_acid_tokenizer)
d_model = 64 # embedding size 
max_seq_length = masked_amino_seq_dataset.input_tensor.shape[1]

In [18]:
test_emb = BERTEmbeddings(vocab_size=vocab_size, embed_size=d_model, max_seq_length=max_seq_length)

In [19]:
## each iteration now gives a batch with 32 data points.
for i in masked_amino_seq_dataloader:
    print(f"input batch shape:     {i[0].shape} ")

    print(f"embedded batch shape: {test_emb(i[0]).shape}")

    break

input batch shape:     torch.Size([32, 116]) 
embedded batch shape: torch.Size([32, 116, 64])


## 3. Multi Headed Attention 

In [20]:
class MultiHeadedAttention(nn.Module):
    
    def __init__(self, heads, d_model, dropout=0.1):
        super().__init__()
        
        assert d_model % heads == 0
        self.d_k = d_model // heads
        self.heads = heads
        self.dropout = torch.nn.Dropout(dropout)

        self.query = torch.nn.Linear(d_model, d_model, dtype=torch.float32)
        self.key = torch.nn.Linear(d_model, d_model, dtype=torch.float32)
        self.value = torch.nn.Linear(d_model, d_model, dtype=torch.float32)
        self.output_linear = torch.nn.Linear(d_model, d_model, dtype=torch.float32)
        
    def forward(self, query, key, value, mask):
        """
        query, key, value of shape: (batch_size, max_len, d_model)
        mask of shape: (batch_size, 1, 1, max_words)
            # Note: mask if not used, it is mainly to tell attention the locations on which 
                it should not learn much, like padding indexes
                - we dont have padding here as of now, so no need it. 
        """
        # (batch_size, max_len, d_model)
        query = self.query(query)
        key = self.key(key)        
        value = self.value(value)   
        
        # (batch_size, max_len, d_model) --> (batch_size, max_len, h, d_k) --> (batch_size, h, max_len, d_k)
        query = query.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)   
        key = key.view(key.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)  
        value = value.view(value.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)  
        
        # (batch_size, h, max_len, d_k) matmul (batch_size, h, d_k, max_len) --> (batch_size, h, max_len, max_len)
        scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / math.sqrt(query.size(-1))

        # to mask the pads (diff from the other mask) so the attention does not learn from it
        # # fill 0 mask with super small number so it wont affect the softmax weight
        # # (batch_size, h, max_len, max_len)
        # scores = scores.masked_fill(mask == 0, -1e9)    

        # (batch_size, h, max_len, max_len)
        # softmax to put attention weight for all non-pad tokens
        # max_len X max_len matrix of attention
        weights = F.softmax(scores, dim=-1)           
        weights = self.dropout(weights)

        # (batch_size, h, max_len, max_len) matmul (batch_size, h, max_len, d_k) --> (batch_size, h, max_len, d_k)
        context = torch.matmul(weights, value)

        # (batch_size, h, max_len, d_k) --> (batch_size, max_len, h, d_k) --> (batch_size, max_len, d_model)
        context = context.permute(0, 2, 1, 3).contiguous().view(context.shape[0], -1, self.heads * self.d_k)

        # (batch_size, max_len, d_model)
        return self.output_linear(context)

class FeedForward(torch.nn.Module):
    "Implements FFN equation."

    def __init__(self, d_model, middle_dim=2048, dropout=0.1):
        super(FeedForward, self).__init__()
        
        self.fc1 = torch.nn.Linear(d_model, middle_dim)
        self.fc2 = torch.nn.Linear(middle_dim, d_model)
        self.dropout = torch.nn.Dropout(dropout)
        self.activation = torch.nn.GELU()

    def forward(self, x):
        out = self.activation(self.fc1(x))
        out = self.fc2(self.dropout(out))
        return out

class EncoderLayer(torch.nn.Module):
    def __init__(
        self, 
        d_model=768,
        heads=12, 
        feed_forward_hidden=768 * 4, 
        dropout=0.1
        ):
        super(EncoderLayer, self).__init__()
        self.layernorm = torch.nn.LayerNorm(d_model)
        self.self_multihead = MultiHeadedAttention(heads, d_model)
        self.feed_forward = FeedForward(d_model, middle_dim=feed_forward_hidden)
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, embeddings, mask):
        # embeddings: (batch_size, max_len, d_model)
        # encoder mask: (batch_size, 1, 1, max_len)
        # result: (batch_size, max_len, d_model)
        interacted = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, mask))
        # residual layer
        interacted = self.layernorm(interacted + embeddings)
        # bottleneck
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        encoded = self.layernorm(feed_forward_out + interacted)
        return encoded

In [21]:
heads = MultiHeadedAttention(heads = 16, d_model=d_model)

In [22]:
## each iteration now gives a batch with 32 data points.
for i in masked_amino_seq_dataloader:
    print(f"input batch shape:     {i[0].shape} ")
    print(f"mask posiitons shape:   {i[2].shape}")

    print(f"embedded batch shape: {test_emb(i[0]).shape}")
    embded = test_emb(i[0])
    mask = i[2]
    attention_output = heads(embded, embded, embded,  mask)

    print(f"The output from the Attention : {attention_output.shape}")

    break

input batch shape:     torch.Size([32, 116]) 
mask posiitons shape:   torch.Size([32])
embedded batch shape: torch.Size([32, 116, 64])
The output from the Attention : torch.Size([32, 116, 64])


- The output from multiheaded attention goes through a little bit of forward passes, because why not!! 
- so below is a simple Feedforward pass code

In [23]:
class FeedForward(nn.Module):
    "Implements FFN equation."

    def __init__(self, d_model, middle_dim=2048, dropout=0.1):
        super().__init__()
        
        self.fc1 = torch.nn.Linear(d_model, middle_dim)
        self.fc2 = torch.nn.Linear(middle_dim, d_model)
        self.dropout = torch.nn.Dropout(dropout)
        self.activation = torch.nn.GELU()

    def forward(self, x):
        out = self.activation(self.fc1(x))
        out = self.fc2(self.dropout(out))
        return out

#### Encoder Layer

- putting all together, 

- embedded matrix comes, first it goes to Attention module 
- Then layer normalization 
- then a feed forward part 
- and it again goes through a layer normalization

In [30]:
class EncoderLayer(nn.Module):
    def __init__(
        self, 
        d_model=768,
        heads=12, 
        feed_forward_hidden=768 * 4, 
        dropout=0.1
        ):
        super(EncoderLayer, self).__init__()
        self.layernorm = torch.nn.LayerNorm(d_model, dtype=torch.float32)
        self.self_multihead = MultiHeadedAttention(heads, d_model)
        self.feed_forward = FeedForward(d_model, middle_dim=feed_forward_hidden)
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, embeddings, mask):
        # embeddings: (batch_size, max_len, d_model)
        # encoder mask: (batch_size, 1, 1, max_len)
        # result: (batch_size, max_len, d_model)
        interacted = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, mask))
        # residual layer
        interacted = interacted.to(torch.float32)
        embeddings = embeddings.to(torch.float32)

        interacted = self.layernorm(interacted + embeddings)
        # bottleneck
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        encoded = self.layernorm(feed_forward_out + interacted)
        return encoded

## 3. Forward pass

In [25]:
class BERT(nn.Module):
    """
    BERT model : Bidirectional Encoder Representations from Transformers.
    """

    def __init__(self, vocab_size, d_model=768, n_layers=12, heads=12, max_seq_length=500, dropout=0.1):
        """
        :param vocab_size: vocab_size of total words
        :param hidden: BERT model hidden size
        :param n_layers: numbers of Transformer blocks(layers)
        :param attn_heads: number of attention heads
        :param dropout: dropout rate
        """

        super().__init__()
        self.d_model = d_model
        self.n_layers = n_layers
        self.heads = heads

        # paper noted they used 4 * hidden_size for ff_network_hidden_size
        self.feed_forward_hidden = d_model * 4

        # embedding for BERT, sum of positional, segment, token embeddings
        self.embedding = BERTEmbeddings(vocab_size=vocab_size, embed_size=d_model, max_seq_length=max_seq_length)

        # multi-layers transformer blocks, deep network
        self.encoder_blocks = torch.nn.ModuleList(
            [EncoderLayer(d_model, heads, d_model * 4, dropout) for _ in range(n_layers)])

    def forward(self, x):
        # attention masking for padded token

        # (batch_size, 1, seq_len, seq_len)
        # mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
        # as of now mask has no role to play, it's for not paying attention to Padding idx
        mask = torch.Tensor([1])

        # embedding the indexed sequence to sequence of vectors
        x = self.embedding(x)

        # running over multiple transformer blocks
        for encoder in self.encoder_blocks:
            x = encoder.forward(x, mask)
        return x

In [26]:
vocab_size = len(amino_acid_tokenizer)
d_model = 64 # embedding size 
max_seq_length = masked_amino_seq_dataset.input_tensor.shape[1]

In [31]:
bert_encoder_test = BERT(vocab_size=vocab_size, d_model=d_model, n_layers=3, heads=4, max_seq_length=116)

In [33]:
# let's see if we can run oe forward pass

## each iteration now gives a batch with 32 data points.
for i in masked_amino_seq_dataloader:
    print(f"input batch shape:     {i[0].shape} ")
    print(f"mask posiitons shape:   {i[2].shape}")

    # print(f"embedded batch shape: {test_emb(i[0]).shape}")
    # embded = test_emb(i[0])
    # mask = i[2]
    # attention_output = heads(embded, embded, embded,  mask)

    # print(f"The output from the Attention : {attention_output.shape}")

    bert_output = bert_encoder_test(i[0])
    print(f"bert encoder output shape: {bert_output.shape}")

    
    break

input batch shape:     torch.Size([32, 116]) 
mask posiitons shape:   torch.Size([32])
bert encoder output shape: torch.Size([32, 116, 64])


## 4. Create a Loss Func

## 5. Traning Loop