# Transformers

Here we develop implementations of transformers for CDR3 sequences.

Resources:
- https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html

In [1]:
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

from NegativeClassOptimization import ml
from NegativeClassOptimization import utils
import NegativeClassOptimization.preprocessing as preprocessing

  from .autonotebook import tqdm as notebook_tqdm
  warn(f"Failed to load image Python extension: {e}")


In [2]:
sn10 = ml.SN10()
print(f"Number of trainable parameters in SN10: {utils.num_trainable_params(sn10)}")

Number of trainable parameters in SN10: 2221


In [69]:
class Transformer(nn.Module):
    """
    Text classifier based on a pytorch TransformerEncoder.
    """

    def __init__(
        self,
        vocab_size, 
        d_model,
        nhead=8,
        dim_feedforward=2048,
        num_layers=6,
        dropout=0.1,
        activation="relu",
        classifier_dropout=0.1,
        ):

        super().__init__()

        assert d_model % nhead == 0, "nheads must divide evenly into d_model"

        self.emb = nn.Embedding(vocab_size, d_model)

        self.pos_encoder = PositionalEncoding(
            d_model=d_model,
            dropout=dropout,
            vocab_size=vocab_size,
        )

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers,
        )
        self.classifier = nn.Linear(d_model, 1)
        self.d_model = d_model

    def forward(self, x):
        x = self.emb(x) * math.sqrt(self.d_model)
        print(x.shape)
        x = self.pos_encoder(x)
        print(x.shape)
        x = self.transformer_encoder(x)
        print(x.shape)
        x = x.mean(dim=1)
        print(x.shape)
        x = x.mean(dim=0)
        print(x.shape)
        x = self.classifier(x)
        print(x.shape)
        x = nn.Sigmoid()(x)
        return x


class PositionalEncoding(nn.Module):
    """
    https://pytorch.org/tutorials/beginner/transformer_tutorial.html
    """

    def __init__(self, d_model, vocab_size=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(vocab_size, d_model)
        position = torch.arange(0, vocab_size, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float()
            * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)


    def forward(self, x):
        x = x + self.pe[:, : x.size(1), :]
        return self.dropout(x)

In [73]:
transformer = Transformer(
    vocab_size=20, 
    d_model=10,  # 20?
    nhead=2,
    dim_feedforward=10,
    num_layers=2,
    dropout=0.1,
    activation="relu",
    classifier_dropout=0.1,
    )

utils.num_trainable_params(transformer)

# transformer = torch.nn.Transformer(
#     d_model=20,
#     nhead=2,
#     num_encoder_layers=2,
#     num_decoder_layers=2,
#     dim_feedforward=16,
#     dropout=0.1,
#     activation="relu",
# )

# src = torch.rand(32, 1, 20)  # S, N, E
# tgt = torch.rand(17, 1, 20)  # T, N, E
# transformer(src, tgt).shape

# utils.num_trainable_params(transformer)

1611

In [74]:
transformer(torch.rand(13, 20).round().type(torch.long))

torch.Size([13, 20, 10])
torch.Size([13, 20, 10])
torch.Size([13, 20, 10])
torch.Size([13, 10])
torch.Size([10])
torch.Size([1])


tensor([0.5224], grad_fn=<SigmoidBackward0>)

In [None]:
cnn = ml.CNN()
utils.num_trainable_params(cnn), cnn.forward(torch.rand(1, 1, 11, 20))

(299, tensor([[0.4975]], grad_fn=<SigmoidBackward0>))

## Predict from one CDR3 / CDR3 batch for each model

In [None]:
df = utils.load_global_dataframe()

ag_pos = "3VRL"
ag_neg = "1ADQ"
df = df.loc[df["Antigen"].isin([ag_pos, ag_neg])].copy()
df = df.drop_duplicates(["Slide"])

N = 1000
df = df.sample(n=N, random_state=42)
df = df.sample(frac=1, random_state=42)

df.head(2)

Unnamed: 0,ID_slide_Variant,CDR3,Best,Slide,Energy,Structure,UID,Antigen
35067,1123588_03a,CAKTLFYDGYYRYFDVW,True,TLFYDGYYRYF,-96.0,128933-BRRSLUDUUS,1ADQ_1123588_03a,1ADQ
32364,4368719_02a,CARWDYGSLLFAYW,True,RWDYGSLLFAY,-96.54,137191-BRDSDLSRRU,1ADQ_4368719_02a,1ADQ


In [None]:
train_data, test_data, train_loader, test_loader = preprocessing.preprocess_data_for_pytorch_binary(
    df_train_val=df.iloc[:int(N*0.8)],
    df_test_closed=df.iloc[int(N*0.8):],
    ag_pos=ag_pos,
    batch_size=32,
    scale_onehot=False,
)



In [None]:
X = train_data[0][0]

Forward loop for 1 sequence for each model.

In [None]:
sn10.forward(X)

X_tr = X.reshape(11, 1, 20)
transformer.forward(X_tr, X_tr)

X_cnn = X.reshape(1, 1, 11, 20)
cnn.forward(X_cnn)

tensor([[0.1958]], grad_fn=<SigmoidBackward0>)

torch.Size([11, 1, 20])

Forward loop for batch of sequences for each model.

In [None]:
X_batch = ml.Xy_from_loader(train_loader)[0]

sn10.forward(X_batch)

X_tr = X_batch.reshape(11, -1, 20)
transformer.forward(X_tr, X_tr)

X_cnn = X_batch.reshape(-1, 1, 11, 20)
cnn.forward(X_cnn);

## Train each model

SN10

In [None]:
model = sn10.to("cpu")

learning_rate = 0.01
epochs = 5
loss_fn = nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    ml.train_loop(train_loader, model, loss_fn, optimizer)
    ml.test_loop(test_loader, model, loss_fn)

Epoch 1
-------------------------------
loss: 0.643021  [    0/  800]
Test Error: 
 Acc: 100.0 Avg loss: 0.585705 

Epoch 2
-------------------------------
loss: 0.581752  [    0/  800]


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


Test Error: 
 Acc: 100.0 Avg loss: 0.525524 

Epoch 3
-------------------------------
loss: 0.518498  [    0/  800]
Test Error: 
 Acc: 100.0 Avg loss: 0.469637 

Epoch 4
-------------------------------
loss: 0.469251  [    0/  800]
Test Error: 
 Acc: 100.0 Avg loss: 0.416152 

Epoch 5
-------------------------------
loss: 0.407727  [    0/  800]
Test Error: 
 Acc: 100.0 Avg loss: 0.366155 



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


Transformer

In [None]:
model = transformer.to("cpu")

learning_rate = 0.01
epochs = 5
loss_fn = nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    ml.train_loop(train_loader, model, loss_fn, optimizer)
    ml.test_loop(test_loader, model, loss_fn)

CNN

In [None]:
model = cnn.to("cpu")

learning_rate = 0.01
epochs = 5
loss_fn = nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    ml.train_loop(train_loader, model, loss_fn, optimizer)
    ml.test_loop(test_loader, model, loss_fn)

Epoch 1
-------------------------------
loss: 0.706944  [    0/  800]


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


Test Error: 
 Acc: 100.0 Avg loss: 0.611178 

Epoch 2
-------------------------------
loss: 0.611281  [    0/  800]
Test Error: 
 Acc: 100.0 Avg loss: 0.513477 

Epoch 3
-------------------------------
loss: 0.510288  [    0/  800]


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


Test Error: 
 Acc: 100.0 Avg loss: 0.412972 

Epoch 4
-------------------------------
loss: 0.413252  [    0/  800]
Test Error: 
 Acc: 100.0 Avg loss: 0.311248 

Epoch 5
-------------------------------
loss: 0.313254  [    0/  800]
Test Error: 
 Acc: 100.0 Avg loss: 0.216979 



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
