# SN10 classifier and integrated gradients attribution

In this notebook, we develop the SN10 classifier used in `Absolut!` and the integrated-gradients method of attribution. We might also check other attribution methods.

In [52]:
import os
from pathlib import Path
from typing import List

import pandas as pd

import torch
from torch import nn
import NegativeClassOptimization.config as config

Let's load the data on which we are going to develop the binary classifier.

In [39]:
df = pd.read_csv(config.DATA_SLACK_1_GLOBAL, sep='\t')

ag_pos = "3VRL"
ag_neg = "1ADQ"
df = df.loc[df["Antigen"].isin([ag_pos, ag_neg])].copy()

df.head(2)

Unnamed: 0,ID_slide_Variant,CDR3,Best,Slide,Energy,Structure,UID,Antigen
0,1873658_06a,CARPENLLLLLWYFDVW,True,LLLLLWYFDVW,-112.82,137442-BRDSLLUDLS,3VRL_1873658_06a,3VRL
1,7116990_04a,CARGLLLLLWYFDVW,True,LLLLLWYFDVW,-112.82,137442-BRDSLLUDLS,3VRL_7116990_04a,3VRL


First, handle duplicates.

In [51]:
def infer_antigen_from_duplicate_list(antigens: List[str], pos_antigen: str):
    assert len(antigens) <= 2, ">2 antigens not supported yet."
    if len(antigens) == 1:
        return antigens[0]
    else:
        if pos_antigen in antigens:
            return pos_antigen
        else:
            return list(set(antigens) - set([pos_antigen]))[0]

df = df.groupby("Slide").apply(
    lambda df_: infer_antigen_from_duplicate_list(df_["Antigen"].unique().tolist(), pos_antigen=ag_pos)
)

Slide
AAELFWYFDVW    3VRL
AAFITTVGWYF    1ADQ
AAFYGRWYFDV    1ADQ
AAFYYGNLAWF    1ADQ
AAGWLLLFAYW    3VRL
               ... 
YYYSFLWYFDV    1ADQ
YYYSNYELGLW    1ADQ
YYYVLWYFDVW    3VRL
YYYVLYYFDYW    3VRL
YYYYLWYFDVW    3VRL
Length: 83213, dtype: object

Define the model.

In [30]:
class SN10(nn.Module):
    def __init__(self):
        super(SN10, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(11*20, 10),
            nn.ReLU(),
            nn.Linear(10, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
model = SN10().to(device)
print(model)

Using cpu device
SN10(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=220, out_features=10, bias=True)
    (1): ReLU()
    (2): Linear(in_features=10, out_features=1, bias=True)
    (3): Sigmoid()
  )
)


In [26]:
X = torch.rand(1, 11, 20, device=device)
pred_probab = model(X)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")

Predicted class: tensor([0])


In [27]:
pred_probab

tensor([[0.4560]], grad_fn=<SigmoidBackward0>)

In [28]:
learning_rate = 0.01
epochs = 700

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
loss_fn = nn.BCELoss()

In [None]:
losses = []
accur = []
for i in range(epochs):
  for j,(x_train,y_train) in enumerate(trainloader):
    
    #calculate output
    output = model(x_train)
 
    #calculate loss
    loss = loss_fn(output,y_train.reshape(-1,1))
 
    #accuracy
    predicted = model(torch.tensor(x,dtype=torch.float32))
    acc = (predicted.reshape(-1).detach().numpy().round() == y).mean()    #backprop
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  if i%50 == 0:
    losses.append(loss)
    accur.append(acc)
    print("epoch {}\tloss : {}\t accuracy : {}".format(i,loss,acc))