In [1]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [2]:
import gzip
import numpy as np
import pandas as pd
pd.set_option("display.max_columns",100, "display.width",200, "display.max_colwidth",40)
import pickle
from termcolor import colored

import torch
torch.set_printoptions(linewidth=120)
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tudata
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence

import pytorch_lightning as pl
import pytorch_lightning.callbacks as plc

#from binary_label_metrics import BinaryLabelMetrics

for x in [("numpy","np"),("pandas","pd"),("torch","torch"),("pytorch_lightning","pl")]:
  print(f"{x[0]} {eval(f'{x[1]}.__version__')}")

numpy 1.22.1
pandas 1.3.5
torch 1.9.1
pytorch_lightning 1.6.4


#### Parameters

In [3]:
clargs = dict(
  #number of controls to cases
   CONTROL_TO_CASE = 1,
)

offs = max(map(len,clargs.keys()))
for k,v in clargs.items():
  print(f"{k:>{offs}}: {v}")

del k, v, offs

CONTROL_TO_CASE: 1


In [4]:
with gzip.open("data/data.pkl.gz", "rb", compresslevel=9) as fil:
  data_orig_dfs = pickle.load(fil)

del fil

#### Create Controls By Randomly Matching Antibody/Antigens

In [5]:
np.random.seed(1000)
let2ind = dict((k,v) for v,k in enumerate("ACDEFGHIKLMNPQRSTVWY",1)) #0 for null

data_dfs = dict()
for k,df in data_orig_dfs.items():  
  lst = list()
  df1 = df.loc[:,["antibody_seq","antigen_seq"]]; df1["label"] = 1
  lst.append(df1)
  
  ind1 = np.arange(df1.shape[0])
  for n in range(clargs["CONTROL_TO_CASE"]):
    ind2 = np.copy(ind1)
    while (ind1==ind2).any():
      np.random.shuffle(ind2)
    df2 = pd.DataFrame(dict(antibody_seq=df1.loc[ind1,"antibody_seq"],antigen_seq=df1.loc[ind2,"antigen_seq"],label=0))
    lst.append(df2)
  df3 = pd.concat(lst, ignore_index=True)
  df3[["antibody_ind","antigen_ind"]] = df3[["antibody_seq","antigen_seq"]].applymap(lambda x:[let2ind[l] for l in x])
  df3[["antibody_sz","antigen_sz"]] = df3[["antibody_seq","antigen_seq"]].applymap(len)
  df3.drop(columns=["antibody_seq","antigen_seq"], inplace=True); data_dfs[k] = df3
    
  print(colored(f"{k} orig {df.shape[0]}, case+cont {df3.shape[0]}", attrs=["bold"]), flush=True)
  print(df.head(2),"\n"); print(df3.head(2),"\n")
  print("==distribution==")
  p = [.02,.2,.5,.8,.98,1.]
  print(pd.concat([df3["antibody_sz"].quantile(p,interpolation="nearest").to_frame().T,
                   df3["antigen_sz"].quantile(p,interpolation="nearest").to_frame().T]),"\n")

del df, df1, df2, df3, ind1, ind2, k, lst, n

[1mtrain orig 4105, case+cont 8210[0m
                              antibody_seq                             antibody_cdr                              antigen_seq
0  EVQLLESGGGLVQPGGSLRLSCAASGFTFSSYAMSW...  000000000000000000000000011111111000...  PTNLCPFGEVFNATRFASVYAWNRKRISNCVADYSV...
1  QVQLVESGGGLVQPGGSLRLSCAASGFTLDDYAIGW...  000000000000000000000000011111111000...  TNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVL... 

   label                             antibody_ind                              antigen_ind  antibody_sz  antigen_sz
0      1  [4, 18, 14, 10, 10, 4, 16, 6, 6, 6, ...  [13, 17, 12, 10, 2, 13, 5, 6, 4, 18,...          227         198
1      1  [14, 18, 14, 10, 18, 4, 16, 6, 6, 6,...  [17, 12, 10, 2, 13, 5, 6, 4, 18, 5, ...          133         195 

==distribution==
             0.02  0.20  0.50  0.80  0.98  1.00
antibody_sz   112   121   213   222   233   366
antigen_sz      7    90   212   442  1072  2346 

[1mvalid orig 136, case+cont 272[0m
                              an

In [6]:
#test

" ".join(map(lambda x:f"{x[0]}:{x[1]}", let2ind.items()))

df = pd.DataFrame({"ag":["EVQL","KY","RHG"]})
df["ag_ind"] = df[["ag"]].applymap(lambda x:[let2ind[l] for l in x])
df

tens = pad_sequence([torch.LongTensor(x) for x in df["ag_ind"]],batch_first=True,padding_value=0)
tens

F.one_hot(tens, len(let2ind)+1)

'A:1 C:2 D:3 E:4 F:5 G:6 H:7 I:8 K:9 L:10 M:11 N:12 P:13 Q:14 R:15 S:16 T:17 V:18 W:19 Y:20'

Unnamed: 0,ag,ag_ind
0,EVQL,"[4, 18, 14, 10]"
1,KY,"[9, 20]"
2,RHG,"[15, 7, 6]"


tensor([[ 4, 18, 14, 10],
        [ 9, 20,  0,  0],
        [15,  7,  6,  0]])

tensor([[[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
         [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]])

#### Pytorch-lightning DataLoader

In [7]:
class PandasDataset(tudata.Dataset):
  
  def __init__(self, df):
    super().__init__()
    assert df.columns.tolist() == "label,antibody_ind,antigen_ind,antibody_sz,antigen_sz".split(",")
    self.df = df
  
  def __len__(self):
    return self.df.shape[0]
  
  def __getitem__(self, idx):
    return self.df.iloc[idx]


#in:  list of pandas rows
#out: antibody tensor, antibody sizes, antigen tensor, antigen sizes, labels
def pad_collate(batch):
  return \
      (pad_sequence([torch.LongTensor(x["antibody_ind"]) for x in batch],batch_first=True,padding_value=0), \
      torch.LongTensor([x["antibody_sz"] for x in batch]), \
      pad_sequence([torch.LongTensor(x["antigen_ind"]) for x in batch],batch_first=True,padding_value=0), \
      torch.LongTensor([x["antigen_sz"] for x in batch])), \
      torch.FloatTensor([x["label"] for x in batch])


# #test
# dl = tudata.DataLoader(PandasDataset(data_dfs["train"]), 
#                         batch_size=6, num_workers=1, drop_last=False, shuffle=True, collate_fn=pad_collate)
# dl = iter(dl)
# for n in range(2):
#   print(colored(f"==batch {n+1}==",attrs=["bold"]))
#   dl.next(); print()
# del dl

In [8]:
class PandasDataModule(pl.LightningDataModule):
  
  def __init__(self, prm):
    super().__init__()
    self.bs = prm["batch_size"]
  
  def train_dataloader(self):
    train = PandasDataset(data_dfs["train"])
    return tudata.DataLoader(train, batch_size=self.bs, num_workers=4, drop_last=False, shuffle=True, collate_fn=pad_collate)
  
  def val_dataloader(self):
    valid = PandasDataset(data_dfs["valid"])
    return tudata.DataLoader(valid, batch_size=self.bs, num_workers=4, drop_last=False, shuffle=False, collate_fn=pad_collate)


# #test
# lstmdm = PandasDataModule({"batch_size":4})
# dl = iter(lstmdm.val_dataloader())
# for n in range(1):
#   print(colored(f"==batch {n+1}==",attrs=["bold"]))
#   dl.next(); print()
# del lstmdm, dl

#### Pytorch Lightning Module

Examples Using Padding and Packing for LSTMs
- https://suzyahyah.github.io/pytorch/2019/07/01/DataLoader-Pad-Pack-Sequence.html
- https://github.com/HarshTrivedi/packing-unpacking-pytorch-minimal-tutorial

In [9]:
class RNNModel(nn.Module):
  
  def __init__(self, prm):
    super().__init__()
        
    #LSTM
    self.rnn_ab = nn.LSTM(input_size=len(let2ind)+1, hidden_size=prm["lstm_hidden_size"], num_layers=1, batch_first=True)
    self.rnn_ag = nn.LSTM(input_size=len(let2ind)+1, hidden_size=prm["lstm_hidden_size"], num_layers=1, batch_first=True)
    
    #LSTM output to be concatenated
    self.fc1 = nn.Linear(in_features=2*prm["lstm_hidden_size"], out_features=16)
    self.dropout = nn.Dropout(p=.5)
    self.fc2 = nn.Linear(in_features=16, out_features=1)
  
  def forward(self, ab_tens, a, ag_tens, b):    
    #antibody
    ab_out = F.one_hot(ab_tens,len(let2ind)+1).type(torch.FloatTensor).cuda()
    ab_out,(ab_hid,_) = self.rnn_ab(ab_out)
    
    #antigen
    ag_out = F.one_hot(ag_tens,len(let2ind)+1).type(torch.FloatTensor).cuda()
    ag_out,(ag_hid,_) = self.rnn_ab(ag_out)
    
    out = self.fc1(torch.cat([ab_out[:,-1,:],ag_out[:,-1,:]],dim=1))
    out = self.dropout(F.relu(out))
    out = self.fc2(out)
    return out.squeeze()

#class RNNModel(nn.Module):
#  
#  def __init__(self, prm):
#    super().__init__()
#    
#    # #embedding (adding 1 for padding_idx)
#    # self.embed_ab = nn.Embedding(num_embeddings=len(let2ind)+1, embedding_dim=prm["embed_dim"], padding_idx=0)
#    # self.embed_ag = nn.Embedding(num_embeddings=len(let2ind)+1, embedding_dim=prm["embed_dim"], padding_idx=0)
#    
#    #LSTM or GRU
#    #self.rnn_ab = nn.LSTM(input_size=prm["embed_dim"], hidden_size=prm["lstm_hidden_size"], num_layers=1, batch_first=True)
#    #self.rnn_ag = nn.LSTM(input_size=prm["embed_dim"], hidden_size=prm["lstm_hidden_size"], num_layers=1, batch_first=True)
#    self.rnn_ab = nn.LSTM(input_size=len(let2ind)+1, hidden_size=prm["lstm_hidden_size"], num_layers=1, batch_first=True)
#    self.rnn_ag = nn.LSTM(input_size=len(let2ind)+1, hidden_size=prm["lstm_hidden_size"], num_layers=1, batch_first=True)
#    
#    #LSTM output to be concatenated
#    self.fc1 = nn.Linear(in_features=2*prm["lstm_hidden_size"], out_features=16)
#    self.dropout = nn.Dropout(p=.5)
#    self.fc2 = nn.Linear(in_features=16, out_features=1)
#  
#  def forward(self, ab_tens, ab_sz, ag_tens, ag_sz):    
#    #antibody
#    ab_out = F.one_hot(ab_tens,len(let2ind)+1).type(torch.FloatTensor).cuda(); #self.embed_ab(ab_tens)
#    #ab_out = pack_padded_sequence(ab_out, ab_sz.cpu(), batch_first=True, enforce_sorted=False)
#    ab_out,(ab_hid,_) = self.rnn_ab(ab_out)
#    #print(ab_out.shape)
#    #ab_out,_ = pad_packed_sequence(ab_out, batch_first=True)
#    
#    #antigen
#    ag_out = F.one_hot(ag_tens,len(let2ind)+1).type(torch.FloatTensor).cuda(); #self.embed_ag(ag_tens)
#    #ag_out = pack_padded_sequence(ag_out, ag_sz.cpu(), batch_first=True, enforce_sorted=False)
#    ag_out,(ag_hid,_) = self.rnn_ab(ag_out)
#    #ag_out,_ = pad_packed_sequence(ag_out, batch_first=True)
#    
#    out = self.fc1(torch.cat([ab_out[:,-1,:],ag_out[:,-1,:]],dim=1))
#    out = self.dropout(F.relu(out))
#    out = self.fc2(out)
#    return out.squeeze()


# #test
# rnnm = RNNModel({"embed_dim":32,"lstm_hidden_size":64}); rnnm
# lstmdm = PandasDataModule({"batch_size":4})
# dl = lstmdm.train_dataloader()
# for n,x in enumerate(dl):
#   rnnm.forward(*x[0])
#   if n==2:
#     break
# del dl, n, x, lstmdm, rnnm

In [10]:
class RNNPL(pl.LightningModule):
  
  def __init__(self, prm):
    super().__init__()
    self.model = RNNModel(prm["model"])   
    self.lr = prm["lr"]
    self.save_hyperparameters()
  
  def forward(self, x):
    return self.model.forward(*x)
  
  def training_step(self, batch, batch_idx):
    X,y = batch
    yhat = self.forward(X)
    return F.binary_cross_entropy_with_logits(yhat,y)
  
  def validation_step(self, batch, batch_idx):
    X,y = batch
    yhat = self.forward(X)
    loss = F.binary_cross_entropy_with_logits(yhat,y)
    self.log("valid_loss", loss)
  
  def configure_optimizers(self):
    return torch.optim.Adam(self.parameters(), lr=self.lr)


# #test
# RNNPL({"model":{"embed_dim":16,"lstm_hidden_size":32}, "lr":1E-3})

#### Model Training

In [12]:
pdm = PandasDataModule({"batch_size":32})
modprm = dict(model={"embed_dim":16,"lstm_hidden_size":128}, lr=1E-3)
rnnpl = RNNPL(modprm)

checkpoint_CB = plc.ModelCheckpoint(monitor="valid_loss", save_top_k=1, mode="min"
                                            , dirpath="checkpoints", filename="{epoch:03d}")
earlystopping_CB = plc.early_stopping.EarlyStopping(monitor="valid_loss", patience=2, mode="min")
progressbar_CB = plc.RichProgressBar()

pl.seed_everything(1234)
trainer = pl.Trainer(accelerator="gpu", strategy="dp", max_epochs=20, auto_lr_find=False, auto_scale_batch_size=False, 
                     deterministic=True, logger=False, callbacks=[checkpoint_CB,earlystopping_CB,progressbar_CB])
_ = trainer.fit(rnnpl, datamodule=pdm)
best_model_path = trainer.checkpoint_callback.best_model_path

print(f"         nparam: {sum(p.numel() for p in rnnpl.model.parameters()):,}")
print(f"  current_epoch: {trainer.current_epoch}")
print(f"best_model_path: {best_model_path}")

         nparam: 158,753
  current_epoch: 5
best_model_path: /home/dnori/antibody_antigen_rnn/checkpoints/epoch=002.ckpt


#### Model Evaluation

In [13]:
#rnnpl = RNNPL(modprm).load_from_checkpoint(best_model_path)
#_ = rnnpl.eval(); _ = torch.no_grad()

In [14]:
#blm = BinaryLabelMetrics()
#
#for x in ["train","valid","test"]:
#  print(f"=={x} data==")
#  data = PandasDataset(data_dfs[x])
#  dataDL = tudata.DataLoader(data, batch_size=128, num_workers=4, drop_last=False
#                                    , shuffle=False, collate_fn=pad_collate)
#  
#  Xlst = list(); ylst = list()
#  for X,y in dataDL:
#    Xlst.append(torch.sigmoid(rnnpl.forward(X))) #detach().numpy()
#    ylst.append(y) #detach().numpy()
#  
#  yhat = np.concatenate(Xlst); del Xlst
#  y = np.concatenate(ylst).astype(int); del ylst
#  blm.add_model(x, pd.DataFrame(dict(label=y,score=yhat)))

In [None]:
#blm.plot_roc()