In [1]:
import random
import pandas as pd
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn import metrics
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm, trange

from tqdm import tqdm
from sklearn.metrics import f1_score
from Bio.SeqUtils.ProtParam import ProteinAnalysis
from transformers import ESMTokenizer, ESMModel
from sklearn import preprocessing

In [2]:
CFG = {
    'NUM_WORKERS':4,
    'EPITOPE_MAX_LEN':72,
    'EPOCHS':20,
    'LEARNING_RATE':5e-5,
    'BATCH_SIZE':512,
    'THRESHOLD':0.5,   # 기본적으로 0.5로 사용하지만 data impalance가 심할 경우 더 큰 값을 사용하기도 한다.
    'SEED':41
}

In [3]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG['SEED']) # Seed 고정

In [4]:
tokenizer = ESMTokenizer.from_pretrained("facebook/esm-1b", do_lower_case=False)

In [5]:
def get_protein_feature(seq):
    protein_feature = []
    protein_feature.append(ProteinAnalysis(seq).isoelectric_point())
    protein_feature.append(ProteinAnalysis(seq).aromaticity())
    protein_feature.append(ProteinAnalysis(seq).gravy())
    protein_feature.append(ProteinAnalysis(seq).instability_index())
    return protein_feature
    
def get_preprocessing(data_type, new_df, tokenizer):
    epitope_ids_list = []
    epitope_mask_list = []
    
    protein_features = []
#     epitope_features = []
        
    for epitope, antigen, s_p, e_p in tqdm(zip(new_df['epitope_seq'], new_df['antigen_seq'], new_df['start_position'], new_df['end_position'])):             
        protein_features.append(get_protein_feature(antigen))
#         epitope_features.append(get_peptide_feature(epitope))
                
        
        epitope_input = tokenizer(epitope, 
                                  add_special_tokens=True, 
                                  pad_to_max_length=True, 
                                  is_split_into_words = True,
                                  max_length = CFG['EPITOPE_MAX_LEN'])
        epitope_ids = epitope_input['input_ids']
        epitope_mask = epitope_input['attention_mask']
        
        
        epitope_ids_list.append(epitope_ids)
        epitope_mask_list.append(epitope_mask)

    
    label_list = None
    if data_type != 'test':
        label_list = []
        for label in new_df['label']:
            label_list.append(label)
    print(f'{data_type} dataframe preprocessing was done.')
    
    
    return epitope_ids_list, epitope_mask_list, protein_features, label_list

In [6]:
train = pd.read_csv('data/train.csv')

train, val = train_test_split(train, train_size=0.8, random_state=12)

train_epitope_ids_list, train_epitope_mask_list, train_antigen_feature, train_label_list = get_preprocessing('train', train, tokenizer)
val_epitope_ids_list, val_epitope_mask_list, val_antigen_feature, val_label_list = get_preprocessing('val', val, tokenizer)


0it [00:00, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
152648it [03:42, 687.30it/s]
65it [00:00, 611.96it/s]

train dataframe preprocessing was done.


38163it [00:55, 684.45it/s]

val dataframe preprocessing was done.





In [7]:
train_feature = pd.DataFrame(train_antigen_feature)
val_feature = pd.DataFrame(val_antigen_feature)

In [8]:
def get_card_split(df, cols, n=11):
    """
    Splits categorical columns into 2 lists based on cardinality (i.e # of unique values)
    Parameters
    ----------
    df : Pandas DataFrame
        DataFrame from which the cardinality of the columns is calculated.
    cols : list-like
        Categorical columns to list
    n : int, optional (default=11)
        The value of 'n' will be used to split columns.
    Returns
    -------
    card_low : list-like
        Columns with cardinality < n
    card_high : list-like
        Columns with cardinality >= n
    """
    cond = df[cols].nunique() > n
    card_high = cols[cond]
    card_low = cols[~cond]
    return card_low, card_high

from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer, MissingIndicator
from sklearn.preprocessing import StandardScaler, OneHotEncoder, OrdinalEncoder

numeric_transformer = Pipeline(
    steps=[("imputer", SimpleImputer(strategy="mean")), ("scaler", StandardScaler())]
)

categorical_transformer_low = Pipeline(
    steps=[
        ("imputer", SimpleImputer(strategy="constant", fill_value="missing")),
        ("encoding", OneHotEncoder(handle_unknown="ignore", sparse=False)),
    ]
)

categorical_transformer_high = Pipeline(
    steps=[
        ("imputer", SimpleImputer(strategy="constant", fill_value="missing")),
        # 'OrdianlEncoder' Raise a ValueError when encounters an unknown value. Check https://github.com/scikit-learn/scikit-learn/pull/13423
        ("encoding", OrdinalEncoder()),
    ]
)

numeric_features = train_feature.select_dtypes(include=[np.number]).columns
categorical_features = train_feature.select_dtypes(include=["object"]).columns

categorical_low, categorical_high = get_card_split(
    train_feature, categorical_features
)


preprocessor = ColumnTransformer(
    transformers=[
        ("numeric", numeric_transformer, numeric_features),
        ("categorical_low", categorical_transformer_low, categorical_low),
        ("categorical_high", categorical_transformer_high, categorical_high),
    ]
)

train_feature = preprocessor.fit_transform(train_feature)
val_feature = preprocessor.transform(val_feature)

In [10]:
class CustomDataset(Dataset):
    def __init__(self, 
                 epitope_ids_list, 
                 epitope_mask_list,   
                 input_feature,
                 label_list):
        self.epitope_ids_list = epitope_ids_list
        self.epitope_mask_list = epitope_mask_list
        self.input_feature = input_feature
        self.label_list = label_list
        
    def __getitem__(self, index):
        self.epitope_ids = self.epitope_ids_list[index]
        self.epitope_mask = self.epitope_mask_list[index]
        input_feature = self.input_feature[index]
              
        
        if self.label_list is not None:
            self.label = self.label_list[index]
            return torch.tensor(self.epitope_ids), torch.tensor(self.epitope_mask), torch.FloatTensor(input_feature), self.label
        else:
            return torch.tensor(self.epitope_ids), torch.tensor(self.epitope_mask), torch.FloatTensor(input_feature)
        
    def __len__(self):
        return len(self.epitope_ids_list)

In [11]:
train_dataset = CustomDataset(train_epitope_ids_list, 
                              train_epitope_mask_list,    
                              train_feature,
                              train_label_list)
train_loader = DataLoader(train_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=True, num_workers=CFG['NUM_WORKERS'])

val_dataset = CustomDataset(val_epitope_ids_list, 
                            val_epitope_mask_list,
                            val_feature,
                            val_label_list)
val_loader = DataLoader(val_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=False, num_workers=CFG['NUM_WORKERS'])

In [18]:
class TransformerModel(nn.Module):
    def __init__(self, pretrained_model='facebook/esm-1b'):
        super(TransformerModel, self).__init__()              
        # Transformer                
        self.esm = ESMModel.from_pretrained(pretrained_model)          
        

        in_channels = 1280 + 4
            
        self.classifier = nn.Sequential(
            nn.LeakyReLU(True),
            nn.BatchNorm1d(in_channels),
            nn.Linear(in_channels, in_channels//4),
            nn.LeakyReLU(True),
            nn.BatchNorm1d(in_channels//4),
            nn.Linear(in_channels//4, 1)
        )
        
    def forward(self, epitope_x1, epitope_x2, input_feature):
        
        # Get Embedding Vector
        epitope = self.esm(input_ids=epitope_x1, attention_mask=epitope_x2).last_hidden_state
        
        epitope = epitope[:, 0, :]
        
        # Feature Concat -> Binary Classifier                
        
        x = torch.cat([epitope, input_feature], axis=-1)        
        
        x = self.classifier(x).view(-1)
        
        return x

In [13]:
class WeightedFocalLoss(nn.Module):
    "Non weighted version of Focal Loss"
    def __init__(self, alpha=.25, gamma=2):
        
        super(WeightedFocalLoss, self).__init__()
        self.alpha = torch.tensor([alpha, 1-alpha]).to(device)
        self.gamma = gamma
        
    def forward(self, inputs, targets):        
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        targets = targets.type(torch.long)
        at = self.alpha.gather(0, targets.data.view(-1))
        pt = torch.exp(-BCE_loss)
        F_loss = at*(1-pt)**self.gamma * BCE_loss
        return F_loss.mean()

In [14]:
def train(model, optimizer, train_loader, val_loader, scheduler, device):
    model.to(device)
    
    
    best_val_f1 = 0
    for epoch in range(1, CFG['EPOCHS']+1):
        model.train()
        train_loss = []
        criterion = WeightedFocalLoss().to(device)
        for epitope_ids_list, epitope_mask_list, input_feature, label in tqdm(iter(train_loader)):            
            
            
            
            epitope_ids_list = epitope_ids_list.to(device)
            epitope_mask_list = epitope_mask_list.to(device)
            
            input_feature = input_feature.to(device)

            label = label.float().to(device)
            
            optimizer.zero_grad()            
            
            output = model(epitope_ids_list, epitope_mask_list, input_feature)
            loss = criterion(output, label)
            
            loss.backward()
            optimizer.step()
            
            train_loss.append(loss.item())
            
            if scheduler is not None:
                scheduler.step()
                    
        val_loss, val_f1, val_acc, val_precision, val_recall = validation(model, val_loader, criterion, device)
        print(f'Epoch : [{epoch}] Train Loss : [{np.mean(train_loss):.5f}] Val Loss : [{val_loss:.5f}] Val F1 : [{val_f1:.5f}] Val acc : [{val_acc:.5f}] Val precision : [{val_precision:.5f}] Val recall : [{val_recall:.5f}]')
        
        if best_val_f1 < val_f1:
            best_val_f1 = val_f1
#             torch.save(model.module.state_dict(), './feature_transformer_best_model.pth', _use_new_zipfile_serialization=False)
            torch.save(model.module.state_dict(), '/DAS_Storage4/hyungseok/esm_final.pth', _use_new_zipfile_serialization=False)
            print('Model Saved.')
    return best_val_f1

In [15]:
from sklearn.metrics import accuracy_score, precision_score, recall_score


def validation(model, val_loader, criterion, device):
    model.eval()
    pred_proba_label = []
    true_label = []
    val_loss = []
    with torch.no_grad():
        for epitope_ids_list, epitope_mask_list, input_feature, label in tqdm(iter(val_loader)):
            epitope_ids_list = epitope_ids_list.to(device)
            epitope_mask_list = epitope_mask_list.to(device)
            
            input_feature = input_feature.to(device)
            
            label = label.float().to(device)
            
            model_pred = model(epitope_ids_list, epitope_mask_list, input_feature)
            
            loss = criterion(model_pred, label)
            model_pred = torch.sigmoid(model_pred).to('cpu')
            
            pred_proba_label += model_pred.tolist()
            true_label += label.to('cpu').tolist()
            
            val_loss.append(loss.item())
    
    pred_label = np.where(np.array(pred_proba_label)>CFG['THRESHOLD'], 1, 0)
    val_f1 = f1_score(true_label, pred_label, average='macro')
    acc = accuracy_score(true_label, pred_label)
    precision = precision_score(true_label, pred_label)
    recall = recall_score(true_label, pred_label)
    return np.mean(val_loss), val_f1, acc, precision, recall

In [16]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

In [20]:
model = TransformerModel()
model = nn.DataParallel(model, device_ids=[0,1,2])
optimizer = torch.optim.Adam(params = model.parameters(), lr = CFG["LEARNING_RATE"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000*CFG['EPOCHS'], eta_min=0)


Some weights of the model checkpoint at facebook/esm-1b were not used when initializing ESMModel: ['lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias']
- This IS expected if you are initializing ESMModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ESMModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ESMModel were not initialized from the model checkpoint at facebook/esm-1b and are newly initialized: ['esm.pooler.dense.weight', 'esm.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [22]:
for name, para in model.named_parameters():
    if name in [
'module.esm.encoder.layer.31.attention.self.query.weight',
'module.esm.encoder.layer.31.attention.self.query.bias',
'module.esm.encoder.layer.31.attention.self.key.weight',
'module.esm.encoder.layer.31.attention.self.key.bias',
'module.esm.encoder.layer.31.attention.self.value.weight',
'module.esm.encoder.layer.31.attention.self.value.bias',
'module.esm.encoder.layer.31.attention.output.dense.weight',
'module.esm.encoder.layer.31.attention.output.dense.bias',
'module.esm.encoder.layer.31.attention.LayerNorm.weight',
'module.esm.encoder.layer.31.attention.LayerNorm.bias',
'module.esm.encoder.layer.31.intermediate.dense.weight',
'module.esm.encoder.layer.31.intermediate.dense.bias',
'module.esm.encoder.layer.31.output.dense.weight',
'module.esm.encoder.layer.31.output.dense.bias',
'module.esm.encoder.layer.31.LayerNorm.weight',
'module.esm.encoder.layer.31.LayerNorm.bias',
'module.esm.encoder.layer.32.attention.self.query.weight',
'module.esm.encoder.layer.32.attention.self.query.bias',
'module.esm.encoder.layer.32.attention.self.key.weight',
'module.esm.encoder.layer.32.attention.self.key.bias',
'module.esm.encoder.layer.32.attention.self.value.weight',
'module.esm.encoder.layer.32.attention.self.value.bias',
'module.esm.encoder.layer.32.attention.output.dense.weight',
'module.esm.encoder.layer.32.attention.output.dense.bias',
'module.esm.encoder.layer.32.attention.LayerNorm.weight',
'module.esm.encoder.layer.32.attention.LayerNorm.bias',
'module.esm.encoder.layer.32.intermediate.dense.weight',
'module.esm.encoder.layer.32.intermediate.dense.bias',
'module.esm.encoder.layer.32.output.dense.weight',
'module.esm.encoder.layer.32.output.dense.bias',
'module.esm.encoder.layer.32.LayerNorm.weight',
'module.esm.encoder.layer.32.LayerNorm.bias',
'module.esm.encoder.emb_layer_norm_after.weight',
'module.esm.encoder.emb_layer_norm_after.bias',
'module.esm.pooler.dense.weight',
'module.esm.pooler.dense.bias',
'module.classifier.1.weight',
'module.classifier.1.bias',
'module.classifier.2.weight',
'module.classifier.2.bias',
'module.classifier.4.weight',
'module.classifier.4.bias',
'module.classifier.5.weight',
'module.classifier.5.bias']:
        para.requires_grad = True
        
    else:
        para.requires_grad = False

In [None]:
best_score = train(model, optimizer, train_loader, val_loader, scheduler, device)
print(f'Best Validation F1 Score : [{best_score:.5f}]')

  2%|▏         | 6/299 [00:21<12:50,  2.63s/it]

In [None]:
test_df = pd.read_csv('data/test.csv')
test_epitope_ids_list, test_epitope_mask_list, test_antigen_feature, test_label_list= get_preprocessing('test', test_df, tokenizer)
test_feature = preprocessor.transform(test_antigen_feature)

In [None]:
test_dataset = CustomDataset(test_epitope_ids_list, 
                             test_epitope_mask_list, 
                             test_feature,
                             test_label_list)
test_loader = DataLoader(test_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=False, num_workers=CFG['NUM_WORKERS'])

In [None]:
def inference(model, test_loader, device):
    model.eval()
    pred_proba_label = []
    with torch.no_grad():
        for epitope_ids_list, epitope_mask_list, protein_features in tqdm(iter(test_loader)):
            epitope_ids_list = epitope_ids_list.to(device)
            epitope_mask_list = epitope_mask_list.to(device)

            protein_features = protein_features.to(device)
            
            model_pred = model(epitope_ids_list, epitope_mask_list, protein_features)
            model_pred = torch.sigmoid(model_pred).to('cpu')
            
            pred_proba_label += model_pred.tolist()
    
    pred_label = np.where(np.array(pred_proba_label)>CFG['THRESHOLD'], 1, 0)
    return pred_label

In [None]:
preds = inference(model, test_loader, device)
submit = pd.read_csv('data/sample_submission.csv')
submit['label'] = preds

submit.to_csv('submission/esm_final.csv', index=False)
print('Done.')