In [1]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn

import pandas as pd
import numpy as np

from sklearn import metrics
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm, trange

import os
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import math
from torch.optim.lr_scheduler import _LRScheduler

from transformers import BertModel, BertTokenizer
import esm

# Parameter & ESM model

In [2]:
BATCH_SIZE = 512
NUM_THREADS = 20
NUM_EPOCHS = 50
LR = 2e-6
antigen_max_len = 5000
epitope_max_len = 300

In [3]:
model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
batch_converter = alphabet.get_batch_converter()

# Load Data

In [4]:
train = pd.read_csv('data/train.csv')
test = pd.read_csv('data/test.csv')

In [5]:
def load_data(data_type, data):

    antigen = []
    epitope = []

    for i in range(len(data)):        
        antigen.append((data['id'][i], data['antigen_seq'][i]))
        epitope.append((data['id'][i], data['epitope_seq'][i]))
     

    _, _, antigen_batch_tokens = batch_converter(antigen)
    _, _, epitope_batch_tokens = batch_converter(epitope) 
    
    antigen_batch_tokens = antigen_batch_tokens.tolist()
    epitope_batch_tokens = epitope_batch_tokens.tolist()
    
    #padding
    for i in antigen_batch_tokens:
        while len(i) < antigen_max_len:
            i.append(1)
        
    for i in epitope_batch_tokens:
        while len(i) < epitope_max_len:
            i.append(1)    
            
    antigen_batch_tokens = torch.tensor(antigen_batch_tokens)
    epitope_batch_tokens = torch.tensor(epitope_batch_tokens)        
    
    label_list = None
    if data_type != 'test':
        label_list = []
        for label in data['label']:
            label_list.append(label)
    print(f'{data_type} dataframe preprocessing was done.')
    
    
    return antigen_batch_tokens, epitope_batch_tokens, label_list

In [6]:
train_len = int(len(train)*0.8)
train_df = train.iloc[:train_len]
val_df = train.iloc[train_len:].reset_index(inplace=False)

In [None]:
train_antigen_input_ids, train_epitope_input_ids, train_label = load_data('train', train_df)
valid_antigen_input_ids, valid_epitope_input_ids, valid_label = load_data('valid', val_df)

In [None]:
class CustomDataset(Dataset):
    def __init__(self, antigen, epitope, label):
        self.antigen = antigen
        self.epitope = epitope        
        self.label = label
        
    def __getitem__(self, index):
        antigen = self.antigen[index]
        epitope = self.epitope[index]        
        
        if self.label is not None:
            label = self.label[index]
            return antigen, epitope, torch.tensor(label)
        else:
            return antigen, epitope
        
    def __len__(self):
        return len(self.epitope)

In [None]:
train_dataset = CustomDataset(train_antigen_input_ids, train_epitope_input_ids, train_label)
train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True, num_workers = NUM_THREADS)

valid_dataset = CustomDataset(valid_antigen_input_ids, valid_epitope_input_ids, valid_label)
val_loader = DataLoader(valid_dataset, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_THREADS)

# Model

In [None]:
class ClassificationModel(nn.Module):
    def __init__(self, pretrained_model):
        super(ClassificationModel, self).__init__()
        self.esm = pretrained_model
        
        self.linear = nn.Sequential(nn.Linear(1280 * 2, 256),
                                    nn.ReLU(),
                                    nn.Linear(256, 256),
                                    nn.ReLU(),
                                    nn.Linear(256, 1))
        
    def forward(self, antigen, epitope):
        antigen_embedding = self.esm(antigen, repr_layers=[33], return_contacts = True)["representations"][33]
        epitope_embedding = self.esm(epitope, repr_layers=[33], return_contacts = True)["representations"][33]        
        
        epitope_hidden = epitope_embedding[:, 0, :]
        antigen_hidden = antigen_embedding[:, 0, :]
        
        hidden = torch.cat([epitope_hidden, antigen_hidden], dim = 1)
        
        predict = self.linear(hidden)
        
        return predict

In [None]:
def train(model, criterion, optimizer, train_loader, val_loader, epochs, device):
    best_acc = 0
    train_losses = []
    val_losses = []
    for epoch in trange(epochs, desc="Epoch"):
        model.train()
        train_loss = 0
        for i, (antigen, epitope, target) in enumerate(iterable=train_loader):
            optimizer.zero_grad()              
            epitope = epitope.to(device)           
            antigen = antigen.to(device)            
            
            target = target.to(device)
            output = model(antigen, epitope)
            
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
        
        print(f"Training loss is {train_loss/len(train_loader)}")
        train_losses.append(train_loss/len(train_loader))
        val_loss = evaluate(model=model, criterion=criterion, dataloader=val_loader, device=device)
        val_losses.append(val_loss)
        print("Epoch {} complete! Validation Loss : {}".format(epoch, val_loss))
        
        
    # data
    data = {
        "train_loss": train_losses,
        "val_loss" : val_losses
    }
    return data

def evaluate(model, criterion, dataloader, device):
    model.eval()
    mean_acc, mean_loss, count = 0, 0, 0

    with torch.no_grad():
        for antigen, epitope, target in (dataloader):
            
            antigen = antigen.to(device)
            epitope = epitope.to(device)
            
            target = target.to(device)
            output = model(antigen, epitope)
            
            mean_loss += criterion(output, target).item()
#             mean_err += get_rmse(output, target)
            count += 1
            
    return mean_loss/count


def predict(model, dataloader, device):
    predicted_label = []
    actual_label = []
    with torch.no_grad():
        for antigen, epitope, target in (dataloader):
            
            antigen = antigen.to(device)
            epitope = epitope.to(device)
            
            target = target.to(device)
            output = model(antigen, epitope)
                        
            predicted_label += output
            actual_label += target
            
    return predicted_label

In [None]:
device = torch.device("cuda:1")
Classification_model = ClassificationModel(model)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(params=model.parameters(), lr=LR)
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
  
    model = nn.DataParallel(model, device_ids=[1,2,3,4])
    
Classification_model.to(device)

In [None]:
for para in model.parameters(): 
    para.requires_grad = False
for name, para in Classification_model.named_parameters():
    if name in ['linear.0.weight','linear.0.bias','linear.2.weight','linear.2.bias','linear.4.weight','linear.4.bias']:
        para_requires_grad = True

In [None]:
data = train(model=model, 
      criterion=criterion,
      optimizer=optimizer, 
      train_loader=train_loader,
      val_loader=val_loader,
      epochs = NUM_EPOCHS,
      device = device)