# Spike NLP

We wish to utilize NLP methods to analyze the virus protein sequences. After initial experiment with the LSTM architecture, we decided to use Transformer architecture. In this notebook, we implement a BERT model. During the process, we learned from existing implementations of BERT, especially [BERT-pytorch](https://github.com/codertimo/BERT-pytorch) and [The Annotated Transformer](https://nlp.seas.harvard.edu/annotated-transformer/) and [ProteinBERT](https://academic.oup.com/bioinformatics/article/38/8/2102/6502274), though we have to make changes to accomodate our own research interests. For example, we are interested in next word prediction in the pre-training of the model to generate contexualized embeddings using self-supervised learning at individual amino acid level through learning the language patterns but we are not interested in protein functional annotation, so we do not use annotation in our model. In other words, we are interested in leveraing the Mask LM pre-training task to derive the embeddings for a fine-tuning model to  predict the phenotype of the virus protein sequences. Example phenotypes are bind binding kinetics of the virus protein to target receptor proteins and antibodies.



## Embedding
In the BERT implemetnation (bert_pytorch/model/bert.py), the masking is done after the second token (x>0) since in the original BERT paper, the first element of the input is always \[CLS\]. In our model, we will use the variant name as the \[CLS\] and the values are:
[wt, alpha, delta, omicron, na], where "na" stands for not assigned.

In [89]:
import io
import copy
import math
from Bio import SeqIO
import torch
import torch.nn as nn
import pandas as pd
import altair as alt
# from bert_pytorch.model import BERT

## Tokenization and Vocabulary
In [ProteinBERT](https://academic.oup.com/bioinformatics/article/38/8/2102/6502274), Brandes et al used 26 unique tokens to represent the 20 standard amino acids, selenocysteine (U), and undefined amino acid (X), another amino acid (OTHER) and three speical tokens \<START\>, \<END\>, \<PAD\>.

In [90]:
# Based on the source code of protein_bert
# TODO: add a <TRUNCATED> token.
ALL_AAS = 'ACDEFGHIKLMNPQRSTUVWXY'
ADDITIONAL_TOKENS = ['<OTHER>', '<START>', '<END>', '<PAD>']

# Each sequence is added <START> and <END>. "<PAD>" are added to sequence shorten than max_len.
ADDED_TOKENS_PER_SEQ = 2

n_aas = len(ALL_AAS)
aa_to_token_index = {aa: i for i, aa in enumerate(ALL_AAS)}
additional_token_to_index = {token: i + n_aas for i, token in enumerate(ADDITIONAL_TOKENS)}
token_to_index = {**aa_to_token_index, **additional_token_to_index}
index_to_token = {index: token for token, index in token_to_index.items()}
n_tokens = len(token_to_index)

def tokenize_seq(seq, max_len):
    other_token_index = additional_token_to_index['<OTHER>']
    token_seq = [additional_token_to_index['<START>']] + [aa_to_token_index.get(aa, other_token_index) for aa in seq]
    if len(token_seq) < max_len - 1: # -1 is for the <END> token
        len_pad = max_len -1 - len(token_seq)
        token_seq.extend(token_to_index['<PAD>'] for _ in range(len_pad))
    token_seq += [additional_token_to_index['<END>']]
    return torch.IntTensor(token_seq)

## Amino Acid Token Embeddings
We will derive it from the [torch.nn.Embedding class](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html). The size of the vacabulary equals the number of tokens. This approach allows the learning of the embeddings from the model intself. If we train the model with virus sepcific squences, the embeddings shall reflect the hidden properties of the amino acids in context of the trainign sequences. Note that the \<START\> and \<END\> tokens are always added at the beginning of the sequence. \<PAD\> tokens may be added before the \<END\> token if the sequence is shorter than the input sequence.

Note that using the "from_pretrained" class method of torch.nn.Embedding, we can load pretrained weights of the embedding.


In [91]:
class TokenEmbedding(nn.Embedding):
    def __init__(self, num_embeddings: torch.Tensor, embedding_dim: int = 512, max_len: int=1500, padding_idx=None):
        super().__init__(num_embeddings, embedding_dim, padding_idx)

padding_idx = token_to_index['<PAD>']
print(padding_idx)


25


## Postional Encoding
We will use the  sine and cosine functions of different frequencie to embed positional information as in the original BERT method.

In [92]:
class PositionalEncoding(nn.Module):
    """
    Impement the PE function.
    
    The PE forward function is different from the BERT-pytorch. Here we used the original method in BERT so
    PE embeddings are added to the input embeddings and no graident tracking is used.
    """

    def __init__(self, d_model, dropout, max_len=1500):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0)/d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x):

        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)

In [93]:
class SeqEncoding(nn.Module):
    """
    Encode amino acid sequence. Input sequence is represented by summing the corresponding sequence token,
    segment (e.g. question and answer or any segments separated by <SEP>), and position embeddings. In our 
    model, we only need the token and position embedding so segment embeddign is not implemented here.    
    """
    def __init__(self, num_embeddings, embedding_dim, dropout=0.1, max_len=1500, padding_idx=25):
        super().__init__()
        self.token_embedding = TokenEmbedding(num_embeddings, embedding_dim, max_len, padding_idx)
        self.position = PositionalEncoding(embedding_dim, dropout, max_len)
        self.embeddng_dim = embedding_dim
        self.dropout = nn.Dropout(dropout)
        self.max_len = max_len
        
    def forward(self, seq:str):
        x = tokenize_seq(seq, self.max_len)
        x = self.token_embedding(x)
        x = self.position(x)
        return self.dropout(x)

## Test Sequence and Position Embedding

Let's test the embedding of the first 28 amino acids of the test sequence. Notice that position 2 and 4 are the same amino acid (F) yet they have different emedding in every dimension due to they appear at different positions. For simplicity, we only use 6 dimensions to embed the sequence. In the actual model, we will use many more dimensions.

In [94]:
test_wt_seq = """>sp|P0DTC2|SPIKE_SARS2 Spike glycoprotein OS=Severe acute respiratory syndrome coronavirus 2 OX=2697049 GN=S PE=1 SV=1
MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSSVLHSTQDLFLPFFS
NVTWFHAIHVSGTNGTKRFDNPVLPFNDGVYFASTEKSNIIRGWIFGTTLDSKTQSLLIV
NNATNVVIKVCEFQFCNDPFLGVYYHKNNKSWMESEFRVYSSANNCTFEYVSQPFLMDLE
GKQGNFKNLREFVFKNIDGYFKIYSKHTPINLVRDLPQGFSALEPLVDLPIGINITRFQT
LLALHRSYLTPGDSSSGWTAGAAAYYVGYLQPRTFLLKYNENGTITDAVDCALDPLSETK
CTLKSFTVEKGIYQTSNFRVQPTESIVRFPNITNLCPFGEVFNATRFASVYAWNRKRISN
CVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIAD
YNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYLYRLFRKSNLKPFERDISTEIYQAGSTPC
NGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVN
FNFNGLTGTGVLTESNKKFLPFQQFGRDIADTTDAVRDPQTLEILDITPCSFGGVSVITP
GTNTSNQVAVLYQDVNCTEVPVAIHADQLTPTWRVYSTGSNVFQTRAGCLIGAEHVNNSY
ECDIPIGAGICASYQTQTNSPRRARSVASQSIIAYTMSLGAENSVAYSNNSIAIPTNFTI
SVTTEILPVSMTKTSVDCTMYICGDSTECSNLLLQYGSFCTQLNRALTGIAVEQDKNTQE
VFAQVKQIYKTPPIKDFGGFNFSQILPDPSKPSKRSFIEDLLFNKVTLADAGFIKQYGDC
LGDIAARDLICAQKFNGLTVLPPLLTDEMIAQYTSALLAGTITSGWTFGAGAALQIPFAM
QMAYRFNGIGVTQNVLYENQKLIANQFNSAIGKIQDSLSSTASALGKLQDVVNQNAQALN
TLVKQLSSNFGAISSVLNDILSRLDKVEAEVQIDRLITGRLQSLQTYVTQQLIRAAEIRA
SANLAATKMSECVLGQSKRVDFCGKGYHLMSFPQSAPHGVVFLHVTYVPAQEKNFTTAPA
ICHDGKAHFPREGVFVSNGTHWFVTQRNFYEPQIITTDNTFVSGNCDVVIGIVNNTVYDP
LQPELDSFKEELDKYFKNHTSPDVDLGDISGINASVVNIQKEIDRLNEVAKNLNESLIDL
QELGKYEQYIKWPWYIWLGFIAGLIAIVMVTIMLCCMTSCCSCLKGCCSCGSCCKFDEDD
SEPVLKGVKLHYT"""
len(test_wt_seq)

1413

In [95]:
test_seqs = []
fa_parser = SeqIO.parse(io.StringIO(test_wt_seq), 'fasta')
for record in fa_parser:
    seq = record.seq
    test_seqs.append(str(seq))

In [101]:
def test_encoding():

    max_len = 30
    seq = test_seqs[0][:max_len-2]
    test_seq_encode = SeqEncoding(n_tokens, 6, 0.1, max_len, padding_idx)
    test_pe_encode = PositionalEncoding(6, 0.1, max_len)
    y = test_pe_encode.forward(test_seq_encode(seq))
    print(f'Embedding shape: {y.shape}')
    print(f'Parameters shape in sequence embedding: {test_seq_encode.token_embedding.weight.shape}')

    y = y.detach().numpy()


    data = pd.concat([pd.DataFrame({
        "embedding": y[:, dim],
        "dimension":dim,
        "position": list(range(max_len)),
        })for dim in range(6)])
    
    aa = ['<start>']
    aa.extend(_ for _ in seq)
    aa.append('<end>')
    data['aa_token'] = aa * 6

    print(data)

    return (
        alt.Chart(data)
        .mark_line()
        .properties(width=800)
        .encode(x="position", y="embedding", color="dimension:N")
        .interactive())

test_encoding()


Embedding shape: torch.Size([30, 6])
Parameters shape in sequence embedding: torch.Size([26, 6])
    embedding  dimension  position aa_token
0    0.000000          0         0  <start>
1    0.934968          0         1        M
2    1.892376          0         2        F
3    1.308922          0         3        V
4   -0.000000          0         4        F
..        ...        ...       ...      ...
25   3.315263          5        25        P
26   3.314970          5        26        P
27   2.924207          5        27        A
28   1.109090          5        28        Y
29   2.387270          5        29    <end>

[180 rows x 4 columns]


## Attention

In [102]:
class Attention(nn.Module):
    """Single head scaled dot product attention"""
    def __init__():
        super().__init__()

    def forward(self, query, key, value, mask=None, dropout=None):
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)    #sqrt(d_k) is the scaling factor.

        if mask is not None:
            scores = scores.masked_fill(mask==0, -1e9)
        
        p_attn = scores.softmax(dim=-1)
        if dropout is not None:
            p_attn = dropout(p_attn)
    
        return torch.matmul(p_attn, value), p_attn 

In [104]:
class MultiHeadedAttention(nn.Module):
    """
    Multi-head attention

    h: numer of heads
    d_model: model size
    
    """
    def __init__(self, h:int, d_model:int, n_linear: int=4, dropout=0.1):
        super().__init__()
        assert d_model % h == 0 # d_model/h is used as d_k and d_v

        self.d_k = d_model // h
        self.h = h
        self.linears = nn.ModuleList([nn.Linear(d_model, d_model), n_linear])  # n layers of linear model with the same input and output size
        self.output_linear = nn.Linear(d_model, d_model)    # Output lienar model. This implementation follows BERT-pytorch instead of using the last linear layer, which is found in the annotated transformer.
        self.attn = Attention() # The forward function in Attention class is called since no hooks are defined in Attention class. See __call__() and _call_impl() in nn.Module implementation.

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        if mask is not None:
            mask = mask.unsqueeze(1)    # same mask applied to all heads
        n_batches = query.size(0)

        # 1) Linear projections in batch from d_model => h x d_k
        query, key, value = [lin(x).view(n_batches, -1, self.h, self.d_k).transpose(1, 2)
                             for lin, x in zip(self.linears, (query, key, value))]
        
        # 2) Apply attention on all the projected vectors in batch
        x, self.attn = self.attention(query, key, value, mask=mask, dropout = self.dropout)

        # 3) "Concat using a view and apply a final linear"
        x = (x.transpose(1, 2)
             .contiguous()
             .view(n_batches, -1, self.h * self.d_k))
        
        del query
        del key
        del value
        return self.output_linear(x)

## Layer Normalization

Linear regression based layer normalization with parameters a_2 and b_2. An arbituary small value (epsilon or eps) is added to std to avoid the error when std is 0.

In [105]:
class LayerNorm(nn.Module):
    """Construct a layernorm module"""

    def __init__(self, features, eps=1e-6):
        super().__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x-mean) / (std + self.eps) + self.b_2

## Residual Connection

In [106]:
class SublayerConnection(nn.Module):
    """A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """

    def __init__(self, size, dropout):
        super().__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size"
        return x + self.dropout(sublayer(self.norm(x)))

## Positionwise Feed Forward

In [107]:
class PositionwiseFeedForward(nn.Module):
    """Implements FFN equation."""

    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.GELU()

    def forward(self, x):
        return self.w_2(self.dropout(self.activation(self.w_1(x))))

## Transformer

In [99]:
class Transformer:
    

IndentationError: expected an indented block (3359644668.py, line 2)

# Model Definition

Here we define a model based on BERT. Part of the implementation is based on [BERT-pytorch](https://github.com/codertimo/BERT-pytorch)

In [None]:
def clones(module, N):
    """Produce N identical layers"""
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [None]:
class BERT(nn.Module):
    """
    BERT model
    """

    def __init__(self, vocabl_size: int = 26, hidden: int = 768, n_layer: int = 12, attn_heads: int = 12, dropout: float = 0.1):
        """
        vacab_size: vacabulary or token size
        hidden: BERT model size (used as input size and hidden size)
        n_layer: number of Transformer layers
        attn_heads: attenion heads
        dropout: dropout ratio
        """

        super().__init__()
        self.hidden  = hidden
        self.n_layer = n_layer
        self.attn_heads = attn_heads

        self.feed_forward_hidden = hidden * 4
        self.embedding = TokenEmbedding(vocabl_size, embed_size=hidden, padding_idx=25)

        self.transformer_blocks = clones(Transformer(hidden, attn_heads, hidden *4, dropout), n_layer)

    def forward(self)


In [None]:
#num_parameters_seq_encoding = sum(p.numel() for p in test_seq_encode.parameters() if p.requires_grad)
#print(f'Parameters in SeqEncoding: {num_parameters_seq_encoding}')