# Our first neural network

Today we will try to write a neural network that will predict antibody developability

## Dataset

We will use data from [TDC](https://tdcommons.ai/single_pred_tasks/develop/), a collection of biochemical datasets for deep learning

In [None]:
!pip install PyTDC
!pip install plotly
!pip install fair-esm
!pip install pytorch-lightning

In [None]:
import pandas as pd
import numpy as np
import torch
from typing import List
import re
import esm
from pytorch_lightning import Trainer, LightningModule
%load_ext tensorboard
torch.manual_seed(42)

In [None]:
from tdc.single_pred import Develop
data = Develop(name = 'SAbDab_Chen')
split = data.get_split()

In [None]:
def split_chains(old_input:str) -> list:
  translated_input = re.sub("[\[\]\s']", "", old_input)
  return translated_input.split(",")

In [None]:
for i in ["train", "valid", "test"]:
  split[i]["Antibody"] = split[i]["Antibody"].apply(split_chains)
  split[i]["Antibody"] = split[i]["Antibody"].apply(lambda x: x[0] + x[1])
train = split["train"]
valid = split["valid"]
test = split["test"]

## Featurisation

First we would need to write a function that takes our aminoacid sequence `EVQLQQSGAEVVRSGAS` and converts it to a tensor of numbers.

The rough steps for that would be:

1. Identify our alphabet (all available amino acids)
1. Assign a number to each amino acid
1. "Translate" our sequences into numbers
1. _Pad_ sequences to give them identical lengths

Then we would want to apply it to both sequences in our input 

In [None]:
alphabet = set()
for sequence in train["Antibody"]:
  for letter in sequence:
    alphabet.add(letter)
alphabet_dict = {i:idx for idx, i in enumerate(sorted(alphabet))}
alphabet_dict["$"] = len(alphabet)
alphabet_dict

In [None]:
esm_model, alphabet = esm.pretrained.esm1_t6_43M_UR50S()
batch_converter = alphabet.get_batch_converter()
esm = esm_model.eval()  # disables dropout for deterministic results
esm_model = esm_model.cuda()

In [None]:
def translate_sequence(sequence:str, max_length:int=281) -> List[int]:
  result = []
  n = len(sequence)
  pad_size = max_length - n
  for letter in sequence:
    result.append(alphabet_dict.get(letter, alphabet_dict["$"]))
  if n > max_length:
    result = result[:max_length]
  else:
    result += [20] * pad_size
  return result

def get_esm_gpu(seqs:list, esm_model, batch_size=6):
  batch_labels, batch_strs, batch_tokens = batch_converter(seqs)
  dl = torch.utils.data.DataLoader(batch_tokens, batch_size=batch_size)
  embeddings = []
  with torch.no_grad():
    for batch in dl:
      batch = batch.cuda()
      results = esm_model(batch, repr_layers=[6], return_contacts=True)
      token_representations = results["representations"][6]
      embeddings.append(token_representations.mean(1).cpu())
  return torch.cat(embeddings)

In [None]:
# ESM use example
seqs_test = [("seq1", "KKKKKKRKRKRKRK"), ("seq2", "RRRRRRR"), ("seq2", "VKRKRKRKVKVKVKMKMKMK")]
seqs_embed = get_esm_gpu(seqs_test, esm_model)
seqs_embed.size()

In [None]:
import torch.nn.functional as F
def prepare_data(df:pd.DataFrame) -> torch.Tensor:
  train_x = torch.tensor(df["Antibody"].apply(translate_sequence))
  train_y = torch.tensor(df["Y"]).unsqueeze(-1)
  return torch.hstack((train_x, train_y))

def prepare_esm_data(df:pd.DataFrame) -> torch.Tensor:
  """Implement me please!"""
  seq_pairs = []
  for name, row in df.iterrows():
    seq_pairs.append((row["Antibody_ID"], row["Antibody"]))
  train_x = get_esm_gpu(seq_pairs, esm_model)
  train_y = torch.tensor(df["Y"]).unsqueeze(-1)
  return torch.hstack((train_x, train_y))

train_data = prepare_esm_data(train)
val_data = prepare_esm_data(valid)
test_data = prepare_esm_data(test)

In [None]:
from torchmetrics.functional import accuracy

class ESMModel(LightningModule):
  def __init__(self, input_dim:int=768, hidden_dim:int=512, dropout:float=0.35, lr:float=0.001):
    super().__init__()
    self.save_hyperparameters()
    self.linear1 = torch.nn.Linear(input_dim, hidden_dim)
    self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim//2)
    self.linear3 = torch.nn.Linear(hidden_dim//2, 1)
    self.dropout = torch.nn.Dropout(dropout)
    self.lr = lr
  
  def forward(self, input:torch.Tensor) -> torch.Tensor:
    x = self.linear1(input)
    x = F.relu(x)
    x = self.dropout(x)
    x = self.linear2(x)
    x = F.relu(x)
    x = self.dropout(x)
    return self.linear3(x)
  
  def shared_step(self, batch:torch.Tensor, step_type:str):
    input = batch[:, :-1]
    target = batch[:, -1]
    pred = self.forward(input).squeeze(-1)
    loss = F.binary_cross_entropy_with_logits(pred, target)
    acc = accuracy(pred, target.long())
    self.log(f"{step_type}_loss", loss)
    self.log(f"{step_type}_acc", acc)
    return dict(loss=loss, acc=acc)
  
  def training_step(self, batch:torch.Tensor) -> dict:
    return self.shared_step(batch, "train")
  
  def validation_step(self, batch:torch.Tensor, batch_idx:int) -> dict:
    return self.shared_step(batch, "val")
  
  def test_step(self, batch:torch.Tensor, batch_idx:int) -> dict:
    return self.shared_step(batch, "test")

  def configure_optimizers(self):
    optim = torch.optim.AdamW(self.parameters(), lr=self.lr)
    scheduler = {
            "monitor": "val_loss",
            "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
                optim,
                verbose=True,
                factor=0.1,
                patience=60,
            ),
        }
    return [optim], [scheduler]


In [None]:
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
BATCH_SIZE=1024
train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE)

model = ESMModel(hidden_dim=1024, lr=0.0001)
trainer = Trainer(gpus=1, 
                  log_every_n_steps=10, 
                  max_epochs=1000, 
                  callbacks=[EarlyStopping("val_loss", patience=200), 
                             ModelCheckpoint(monitor="val_loss")], 
                  gradient_clip_val=50)
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
trainer.test(model, test_loader)

In [None]:
%tensorboard --logdir lightning_logs