In [1]:
import random
import pandas as pd
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn import metrics
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm, trange

from tqdm import tqdm
from sklearn.metrics import f1_score
from transformers import ESMTokenizer, ESMModel


In [2]:
CFG = {
    'NUM_WORKERS':4,
    'EPITOPE_MAX_LEN':64,
    'EPOCHS':50,
    'LEARNING_RATE':5e-5,
    'BATCH_SIZE':512,
    'THRESHOLD':0.5,   # 기본적으로 0.5로 사용하지만 data impalance가 심할 경우 더 큰 값을 사용하기도 한다.
    'SEED':41
}

In [3]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG['SEED']) # Seed 고정

In [4]:
tokenizer = ESMTokenizer.from_pretrained("facebook/esm-1b", do_lower_case=False)

In [5]:
def get_preprocessing(data_type, new_df, tokenizer):
    epitope_ids_list = []
    epitope_mask_list = []    
        
    for epitope, antigen, s_p, e_p in tqdm(zip(new_df['epitope_seq'], new_df['antigen_seq'], new_df['start_position'], new_df['end_position'])):            
        
#         epitope_features.append(get_peptide_feature(epitope))        
        
        epitope_input = tokenizer(epitope, add_special_tokens=True, pad_to_max_length=True, is_split_into_words = True, max_length = CFG['EPITOPE_MAX_LEN'])
        epitope_ids = epitope_input['input_ids']
        epitope_mask = epitope_input['attention_mask']
        
        
        epitope_ids_list.append(epitope_ids)
        epitope_mask_list.append(epitope_mask)

    
    label_list = None
    if data_type != 'test':
        label_list = []
        for label in new_df['label']:
            label_list.append(label)
    print(f'{data_type} dataframe preprocessing was done.')
    return epitope_ids_list, epitope_mask_list, label_list


In [6]:
train = pd.read_csv('data/train.csv')

train, val = train_test_split(train, train_size=0.8, random_state=12)

train_epitope_ids_list, train_epitope_mask_list, train_label_list = get_preprocessing('train', train, tokenizer)
val_epitope_ids_list, val_epitope_mask_list, val_label_list = get_preprocessing('val', val, tokenizer)



0it [00:00, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
152648it [00:17, 8550.82it/s]


train dataframe preprocessing was done.


38163it [00:04, 8262.22it/s]

val dataframe preprocessing was done.





In [7]:
class CustomDataset(Dataset):
    def __init__(self, 
                 epitope_ids_list, 
                 epitope_mask_list,                  
                 label_list):
        self.epitope_ids_list = epitope_ids_list
        self.epitope_mask_list = epitope_mask_list  
        self.label_list = label_list
        
    def __getitem__(self, index):
        self.epitope_ids = self.epitope_ids_list[index]
        self.epitope_mask = self.epitope_mask_list[index]
              
        
        if self.label_list is not None:
            self.label = self.label_list[index]
            return torch.tensor(self.epitope_ids), torch.tensor(self.epitope_mask), self.label
        else:
            return torch.tensor(self.epitope_ids), torch.tensor(self.epitope_mask)
        
    def __len__(self):
        return len(self.epitope_ids_list)

In [8]:
train_dataset = CustomDataset(train_epitope_ids_list, 
                              train_epitope_mask_list,                             
                              train_label_list)
train_loader = DataLoader(train_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=True, num_workers=CFG['NUM_WORKERS'])

val_dataset = CustomDataset(val_epitope_ids_list, 
                            val_epitope_mask_list,
                            val_label_list)
val_loader = DataLoader(val_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=False, num_workers=CFG['NUM_WORKERS'])

In [9]:
class TransformerModel(nn.Module):
    def __init__(self, pretrained_model='facebook/esm-1b'):
        super(TransformerModel, self).__init__()              
        # Transformer                
        self.esm = ESMModel.from_pretrained(pretrained_model)  
        
        self.max_pool = nn.MaxPool1d(CFG['EPITOPE_MAX_LEN'])
        
        self.cnn4 = nn.Conv1d(in_channels = 1280, out_channels = 256, kernel_size = 4, padding = 'same')
        self.cnn5 = nn.Conv1d(in_channels = 1280, out_channels = 256, kernel_size = 5, padding = 'same')
        self.cnn6 = nn.Conv1d(in_channels = 1280, out_channels = 256, kernel_size = 6, padding = 'same')
        self.cnn7 = nn.Conv1d(in_channels = 1280, out_channels = 256, kernel_size = 7, padding = 'same')
        self.cnn8 = nn.Conv1d(in_channels = 1280, out_channels = 256, kernel_size = 8, padding = 'same')
        self.cnn9 = nn.Conv1d(in_channels = 1280, out_channels = 256, kernel_size = 9, padding = 'same')
        self.cnn10 = nn.Conv1d(in_channels = 1280, out_channels = 256, kernel_size = 10, padding = 'same')
        self.cnn11 = nn.Conv1d(in_channels = 1280, out_channels = 256, kernel_size = 11, padding = 'same')
        self.cnn12 = nn.Conv1d(in_channels = 1280, out_channels = 256, kernel_size = 12, padding = 'same')        
        

        in_channels = 256 * 9
            
        self.classifier = nn.Sequential(
            nn.LeakyReLU(True),
            nn.BatchNorm1d(in_channels),
            nn.Linear(in_channels, in_channels//4),
            nn.LeakyReLU(True),
            nn.BatchNorm1d(in_channels//4),
            nn.Linear(in_channels//4, 1)
        )
        
    def forward(self, epitope_x1, epitope_x2):
        
        # Get Embedding Vector
        epitope = self.esm(input_ids=epitope_x1, attention_mask=epitope_x2).last_hidden_state
        
        epitope = epitope.transpose(1,2).contiguous()
        
        epitope4 = self.max_pool(self.cnn4(epitope)).squeeze()
        epitope5 = self.max_pool(self.cnn5(epitope)).squeeze()
        epitope6 = self.max_pool(self.cnn6(epitope)).squeeze()
        epitope7 = self.max_pool(self.cnn7(epitope)).squeeze()
        epitope8 = self.max_pool(self.cnn8(epitope)).squeeze()
        epitope9 = self.max_pool(self.cnn9(epitope)).squeeze()
        epitope10 = self.max_pool(self.cnn10(epitope)).squeeze()
        epitope11 = self.max_pool(self.cnn11(epitope)).squeeze()
        epitope12 = self.max_pool(self.cnn12(epitope)).squeeze()

        
        
        # Feature Concat -> Binary Classifier                
        
        x = torch.cat([epitope4, epitope5, epitope6, epitope7, epitope8, epitope9, epitope10, epitope11, epitope12], axis=-1)        
        
        x = self.classifier(x).view(-1)
        
        return x

In [10]:
class WeightedFocalLoss(nn.Module):
    "Non weighted version of Focal Loss"
    def __init__(self, alpha=.25, gamma=2):
        
        super(WeightedFocalLoss, self).__init__()
        self.alpha = torch.tensor([alpha, 1-alpha]).to(device)
        self.gamma = gamma
        
    def forward(self, inputs, targets):        
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        targets = targets.type(torch.long)
        at = self.alpha.gather(0, targets.data.view(-1))
        pt = torch.exp(-BCE_loss)
        F_loss = at*(1-pt)**self.gamma * BCE_loss
        return F_loss.mean()

In [11]:
def train(model, optimizer, train_loader, val_loader, device):
    model.to(device)
    
    
    best_val_f1 = 0
    for epoch in range(1, CFG['EPOCHS']+1):
        model.train()
        train_loss = []
        criterion = WeightedFocalLoss().to(device)
        for epitope_ids_list, epitope_mask_list, label in tqdm(iter(train_loader)):            
            
            
            
            epitope_ids_list = epitope_ids_list.to(device)
            epitope_mask_list = epitope_mask_list.to(device)

            label = label.float().to(device)
            
            optimizer.zero_grad()            
            
            output = model(epitope_ids_list, epitope_mask_list)
            loss = criterion(output, label)
            
            loss.backward()
            optimizer.step()
            
            train_loss.append(loss.item())
            
                    
        val_loss, val_f1, val_acc, val_precision, val_recall = validation(model, val_loader, criterion, device)
        print(f'Epoch : [{epoch}] Train Loss : [{np.mean(train_loss):.5f}] Val Loss : [{val_loss:.5f}] Val F1 : [{val_f1:.5f}] Val acc : [{val_acc:.5f}] Val precision : [{val_precision:.5f}] Val recall : [{val_recall:.5f}]')
        
        if best_val_f1 < val_f1:
            best_val_f1 = val_f1
#             torch.save(model.module.state_dict(), './feature_transformer_best_model.pth', _use_new_zipfile_serialization=False)
            torch.save(model.module.state_dict(), '/DAS_Storage4/hyungseok/deepconv.pth', _use_new_zipfile_serialization=False)
            print('Model Saved.')
    return best_val_f1

In [12]:
from sklearn.metrics import accuracy_score, precision_score, recall_score


def validation(model, val_loader, criterion, device):
    model.eval()
    pred_proba_label = []
    true_label = []
    val_loss = []
    with torch.no_grad():
        for epitope_ids_list, epitope_mask_list, label in tqdm(iter(val_loader)):
            epitope_ids_list = epitope_ids_list.to(device)
            epitope_mask_list = epitope_mask_list.to(device)
            
            
            label = label.float().to(device)
            
            model_pred = model(epitope_ids_list, epitope_mask_list)
            
            loss = criterion(model_pred, label)
            model_pred = torch.sigmoid(model_pred).to('cpu')
            
            pred_proba_label += model_pred.tolist()
            true_label += label.to('cpu').tolist()
            
            val_loss.append(loss.item())
    
    pred_label = np.where(np.array(pred_proba_label)>CFG['THRESHOLD'], 1, 0)
    val_f1 = f1_score(true_label, pred_label, average='macro')
    acc = accuracy_score(true_label, pred_label)
    precision = precision_score(true_label, pred_label)
    recall = recall_score(true_label, pred_label)
    return np.mean(val_loss), val_f1, acc, precision, recall

In [13]:
device = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')

In [14]:
model = TransformerModel()
model = nn.DataParallel(model, device_ids=[1, 0])

optimizer = torch.optim.Adam(params = model.parameters(), lr = CFG["LEARNING_RATE"])


Some weights of the model checkpoint at facebook/esm-1b were not used when initializing ESMModel: ['lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing ESMModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ESMModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ESMModel were not initialized from the model checkpoint at facebook/esm-1b and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
for name, para in model.named_parameters():
    if name in [
'module.esm.encoder.layer.31.attention.self.query.weight',
'module.esm.encoder.layer.31.attention.self.query.bias',
'module.esm.encoder.layer.31.attention.self.key.weight',
'module.esm.encoder.layer.31.attention.self.key.bias',
'module.esm.encoder.layer.31.attention.self.value.weight',
'module.esm.encoder.layer.31.attention.self.value.bias',
'module.esm.encoder.layer.31.attention.output.dense.weight',
'module.esm.encoder.layer.31.attention.output.dense.bias',
'module.esm.encoder.layer.31.attention.LayerNorm.weight',
'module.esm.encoder.layer.31.attention.LayerNorm.bias',
'module.esm.encoder.layer.31.intermediate.dense.weight',
'module.esm.encoder.layer.31.intermediate.dense.bias',
'module.esm.encoder.layer.31.output.dense.weight',
'module.esm.encoder.layer.31.output.dense.bias',
'module.esm.encoder.layer.31.LayerNorm.weight',
'module.esm.encoder.layer.31.LayerNorm.bias',
'module.esm.encoder.layer.32.attention.self.query.weight',
'module.esm.encoder.layer.32.attention.self.query.bias',
'module.esm.encoder.layer.32.attention.self.key.weight',
'module.esm.encoder.layer.32.attention.self.key.bias',
'module.esm.encoder.layer.32.attention.self.value.weight',
'module.esm.encoder.layer.32.attention.self.value.bias',
'module.esm.encoder.layer.32.attention.output.dense.weight',
'module.esm.encoder.layer.32.attention.output.dense.bias',
'module.esm.encoder.layer.32.attention.LayerNorm.weight',
'module.esm.encoder.layer.32.attention.LayerNorm.bias',
'module.esm.encoder.layer.32.intermediate.dense.weight',
'module.esm.encoder.layer.32.intermediate.dense.bias',
'module.esm.encoder.layer.32.output.dense.weight',
'module.esm.encoder.layer.32.output.dense.bias',
'module.esm.encoder.layer.32.LayerNorm.weight',
'module.esm.encoder.layer.32.LayerNorm.bias',
'module.esm.encoder.emb_layer_norm_after.weight',
'module.esm.encoder.emb_layer_norm_after.bias',
'module.esm.pooler.dense.weight',
'module.esm.pooler.dense.bias',
'module.cnn4.weight',
'module.cnn4.bias',
'module.cnn5.weight',
'module.cnn5.bias',
'module.cnn6.weight',
'module.cnn6.bias',
'module.cnn7.weight',
'module.cnn7.bias',
'module.cnn8.weight',
'module.cnn8.bias',
'module.cnn9.weight',
'module.cnn9.bias',
'module.cnn10.weight',
'module.cnn10.bias',
'module.cnn11.weight',
'module.cnn11.bias',
'module.cnn12.weight',
'module.cnn12.bias',
'module.classifier.1.weight',
'module.classifier.1.bias',
'module.classifier.2.weight',
'module.classifier.2.bias',
'module.classifier.4.weight',
'module.classifier.4.bias',
'module.classifier.5.weight',
'module.classifier.5.bias']:
        para.requires_grad = True
        
    else:
        para.requires_grad = False


In [None]:
best_score = train(model, optimizer, train_loader, val_loader, device)
print(f'Best Validation F1 Score : [{best_score:.5f}]')

  self.padding, self.dilation, self.groups)
100%|██████████| 299/299 [10:55<00:00,  2.19s/it]
100%|██████████| 75/75 [02:05<00:00,  1.67s/it]


Epoch : [1] Train Loss : [0.04678] Val Loss : [0.04059] Val F1 : [0.62856] Val acc : [0.80326] Val precision : [0.26290] Val recall : [0.64657]
Model Saved.


100%|██████████| 299/299 [10:50<00:00,  2.18s/it]
100%|██████████| 75/75 [02:05<00:00,  1.67s/it]


Epoch : [2] Train Loss : [0.03496] Val Loss : [0.07116] Val F1 : [0.28204] Val acc : [0.29159] Val precision : [0.11101] Val recall : [0.97028]


100%|██████████| 299/299 [10:46<00:00,  2.16s/it]
100%|██████████| 75/75 [02:04<00:00,  1.66s/it]


Epoch : [3] Train Loss : [0.02765] Val Loss : [0.03505] Val F1 : [0.65521] Val acc : [0.81773] Val precision : [0.29461] Val recall : [0.72216]
Model Saved.


100%|██████████| 299/299 [10:52<00:00,  2.18s/it]
100%|██████████| 75/75 [02:06<00:00,  1.69s/it]


Epoch : [4] Train Loss : [0.02098] Val Loss : [0.03045] Val F1 : [0.73129] Val acc : [0.91911] Val precision : [0.56790] Val recall : [0.45730]
Model Saved.


 83%|████████▎ | 249/299 [09:06<01:49,  2.19s/it]