In [1]:
import random
import pandas as pd
import numpy as np
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm
from sklearn.metrics import f1_score
import warnings
warnings.filterwarnings(action='ignore')
from transformers import ESMTokenizer, ESMForMaskedLM, ESMModel
from sklearn.model_selection import train_test_split
from transformers import BertConfig, BertModel, BertTokenizer, BertForPreTraining
from sklearn import preprocessing

# Parameter

In [2]:
CFG = {
    'NUM_WORKERS':4,    
    'EPOCHS':50,
    'LEARNING_RATE':1e-4,
    'BATCH_SIZE':128,
    'THRESHOLD':0.5,
    'SEED':41
}

In [3]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG['SEED']) # Seed 고정

In [4]:
tokenizer = ESMTokenizer.from_pretrained("facebook/esm-1b", do_lower_case=False)

# Preprocessing

In [5]:
def get_preprocessing(data_type, new_df, tokenizer):
    epitope_ids_list = []
    epitope_mask_list = []
    
    
        
    for epitope in tqdm(new_df['epitope_seq']):       
        
        epitope_input = tokenizer(epitope, add_special_tokens=True, pad_to_max_length=True, is_split_into_words = True, max_length = 72)
        epitope_ids = epitope_input['input_ids']
        epitope_mask = epitope_input['attention_mask']
        
        
        epitope_ids_list.append(epitope_ids)
        epitope_mask_list.append(epitope_mask)
        
    
    label_list = None
    if data_type != 'test':
        label_list = []
        for label in new_df['label']:
            label_list.append(label)
    print(f'{data_type} dataframe preprocessing was done.')
    return epitope_ids_list, epitope_mask_list, label_list

def normalization(df):
    standard_scaler = preprocessing.StandardScaler()    
    standard_scaler.fit(df)
    df = standard_scaler.transform(df)
    return df

In [6]:
train = pd.read_csv('data/train.csv')
test = pd.read_csv('data/test.csv')
train1 = pd.read_csv('data/train_AAC.csv')
train2 = pd.read_csv('data/train_DDE.csv')
train3 = pd.read_csv('data/train_CTDC.csv')
train4 = pd.read_csv('data/train_CTDT.csv')
train5 = pd.read_csv('data/train_CTDD.csv')

train_input_feature = pd.merge(train1, train2, how='left', left_on='#', right_on='#', sort = False)
train_input_feature = pd.merge(train_input_feature, train3, how='left', left_on='#', right_on='#', sort = False)
train_input_feature = pd.merge(train_input_feature, train4, how='left', left_on='#', right_on='#', sort = False)
train_input_feature = pd.merge(train_input_feature, train5, how='left', left_on='#', right_on='#', sort = False)

train = pd.merge(train, train_input_feature, how='left', left_on = 'id', right_on='#', sort = False)



train, val = train_test_split(train, train_size=0.8, random_state=12)

train_epitope_ids_list, train_epitope_mask_list, train_label_list = get_preprocessing('train', train, tokenizer)
val_epitope_ids_list, val_epitope_mask_list, val_label_list = get_preprocessing('val', val, tokenizer)

train_feature = normalization(train.drop(['id','epitope_seq','antigen_seq','antigen_code','start_position','end_position',
                                           'number_of_tested','number_of_responses','assay_method_technique','assay_group',
                                           'disease_type','disease_state','reference_date','reference_journal','reference_title',
                                           'reference_IRI','qualitative_label','label','#'], axis = 1))

val_feature = normalization(val.drop(['id','epitope_seq','antigen_seq','antigen_code','start_position','end_position',
                                       'number_of_tested','number_of_responses','assay_method_technique','assay_group',
                                       'disease_type','disease_state','reference_date','reference_journal','reference_title',
                                       'reference_IRI','qualitative_label','label','#'], axis = 1))


  0%|          | 0/152648 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
100%|██████████| 152648/152648 [00:17<00:00, 8594.10it/s]


train dataframe preprocessing was done.


100%|██████████| 38163/38163 [00:04<00:00, 8355.73it/s]


val dataframe preprocessing was done.


In [8]:
train_feature.shape

(152648, 698)

# Load Data

In [12]:
class CustomDataset(Dataset):
    def __init__(self, 
                 epitope_ids_list, 
                 epitope_mask_list,                  
                 input_feature,
                 label_list):
        self.epitope_ids_list = epitope_ids_list
        self.epitope_mask_list = epitope_mask_list        
        self.input_feature = input_feature
        self.label_list = label_list
        
    def __getitem__(self, index):
        self.epitope_ids = self.epitope_ids_list[index]
        self.epitope_mask = self.epitope_mask_list[index]        
        input_feature = self.input_feature[index]
        
        if self.label_list is not None:
            self.label = self.label_list[index]
            return torch.tensor(self.epitope_ids), torch.tensor(self.epitope_mask), torch.FloatTensor(input_feature), self.label
        
        else:
            return torch.tensor(self.epitope_ids), torch.tensor(self.epitope_mask), torch.FloatTensor(input_feature)
        
    def __len__(self):
        return len(self.epitope_ids_list)

In [13]:
train_dataset = CustomDataset(train_epitope_ids_list, 
                              train_epitope_mask_list,                               
                              train_feature,
                              train_label_list)
train_loader = DataLoader(train_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=True, num_workers=CFG['NUM_WORKERS'])

val_dataset = CustomDataset(val_epitope_ids_list, 
                            val_epitope_mask_list,                             
                            val_feature,
                            val_label_list)
val_loader = DataLoader(val_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=False, num_workers=CFG['NUM_WORKERS'])

# Model

In [14]:
class resnet_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()        
        
        self.cnn1 = nn.Conv1d(in_channels = in_channels, out_channels = out_channels, kernel_size = 8, padding = 'same')
        self.cnn2 = nn.Conv1d(in_channels = out_channels, out_channels = out_channels, kernel_size = 8, padding = 'same')
        
    def forward(self, epitope):
        
        
        _epitope = self.cnn1(epitope)
        epitope = F.relu(self.cnn1(epitope))
        epitope = self.cnn2(epitope)
        
        epitope = F.relu(torch.add(epitope , _epitope))        
        
        
        return epitope
        

In [15]:
class protein_layer(nn.Module):
    
    def __init__(self):
        
        super().__init__()
        
        self.resnet_block1 = resnet_block(1280, 256)       
        self.layers = nn.ModuleList([resnet_block(256, 256) for _ in range(3)])
        
    def forward(self, epitope):
        
        
        epitope = self.resnet_block1(epitope)
        
        for layer in self.layers:            
        
            epitope = layer(epitope)
            
        return epitope
            

In [16]:
class TransformerModel(nn.Module):
    def __init__(self, pretrained_model='facebook/esm-1b'):
        super(TransformerModel, self).__init__()              
        # Transformer                
        self.esm = ESMModel.from_pretrained(pretrained_model)  
        
        self.max_pool = nn.MaxPool1d(72)
        
        self.epitope_layer = protein_layer()

        in_channels = 256 + 698
            
        self.classifier = nn.Sequential(
            nn.LeakyReLU(True),
            nn.BatchNorm1d(in_channels),
            nn.Linear(in_channels, in_channels//4),
            nn.LeakyReLU(True),
            nn.BatchNorm1d(in_channels//4),
            nn.Linear(in_channels//4, 1)
        )
        
    def forward(self, epitope_x1, epitope_x2, input_feature):
        
        # Get Embedding Vector
        epitope = self.esm(input_ids=epitope_x1, attention_mask=epitope_x2).last_hidden_state
        
        epitope = epitope.transpose(1,2).contiguous()
        
        epitope = self.max_pool(self.epitope_layer(epitope)).squeeze()
        
        
        # Feature Concat -> Binary Classifier
        
        x = torch.cat([epitope, input_feature], axis=-1)        
        
        x = self.classifier(x).view(-1)
        return x

# Train & Validation

In [17]:
class WeightedFocalLoss(nn.Module):
    "Non weighted version of Focal Loss"
    def __init__(self, alpha=.25, gamma=2):
        super(WeightedFocalLoss, self).__init__()
        self.alpha = torch.tensor([alpha, 1-alpha]).to(device)
        self.gamma = gamma
    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        targets = targets.type(torch.long)
        at = self.alpha.gather(0, targets.data.view(-1))
        pt = torch.exp(-BCE_loss)
        F_loss = at*(1-pt)**self.gamma * BCE_loss
        return F_loss.mean()

In [18]:
def train(model, optimizer, train_loader, val_loader, scheduler, device):
    model.to(device)    
    criterion = WeightedFocalLoss().to(device) 
    best_val_f1 = 0
    for epoch in range(1, CFG['EPOCHS']+1):
        model.train()
        train_loss = []
        for epitope_ids_list, epitope_mask_list, input_feature, label in tqdm(iter(train_loader)):
            epitope_ids_list = epitope_ids_list.to(device)
            epitope_mask_list = epitope_mask_list.to(device)
            input_feautre = input_feature.to(device)

            label = label.float().to(device)
            
            optimizer.zero_grad()
            
            output = model(epitope_ids_list, epitope_mask_list, input_feature)
            loss = criterion(output, label)
            
            loss.backward()
            optimizer.step()
            
            train_loss.append(loss.item())
            
            if scheduler is not None:
                scheduler.step()
                    
        val_loss, val_f1 = validation(model, val_loader, criterion, device)
        print(f'Epoch : [{epoch}] Train Loss : [{np.mean(train_loss):.5f}] Val Loss : [{val_loss:.5f}] Val F1 : [{val_f1:.5f}]')
        
        if best_val_f1 < val_f1:
            best_val_f1 = val_f1
            torch.save(model.module.state_dict(), 'model/esm_epitope_resnet.pth', _use_new_zipfile_serialization=False)
            print('Model Saved.')
    return best_val_f1

In [19]:
def validation(model, val_loader, criterion, device):
    model.eval()
    pred_proba_label = []
    true_label = []
    val_loss = []
    with torch.no_grad():
        for epitope_ids_list, epitope_mask_list, input_feature, label in tqdm(iter(val_loader)):
            epitope_ids_list = epitope_ids_list.to(device)
            epitope_mask_list = epitope_mask_list.to(device)
            input_feature = input_feature.to(device)


            label = label.float().to(device)
            
            model_pred = model(epitope_ids_list, epitope_mask_list, input_feature)
            loss = criterion(output, label)
            model_pred = torch.sigmoid(model_pred).to('cpu')
            
            pred_proba_label += model_pred.tolist()
            true_label += label.to('cpu').tolist()
            
            val_loss.append(loss.item())
            
    pred_label = np.where(np.array(pred_proba_label)>CFG['THRESHOLD'], 1, 0)
    val_f1 = f1_score(true_label, pred_label, average='macro')
    return np.mean(val_loss), val_f1

# Run

In [20]:
device = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')

In [21]:
model = TransformerModel()
model = nn.DataParallel(model, device_ids=[1,0])
optimizer = torch.optim.Adam(params = model.parameters(), lr = CFG["LEARNING_RATE"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10*CFG['EPOCHS'], eta_min=0)


Some weights of the model checkpoint at facebook/esm-1b were not used when initializing ESMModel: ['lm_head.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight']
- This IS expected if you are initializing ESMModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ESMModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ESMModel were not initialized from the model checkpoint at facebook/esm-1b and are newly initialized: ['esm.pooler.dense.weight', 'esm.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [22]:
for name, para in model.named_parameters():
    if name in [
'module.esm.encoder.layer.31.attention.self.query.weight',
'module.esm.encoder.layer.31.attention.self.query.bias',
'module.esm.encoder.layer.31.attention.self.key.weight',
'module.esm.encoder.layer.31.attention.self.key.bias',
'module.esm.encoder.layer.31.attention.self.value.weight',
'module.esm.encoder.layer.31.attention.self.value.bias',
'module.esm.encoder.layer.31.attention.output.dense.weight',
'module.esm.encoder.layer.31.attention.output.dense.bias',
'module.esm.encoder.layer.31.attention.LayerNorm.weight',
'module.esm.encoder.layer.31.attention.LayerNorm.bias',
'module.esm.encoder.layer.31.intermediate.dense.weight',
'module.esm.encoder.layer.31.intermediate.dense.bias',
'module.esm.encoder.layer.31.output.dense.weight',
'module.esm.encoder.layer.31.output.dense.bias',
'module.esm.encoder.layer.31.LayerNorm.weight',
'module.esm.encoder.layer.31.LayerNorm.bias',
'module.esm.encoder.layer.32.attention.self.query.weight',
'module.esm.encoder.layer.32.attention.self.query.bias',
'module.esm.encoder.layer.32.attention.self.key.weight',
'module.esm.encoder.layer.32.attention.self.key.bias',
'module.esm.encoder.layer.32.attention.self.value.weight',
'module.esm.encoder.layer.32.attention.self.value.bias',
'module.esm.encoder.layer.32.attention.output.dense.weight',
'module.esm.encoder.layer.32.attention.output.dense.bias',
'module.esm.encoder.layer.32.attention.LayerNorm.weight',
'module.esm.encoder.layer.32.attention.LayerNorm.bias',
'module.esm.encoder.layer.32.intermediate.dense.weight',
'module.esm.encoder.layer.32.intermediate.dense.bias',
'module.esm.encoder.layer.32.output.dense.weight',
'module.esm.encoder.layer.32.output.dense.bias',
'module.esm.encoder.layer.32.LayerNorm.weight',
'module.esm.encoder.layer.32.LayerNorm.bias',
'module.esm.encoder.emb_layer_norm_after.weight',
'module.esm.encoder.emb_layer_norm_after.bias',
'module.esm.pooler.dense.weight',
'module.esm.pooler.dense.bias',
'module.epitope_layer.resnet_block1.cnn1.weight',
'module.epitope_layer.resnet_block1.cnn1.bias',
'module.epitope_layer.resnet_block1.cnn2.weight',
'module.epitope_layer.resnet_block1.cnn2.bias',
'module.epitope_layer.layers.0.cnn1.weight',
'module.epitope_layer.layers.0.cnn1.bias',
'module.epitope_layer.layers.0.cnn2.weight',
'module.epitope_layer.layers.0.cnn2.bias',
'module.epitope_layer.layers.1.cnn1.weight',
'module.epitope_layer.layers.1.cnn1.bias',
'module.epitope_layer.layers.1.cnn2.weight',
'module.epitope_layer.layers.1.cnn2.bias',
'module.epitope_layer.layers.2.cnn1.weight',
'module.epitope_layer.layers.2.cnn1.bias',
'module.epitope_layer.layers.2.cnn2.weight',
'module.epitope_layer.layers.2.cnn2.bias',
'module.epitope_layer.layers.3.cnn1.weight',
'module.epitope_layer.layers.3.cnn1.bias',
'module.epitope_layer.layers.3.cnn2.weight',
'module.epitope_layer.layers.3.cnn2.bias',
'module.epitope_layer.layers.4.cnn1.weight',
'module.epitope_layer.layers.4.cnn1.bias',
'module.epitope_layer.layers.4.cnn2.weight',
'module.epitope_layer.layers.4.cnn2.bias',
'module.epitope_layer.layers.5.cnn1.weight',
'module.epitope_layer.layers.5.cnn1.bias',
'module.epitope_layer.layers.5.cnn2.weight',
'module.epitope_layer.layers.5.cnn2.bias',
'module.epitope_layer.layers.6.cnn1.weight',
'module.epitope_layer.layers.6.cnn1.bias',
'module.epitope_layer.layers.6.cnn2.weight',
'module.epitope_layer.layers.6.cnn2.bias',
'module.classifier.1.weight',
'module.classifier.1.bias',
'module.classifier.2.weight',
'module.classifier.2.bias',
'module.classifier.4.weight',
'module.classifier.4.bias',
'module.classifier.5.weight',
'module.classifier.5.bias']:
        para.requires_grad = True
        
    else:
        para.requires_grad = False



In [None]:
best_score = train(model, optimizer, train_loader, val_loader, scheduler, device)
print(f'Best Validation F1 Score : [{best_score:.5f}]')

 53%|█████▎    | 635/1193 [15:12<10:06,  1.09s/it]