# BERT Model for understanding mutation in protien sequences

## 0. init

In [92]:
import numpy as np
import random
import pandas as pd
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import tqdm
from torch.optim import Adam
import time

import string
from typing import Iterable, Tuple, Optional

## 1. Building Dataset and DataLoader

In [93]:
file_path = 'X_set.txt'

# Initialize lists to hold the phylogenetic position strings and amino acid sequences
specie_code = []
amino_acid_sequences = []

# Read the file
with open(file_path, 'r') as file:
    for line in file:
        parts = line.strip().split(' ')
        specie_code.append(parts[0])
        amino_acid_sequences.append(parts[1])

amino_acid_sequences[0:3]

['---LSQF--LLMLWVPGSKGEIVLTQSPASVSVSPGERVTISCQASESVGNTYLNWLQQKSGQSPRWLIYQVSKLESGIPARFRGSGSGTDFTFTISRVEAEDVAHYYSQQ-----',
 'MESLSQC--LLMLWVPVSRGAIVLTQSPALVSVSPGERVTISCKASQSVGNTYLSWFRQKPGQSPRGLIYKVSNLPSGVPSRFRGSGAEKDFTLTISRVEAVDGAVYYCAQASYSP',
 'MESLSQC--LLMLWVPVSRGAIVLTQSPASVSVSPGERVTISCKASQSLGNTYLHWFQQKPGQSPRRLIYQVSNLLSGVPSRFSGSGAGKDFSLTISSVEAGDGAVYYCFQGSYDP']

In [94]:
len(amino_acid_sequences)

1001

In [95]:
# we would need two more set of information 

# 1. protien type 
protein_types = ['A1'] * len(amino_acid_sequences)

# 2. weights for species
specie_weight = torch.rand(len(amino_acid_sequences))


### 1.1. Tokenizer

- There are 20 amino acids, each letter in the chain represents one of them. 
- Converting them into 20 tokens, meaning each amino acid would get a number associated with it. 
- Would also need a special character token, which is "-", something related to multiple-sequence-alignment 

In [96]:
# making sure that there are only 20 diff types of amino acids 

amino_acid_set = set()

for seq in amino_acid_sequences:
    for acid in seq:
        if acid != "-":
            amino_acid_set.add(acid)

# 20 amino acids
print(f"Num of Amino Acids: {len(amino_acid_set) }")
amino_acids_list = list(amino_acid_set)

Num of Amino Acids: 20


In [97]:
# Creating a Tokenzer class, which ennodes and decodes an amino acid sequence 

class AminoAcidTokenizer:
    ''' 
    To encode and decode any amino acid string
    '''
    # class attribute
    # all 20 types of amino acids
    amino_acids = ['S','D','H','L','T','E','W','N','Y','Q','C','G','V','K','I','R','M','F','A','P']

    def __init__(self, special_tokens: Optional[Iterable[str]] = None):
        # define a vocab
        self.vocab = AminoAcidTokenizer.amino_acids
        if special_tokens:
            self.vocab += list(special_tokens)
        # mapping each vocab to a token (a numeric value)
        self.token2idx = {token:i for i, token in enumerate(self.vocab)} 
        # mapping numeric value back to a token
        self.idx2token = {i:token for token, i  in self.token2idx.items()}

    def encode(self, inputs: Iterable[str]) -> Iterable[int]:
        return [self.token2idx[token] for token in inputs]
    
    def decode(self, inputs: Iterable[int]) -> Iterable[str]:
        return [self.idx2token[idx] for idx in inputs]

    def __len__(self):
        return len(self.vocab)
    
    @property
    def vocab_size(self) -> int:
        return len(self.vocab)


In [98]:
# creating an instance of the Tokenizer. 
amino_acid_tokenizer = AminoAcidTokenizer(special_tokens=["-", "[MASK]", "[PAD]"])

# let's encode the first amino-acid-sequence and see the first 10 positions
print(f"First 20 amino acids         : {[i for i in amino_acid_sequences[0][0:20]]}")
print(f"First 20 encoded amino acids : {amino_acid_tokenizer.encode(amino_acid_sequences[0])[0:20]}")
print(f"First 20 decoded amino acids : {amino_acid_tokenizer.decode(amino_acid_tokenizer.encode(amino_acid_sequences[0])[0:20])}")

First 20 amino acids         : ['-', '-', '-', 'L', 'S', 'Q', 'F', '-', '-', 'L', 'L', 'M', 'L', 'W', 'V', 'P', 'G', 'S', 'K', 'G']
First 20 encoded amino acids : [20, 20, 20, 3, 0, 9, 17, 20, 20, 3, 3, 16, 3, 6, 12, 19, 11, 0, 13, 11]
First 20 decoded amino acids : ['-', '-', '-', 'L', 'S', 'Q', 'F', '-', '-', 'L', 'L', 'M', 'L', 'W', 'V', 'P', 'G', 'S', 'K', 'G']


In [99]:
# all tokens mapped to an idx
print(amino_acid_tokenizer.token2idx)

{'S': 0, 'D': 1, 'H': 2, 'L': 3, 'T': 4, 'E': 5, 'W': 6, 'N': 7, 'Y': 8, 'Q': 9, 'C': 10, 'G': 11, 'V': 12, 'K': 13, 'I': 14, 'R': 15, 'M': 16, 'F': 17, 'A': 18, 'P': 19, '-': 20, '[MASK]': 21, '[PAD]': 22}


In [100]:
amino_acid_tokenizer.vocab_size

23

In [102]:
# similar to Amino Acid tokenizers, creating Protien tokenizer

class ProteinTokenizer:
    '''
    To encode and decode protein types and amino acid sequences
    '''
    # class attribute
    protiens = ['A1', 'A2']

    def __init__(self, special_tokens: Iterable[str] = None):
        # define a vocab
        self.vocab = ProteinTokenizer.protiens 
        if special_tokens:   
            self.vocab += list(special_tokens)
        # mapping each vocab to a token (a numeric value)
        self.token2idx = {token:i for i, token in enumerate(self.vocab)} 
        # mapping numeric value back to a token
        self.idx2token = {i:token for token, i  in self.token2idx.items()}

    def encode(self, inputs: Iterable[str]) -> Iterable[int]:
        return [self.token2idx[token] for token in inputs]
    
    def decode(self, inputs: Iterable[int]) -> Iterable[str]:
        return [self.idx2token[idx] for idx in inputs]

    def __len__(self):
        return len(self.vocab)
    
    @property
    def vocab_size(self) -> int:
        return len(self.vocab)


In [103]:
protein_tokenizer = ProteinTokenizer()
protein_tokenizer.token2idx

{'A1': 0, 'A2': 1}

### 1.3 Create Training Data

In [104]:
# tokenizing all protein seq and protein types 
def create_encoded_tensors(amino_acid_sequences: list, protein_types: list, 
                           amino_acid_tokenizer: AminoAcidTokenizer, 
                           protein_tokenizer: ProteinTokenizer):
    
    amino_acid_tensors = []
    protein_type_tensors = []

    for seq, p_type in zip(amino_acid_sequences, protein_types):
        amino_acid_tensors.append(torch.tensor(amino_acid_tokenizer.encode(seq), dtype=torch.int64))
        protein_type_tensors.append(torch.tensor(protein_tokenizer.encode([p_type]), dtype=torch.int64))

    return amino_acid_tensors, protein_type_tensors

In [105]:
encoded_amino_acids, encoded_protein_types = create_encoded_tensors(
    amino_acid_sequences, 
    protein_types, 
    amino_acid_tokenizer, 
    protein_tokenizer
)

In [106]:
# both list to have same len
print(len(encoded_amino_acids))
print(len(encoded_protein_types))

1001
1001


In [169]:
from torch.utils.data import Dataset, DataLoader

class MaskedAminoSeqDataset(Dataset):
    def __init__(self, encoded_amino_acids: list,
                encoded_protein_types: list,
                specie_weight : torch.Tensor,
                mask_token: int,
                pad_token: int,
                max_len: int ):
            """
            Dataset for masked amino acid sequence prediction.

            Args:
            input_tensor (torch.Tensor): Input tensor of shape (num_sequences, sequence_length).
            mask_token (int): The token used for masking.
            """
            self.encoded_amino_acids = encoded_amino_acids
            self.encoded_protein_types = encoded_protein_types
            self.specie_weight = specie_weight
            self.mask_token = mask_token
            self.pad_token = pad_token
            self.max_len = max_len
    

    def __len__(self):
        return len(self.encoded_amino_acids)

    def __getitem__(self, idx):
        input_seqs, target_amino_acids, mask_positions, encoded_protien, sample_weight = \
            self._create_training_data(self.encoded_amino_acids[idx],
                                        self.encoded_protein_types[idx],
                                        self.specie_weight[idx],
                                        self.mask_token,
                                        self.pad_token,
                                        self.max_len, 
                                        )
        
        return input_seqs.squeeze(0), target_amino_acids.squeeze(0), mask_positions.squeeze(0), encoded_protien, sample_weight


    def _create_training_data(self, encoded_amino_acid: torch.Tensor,
                            encoded_protein_type: torch.Tensor, 
                            specie_weight: torch.Tensor, 
                            mask_token: int,
                            pad_token: int,
                            max_len: int,  
                            min_masks: int = 1,
                            max_masks: int = 5):
        """
        Create training data for masked amino acid sequence prediction.

        This function takes an encoded amino acid sequence and applies random masking
        to create input-target pairs for training a BERT-like model. It also handles
        padding or truncation to ensure consistent sequence length.

        Args:
            encoded_amino_acid (torch.Tensor): Encoded amino acid sequence.
            encoded_protein_type (torch.Tensor): Encoded protein type.
            specie_weight (torch.Tensor): Weight associated with the species.
            mask_token (int): Token used for masking.
            pad_token (int): Token used for padding.
            max_len (int): Maximum length of the sequence.
            min_masks (int, optional): Minimum number of tokens to mask. Defaults to 1.
            max_masks (int, optional): Maximum number of tokens to mask. Defaults to 5.

        Returns:
            tuple: A tuple containing:
                - masked_seq (torch.Tensor): Input sequence with masked tokens.
                - target_seq (torch.Tensor): Target sequence for masked token prediction.
                - fixed_mask_positions (torch.Tensor): Fixed-size tensor of mask positions.
                - encoded_protein_type (torch.Tensor): Encoded protein type.
                - specie_weight (torch.Tensor): Weight associated with the species.

        Notes:
            - The function pads or truncates the input sequence to `max_len`.
            - It randomly masks between `min_masks` and `max_masks` tokens.
            - The `fixed_mask_positions` tensor has a fixed size of `max_masks`,
            with -1 values indicating unused mask positions.
            - Target sequences use -100 for non-masked positions (ignored in loss calculation).
        """
        # Pad or truncate the sequence to max_len
        seq_len = encoded_amino_acid.shape[0]
        if seq_len < max_len:
            padding = torch.full((max_len - seq_len,), pad_token, dtype=encoded_amino_acid.dtype)
            input_seq = torch.cat([encoded_amino_acid, padding])
        else:
            input_seq = encoded_amino_acid[:max_len]
        
        # Determine number of masks
        num_masks = torch.randint(min_masks, min(max_masks + 1, max_len + 1), (1,)).item()
        
        # Create mask positions
        mask_positions = torch.randperm(max_len)[:num_masks]
        # Create a fixed-size tensor for mask positions to make sure it's same size vector
        fixed_mask_positions = torch.full((max_masks,), -1, dtype=torch.long)
        fixed_mask_positions[:num_masks] = mask_positions    
        
        # Create masked input sequence
        masked_seq = input_seq.clone()
        masked_seq[mask_positions] = mask_token
        
        # Create target sequence
        target_seq = torch.full((max_len,), -100, dtype=input_seq.dtype)  # -100 is often used to ignore in loss
        target_seq[mask_positions] = input_seq[mask_positions]
        
        # Ensure encoded_protein_type is the right shape
        if encoded_protein_type.dim() == 0:
            encoded_protein_type = encoded_protein_type.unsqueeze(0)
        
        return masked_seq, target_seq, fixed_mask_positions, encoded_protein_type, specie_weight

In [170]:
# Assuming input_tensor is your tensor of amino acid sequences
masked_amino_seq_dataset = MaskedAminoSeqDataset(
    encoded_amino_acids=encoded_amino_acids, 
    encoded_protein_types=encoded_protein_types, 
    specie_weight=specie_weight,
    mask_token=21,
    pad_token=22, 
    max_len=116
) 
masked_amino_seq_dataloader = DataLoader(masked_amino_seq_dataset, batch_size=32, shuffle=True)

In [179]:
## each iteration now gives a batch with 32 data points.
i, t, m, n, q = 0, 0, 0, 0, 0
for data in masked_amino_seq_dataloader:
    print(f"amino seqs with masked: shape: {data[0].shape}")
    print(f"targets amino acid:     shape: {data[1].shape}")
    print(f"mask posittions:        shape: {data[2].shape} ")
    print(f"encoded protein type:   shpae: {data[3].shape}")
    print(f"specie_weight:          shape: {data[4].shape}")

    i = data[0]
    t = data[1]
    m = data[2]
    break

amino seqs with masked: shape: torch.Size([32, 116])
targets amino acid:     shape: torch.Size([32, 116])
mask posittions:        shape: torch.Size([32, 5]) 
encoded protein type:   shpae: torch.Size([32, 1])
specie_weight:          shape: torch.Size([32])


## 2. Embeddings

- amino acid embeddings 
- position embeddings 
- protein type embeddings

In [180]:
class SinusoidalPositionEncoding(nn.Module):
    """
    Sinusoidal Positional Encoding module.

    This module generates sinusoidal position embeddings for input sequences.
    It can create up to `max_seq_length` unique position embeddings (default 5000).

    Args:
        embed_size (int): The size of each embedding vector.
        max_seq_length (int, optional): The maximum sequence length to support. 
            Defaults to 5000.

    Attributes:
        embed_size (int): The size of each embedding vector.
        pe (Tensor): The pre-computed position encoding matrix of shape 
            (1, max_seq_length, embed_size).

    Note:
        - The actual number of unique embeddings used depends on the input 
          sequence length in the forward pass.
        - While there are `max_seq_length` distinct vectors, positions beyond 
          this could theoretically be represented due to the periodic nature 
          of sine and cosine functions, albeit with some loss of uniqueness.
    """
    def __init__(self, embed_size, max_seq_length=5000):
        super().__init__()
        self.embed_size = embed_size
        
        pe = torch.zeros(max_seq_length, embed_size)
        position = torch.arange(0, max_seq_length, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-math.log(10000.0) / embed_size))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return self.pe[:, :x.size(1)]


In [186]:
class BERTEmbeddings(nn.Module):

    def __init__(self, amino_vocab_size, protien_vocab_size, embed_size, max_seq_length, dropout=0.1):
        super().__init__()

        self.embed_size = embed_size
        self.amino_acid_token = torch.nn.Embedding(amino_vocab_size, embed_size, dtype=torch.float32)
        self.position = SinusoidalPositionEncoding(embed_size, max_seq_length=max_seq_length)
        self.protien_token = torch.nn.Embedding(protien_vocab_size, embed_size, dtype=torch.float32)
        self.dropout = torch.nn.Dropout(p=dropout)

    def forward(self, amino_acid_seqs, protiens):
        """
        amino_acid_seqs = (B * C) ; protien =  (32 * 1)
        output ===> (B * C * d_model)
        """
    
        amino_acid_embed = self.amino_acid_token(amino_acid_seqs) 
        pos_embed = self.position(amino_acid_seqs)
        protien_embed = self.protien_token(protiens)
        out = amino_acid_embed + pos_embed + protien_embed

        return self.dropout(out)

In [187]:
# example 
amino_vocab_size = len(amino_acid_tokenizer)
protein_vocab_size = len(protein_tokenizer)
d_model = 64 # embedding size 
max_seq_length = 200 # this doen't have to be precise, this is only for positional encoding


test_emb = BERTEmbeddings(amino_vocab_size=amino_vocab_size,
                        protien_vocab_size=protein_vocab_size,
                        embed_size=d_model,
                        max_seq_length=max_seq_length)

In [188]:
## each iteration now gives a batch with 32 data points.
for data in masked_amino_seq_dataloader:
    print(f"input batch shape:     {data[0].shape} ")

    print(f"embedded batch shape: {test_emb(data[0], data[3]).shape}")

    break

input batch shape:     torch.Size([32, 116]) 
embedded batch shape: torch.Size([32, 116, 64])
