# Test BERT-Pytorch

BERT-pytorch is a PyTorch implementation of the BERT algorithm.

[BERT-pytorch](https://github.com/codertimo/BERT-pytorch)


## Embedding
In the BERT implemetnation (bert_pytorch/model/bert.py), the masking is done after the second token (x>0) since in the original BERT paper, the first element of the input is always \[CLS\]. In our model, we will use the variant name as the \[CLS\] and the values are:
[wt, alpha, delta, omicron, na], where "na" stands for not assigned.

In [3]:
import io
import copy
import math
from Bio import SeqIO
import torch
import torch.nn as nn
# from bert_pytorch.model import BERT

## Tokenization and Vocabulary
In [ProteinBERT](https://academic.oup.com/bioinformatics/article/38/8/2102/6502274), Brandes et al used 26 unique tokens to represent the 20 standard amino acids, selenocysteine (U), and undefined amino acid (X), another amino acid (OTHER) and three speical tokens \<START\>, \<END\>, \<PAD\>.

In [62]:
# Based on the source code of protein_bert
ALL_AAS = 'ACDEFGHIKLMNPQRSTUVWXY'
ADDITIONAL_TOKENS = ['<OTHER>', '<START>', '<END>', '<PAD>']

# Each sequence is added <START> a
ADDED_TOKENS_PER_SEQ = 2

n_aas = len(ALL_AAS)
aa_to_token_index = {aa: i for i, aa in enumerate(ALL_AAS)}
additional_token_to_index = {token: i + n_aas for i, token in enumerate(ADDITIONAL_TOKENS)}
token_to_index = {**aa_to_token_index, **additional_token_to_index}
index_to_token = {index: token for token, index in token_to_index.items()}
n_tokens = len(token_to_index)

def tokenize_seq(seq, max_len):
    other_token_index = additional_token_to_index['<OTHER>']
    token_seq = [additional_token_to_index['<START>']] + [aa_to_token_index.get(aa, other_token_index) for aa in seq]
    if len(token_seq) < max_len - 1: # -1 is for the <END> token
        len_pad = max_len -1 - len(token_seq)
        token_seq.extend(token_to_index['<PAD>'] for _ in range(len_pad))
    token_seq += [additional_token_to_index['<END>']]
    return torch.IntTensor(token_seq)

## Amino Acid Token Embeddings
We will derive it from the [torch.nn.Embedding class](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html). The size of the vacabulary equals the number of tokens. This approach allows the learning of the embeddings from the model intself. If we train the model with virus sepcific squences, the embeddings shall reflect the hidden properties of the amino acids in context of the trainign sequences. Note that the \<START\> and \<END\> tokens are always added at the beginning of the sequence. \<PAD\> tokens may be added before the \<END\> token if the sequence is shorter than the input sequence.

Note that using the "from_pretrained" class method of torch.nn.Embedding, we can load pretrained weights of the embedding.


In [63]:
class TokenEmbedding(nn.Embedding):
    def __init__(self, num_embeddings: torch.Tensor, embedding_dim: int = 512, max_len: int=1500, padding_idx=None):
        super().__init__(num_embeddings, embedding_dim, padding_idx)

padding_idx = token_to_index['<PAD>']
print(padding_idx)


25


In [33]:
test_wt_seq = """>sp|P0DTC2|SPIKE_SARS2 Spike glycoprotein OS=Severe acute respiratory syndrome coronavirus 2 OX=2697049 GN=S PE=1 SV=1
MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSSVLHSTQDLFLPFFS
NVTWFHAIHVSGTNGTKRFDNPVLPFNDGVYFASTEKSNIIRGWIFGTTLDSKTQSLLIV
NNATNVVIKVCEFQFCNDPFLGVYYHKNNKSWMESEFRVYSSANNCTFEYVSQPFLMDLE
GKQGNFKNLREFVFKNIDGYFKIYSKHTPINLVRDLPQGFSALEPLVDLPIGINITRFQT
LLALHRSYLTPGDSSSGWTAGAAAYYVGYLQPRTFLLKYNENGTITDAVDCALDPLSETK
CTLKSFTVEKGIYQTSNFRVQPTESIVRFPNITNLCPFGEVFNATRFASVYAWNRKRISN
CVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIAD
YNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYLYRLFRKSNLKPFERDISTEIYQAGSTPC
NGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVN
FNFNGLTGTGVLTESNKKFLPFQQFGRDIADTTDAVRDPQTLEILDITPCSFGGVSVITP
GTNTSNQVAVLYQDVNCTEVPVAIHADQLTPTWRVYSTGSNVFQTRAGCLIGAEHVNNSY
ECDIPIGAGICASYQTQTNSPRRARSVASQSIIAYTMSLGAENSVAYSNNSIAIPTNFTI
SVTTEILPVSMTKTSVDCTMYICGDSTECSNLLLQYGSFCTQLNRALTGIAVEQDKNTQE
VFAQVKQIYKTPPIKDFGGFNFSQILPDPSKPSKRSFIEDLLFNKVTLADAGFIKQYGDC
LGDIAARDLICAQKFNGLTVLPPLLTDEMIAQYTSALLAGTITSGWTFGAGAALQIPFAM
QMAYRFNGIGVTQNVLYENQKLIANQFNSAIGKIQDSLSSTASALGKLQDVVNQNAQALN
TLVKQLSSNFGAISSVLNDILSRLDKVEAEVQIDRLITGRLQSLQTYVTQQLIRAAEIRA
SANLAATKMSECVLGQSKRVDFCGKGYHLMSFPQSAPHGVVFLHVTYVPAQEKNFTTAPA
ICHDGKAHFPREGVFVSNGTHWFVTQRNFYEPQIITTDNTFVSGNCDVVIGIVNNTVYDP
LQPELDSFKEELDKYFKNHTSPDVDLGDISGINASVVNIQKEIDRLNEVAKNLNESLIDL
QELGKYEQYIKWPWYIWLGFIAGLIAIVMVTIMLCCMTSCCSCLKGCCSCGSCCKFDEDD
SEPVLKGVKLHYT"""
len(test_wt_seq)

1413

In [64]:
test_seqs = []
fa_parser = SeqIO.parse(io.StringIO(test_wt_seq), 'fasta')
for record in fa_parser:
    seq = record.seq
    test_seqs.append(str(seq))

In [65]:
num_embeddings = n_tokens
embedding_dim = 20
max_len = 1500
embedding = TokenEmbedding(num_embeddings, embedding_dim, max_len=1500, padding_idx=padding_idx)
test_embedding = embedding(tokenize_seq(test_seqs[0], max_len))
print(f'Shape of test sequence embedding: {test_embedding.shape}')


Shape of test sequence embedding: torch.Size([1500, 20])


Let's take a look of the embedding weights:

In [66]:
embedding.weight

Parameter containing:
tensor([[ 2.4855e-01, -6.0301e-01, -8.9350e-01,  1.0500e+00,  1.1623e+00,
          2.1284e+00,  9.9412e-01, -2.1245e+00, -7.0741e-01, -4.8086e-01,
          1.0270e+00, -2.1130e-01,  1.3697e-01,  1.3649e-01,  2.8259e-01,
         -9.0673e-02, -7.8031e-01,  1.0446e+00,  1.7205e-01,  1.3393e+00],
        [-1.2577e+00, -1.7306e+00,  4.7949e-02, -6.1586e-01,  2.3227e+00,
          1.9112e+00, -3.9107e-01, -1.4216e-01,  1.1763e+00, -2.2798e-01,
          2.9397e-01, -1.1841e+00,  4.3349e-02, -7.8844e-01,  1.0073e-01,
          2.0124e-02,  1.9426e-01,  9.9476e-01, -1.3050e+00,  6.6223e-01],
        [-1.0654e+00,  1.2877e+00, -1.5136e+00, -5.2084e-01,  1.2437e+00,
          1.8556e+00, -4.0181e-01, -4.1998e-01, -1.9787e+00, -2.1260e-01,
         -4.7483e-01, -1.3029e-01,  1.7986e+00,  4.4688e-02,  7.9517e-01,
         -4.6092e-01, -2.9534e-01,  2.2252e+00,  2.1702e-01, -1.0765e+00],
        [-1.2580e+00,  1.2980e+00,  9.9607e-01, -8.0020e-01,  1.2469e+00,
          5.4

In [67]:
embedding.weight.shape

torch.Size([26, 20])

## Postional Encoding
We will use the  sine and cosine functions of different frequencie to embed positional information as in the original BERT method.

In [68]:
class PositionalEncoding(nn.Module):
    """
    Impement the PE function.
    
    The PE forward function is different from the BERT-pytorch. Here we used the original method in BERT so
    PE embeddings are added to the input embeddings and no graident tracking is used.
    """

    def __init__(self, d_model, dropout, max_len=1500):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0)/d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x):
        print(f'x.shape in PositionalEncoding: {x.shape}')
        print(f'x.shape: {x.shape},pe.shape: {self.pe.shape}')
        print(f'pe[:, : x.size(1)]: {self.pe[:, : x.size(1)].shape}')


        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)

In [69]:
class SeqEncoding(nn.Module):
    """
    Encode amino acid sequence. Input sequence is represented by summing the corresponding sequence token,
    segment (e.g. question and answer or any segments separated by <SEP>), and position embeddings. In our 
    model, we only need the token and position embedding so segment embeddign is not implemented here.    
    """
    def __init__(self, num_embeddings, embedding_dim, dropout=0.1, max_len=1500, padding_idx=25):
        super().__init__()
        self.token_embedding = TokenEmbedding(num_embeddings, embedding_dim, max_len, padding_idx)
        self.position = PositionalEncoding(embedding_dim, dropout, max_len)
        self.embeddng_dim = embedding_dim
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, seq:str):
        x = tokenize_seq(seq, max_len)
        x = self.token_embedding(x)
        x = self.position(x)
        return self.dropout(x)

In [70]:
test_seq_encode = SeqEncoding(n_tokens, 512, 0.1)
num_parameters_seq_encoding = sum(p.numel() for p in test_seq_encode.parameters() if p.requires_grad)
print(f'Parameters in SeqEncoding: {num_parameters_seq_encoding}')
print(test_seq_encode(test_seqs[0]))

Parameters in SeqEncoding: 13312
x.shape in PositionalEncoding: torch.Size([1500, 512])
x.shape: torch.Size([1500, 512]),pe.shape: torch.Size([1500, 512])
pe[:, : x.size(1)]: torch.Size([1500, 512])
tensor([[ 0.8716,  0.7394,  1.3862,  ...,  1.1798, -0.9199,  2.1998],
        [ 3.1870,  3.0478,  1.5118,  ...,  0.9575, -0.3805,  1.0238],
        [ 0.8325, -0.0000,  1.7789,  ...,  1.4181, -1.5577,  3.7161],
        ...,
        [ 1.2340, -0.0384, -1.0611,  ...,  1.2186,  0.0000,  0.0000],
        [ 0.6344, -1.0591, -0.0857,  ...,  1.2186,  0.1909,  0.0000],
        [-0.0000, -0.0000,  1.9457,  ...,  0.5950, -0.3285,  3.7168]],
       grad_fn=<MulBackward0>)


# Model Definition

Here we define a model based on BERT. Part of the implementation is based on [BERT-pytorch](https://github.com/codertimo/BERT-pytorch)

In [30]:
def clones(module, N):
    """Produce N identical layers"""
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [None]:
class Transformer:
    

In [None]:
class BERT(nn.Module):
    """
    BERT model
    """

    def __init__(self, vocabl_size: int = 26, hidden: int = 768, n_layer: int = 12, attn_heads: int = 12, dropout: float = 0.1):
        """
        vacab_size: vacabulary or token size
        hidden: BERT model size (used as input size and hidden size)
        n_layer: number of Transformer layers
        attn_heads: attenion heads
        dropout: dropout ratio
        """

        super().__init__()
        self.hidden  = hidden
        self.n_layer = n_layer
        self.attn_heads = attn_heads

        self.feed_forward_hidden = hidden * 4
        self.embedding = TokenEmbedding(vocabl_size, embed_size=hidden, padding_idx=25)

        self.transformer_blocks = clones(Transformer(hidden, attn_heads, hidden *4, dropout), n_layer)

    def forward(self)
