In [1]:
import random
import pandas as pd
import numpy as np
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn import metrics
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm, trange

from tqdm import tqdm
from sklearn.metrics import f1_score
from transformers import BertConfig, BertModel, BertTokenizer, BertForPreTraining

from Bio.SeqUtils.ProtParam import ProteinAnalysis
from sklearn import preprocessing

In [2]:
device = torch.device('cuda:3') if torch.cuda.is_available() else torch.device('cpu')

In [3]:
CFG = {
    'NUM_WORKERS':4,
    'EPITOPE_MAX_LEN':72,
    'EPOCHS':20,
    'LEARNING_RATE':5e-5,
    'BATCH_SIZE':1024,
    'THRESHOLD':0.5,   # 기본적으로 0.5로 사용하지만 data impalance가 심할 경우 더 큰 값을 사용하기도 한다.
    'SEED':41
}

In [4]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG['SEED']) # Seed 고정

In [5]:
tokenizer = BertTokenizer.from_pretrained('model/tokenizer')

In [6]:
def seqtoinput(seq):
    for j in range(len(seq)-1):
        seq = seq[:j+j+1]+ ' ' + seq[j+j+1:]
    return seq
    
def get_protein_feature(seq, epitope):
    protein_feature = []
    protein_feature.append(ProteinAnalysis(seq).isoelectric_point())
    protein_feature.append(ProteinAnalysis(seq).aromaticity())
    protein_feature.append(ProteinAnalysis(seq).gravy())
    protein_feature.append(ProteinAnalysis(seq).instability_index())
    return protein_feature

def normalization(a):
    standard_scaler = preprocessing.StandardScaler()    
    standard_scaler.fit(a)
    result = standard_scaler.transform(a)
    return result
    
def get_preprocessing(data_type, new_df, tokenizer):
    epitope_ids_list = []
    epitope_mask_list = []
    
    protein_features = []
#     epitope_features = []
        
    for epitope, antigen, s_p, e_p in tqdm(zip(new_df['epitope_seq'], new_df['antigen_seq'], new_df['start_position'], new_df['end_position'])):             
        protein_features.append(get_protein_feature(antigen, epitope))
#         epitope_features.append(get_peptide_feature(epitope))
        
        epitope = seqtoinput(epitope)
        
        
        epitope_input = tokenizer(epitope, add_special_tokens=True, pad_to_max_length=True, max_length = 72)
        epitope_ids = epitope_input['input_ids']
        epitope_mask = epitope_input['attention_mask']
        
        
        epitope_ids_list.append(epitope_ids)
        epitope_mask_list.append(epitope_mask)

    
    label_list = None
    if data_type != 'test':
        label_list = []
        for label in new_df['label']:
            label_list.append(label)
    print(f'{data_type} dataframe preprocessing was done.')
    
    protein_features = normalization(protein_features)
    
    return epitope_ids_list, epitope_mask_list, protein_features, label_list

In [7]:
train = pd.read_csv('data/train.csv')
train = train.drop([7511])
test = pd.read_csv('data/test.csv')

train, val = train_test_split(train, train_size=0.8, random_state=12)

train_epitope_ids_list, train_epitope_mask_list, train_protein_features, train_label_list = get_preprocessing('train', train, tokenizer)
val_epitope_ids_list, val_epitope_mask_list, val_protein_features, val_label_list = get_preprocessing('val', val, tokenizer)

0it [00:00, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
152648it [04:14, 599.23it/s]


train dataframe preprocessing was done.


38162it [01:03, 601.39it/s]

val dataframe preprocessing was done.





In [8]:
class CustomDataset(Dataset):
    def __init__(self, epitope_ids_list, epitope_mask_list, protein_features, label_list):
        self.epitope_ids_list = epitope_ids_list
        self.epitope_mask_list = epitope_mask_list
        self.protein_features = protein_features

        self.label_list = label_list
        
    def __getitem__(self, index):
        self.epitope_ids = self.epitope_ids_list[index]
        self.epitope_mask = self.epitope_mask_list[index]
        
        self.protein_feature = self.protein_features[index]        
        
        if self.label_list is not None:
            self.label = self.label_list[index]
            return torch.tensor(self.epitope_ids), torch.tensor(self.epitope_mask), torch.tensor(self.protein_feature, dtype=torch.float32), self.label
        else:
            return torch.tensor(self.epitope_ids), torch.tensor(self.epitope_mask), torch.tensor(self.protein_feature, dtype=torch.float32)
        
    def __len__(self):
        return len(self.epitope_ids_list)

In [9]:
train_dataset = CustomDataset(train_epitope_ids_list, train_epitope_mask_list, train_protein_features, train_label_list)
train_loader = DataLoader(train_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=True, num_workers=CFG['NUM_WORKERS'])

val_dataset = CustomDataset(val_epitope_ids_list, val_epitope_mask_list, val_protein_features, val_label_list)
val_loader = DataLoader(val_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=False, num_workers=CFG['NUM_WORKERS'])

In [10]:
config = BertConfig(
    vocab_size=30, # default는 영어 기준이므로 내가 만든 vocab size에 맞게 수정해줘야 함
    hidden_size=768,
    num_hidden_layers=12,    # layer num
    num_attention_heads=12,    # transformer attention head number
    intermediate_size=3072,   # transformer 내에 있는 feed-forward network의 dimension size
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.0,
    max_position_embeddings=72,    # embedding size 최대 몇 token까지 input으로 사용할 것인지 지정
    type_vocab_size=2,    # token type ids의 범위 (BERT는 segmentA, segmentB로 2종류)
)

In [11]:
pre = BertForPreTraining(config=config)
pre.save_pretrained('model/bert')

In [12]:
class TransformerModel(nn.Module):
    def __init__(self,
                 protein_hidden_dim=768,
                 pretrained_model='model/bert'
                ):
        super(TransformerModel, self).__init__()              
        # Transformer                
        self.transformer = BertModel.from_pretrained(pretrained_model)
                
        in_channels = protein_hidden_dim + 4
            
        self.classifier = nn.Sequential(
            nn.LeakyReLU(True),
            nn.BatchNorm1d(in_channels),
            nn.Linear(in_channels, protein_hidden_dim//4),
            nn.LeakyReLU(True),
            nn.BatchNorm1d(protein_hidden_dim//4),
            nn.Linear(protein_hidden_dim//4, 1)
        )
        
    def forward(self, epitope_ids_list, epitope_mask_list, protein_features):
        BATCH_SIZE = epitope_ids_list.size(0)
        # Get Embedding Vector
        epitope_x = self.transformer(input_ids=epitope_ids_list, attention_mask=epitope_mask_list)[0]
        
        # transformer [CLS]
        epitope_hidden = epitope_x[:, 0, :]
                        
        # Feature Concat -> Binary Classifier
        x = torch.cat([epitope_hidden, protein_features], axis=-1)
        x = self.classifier(x).view(-1)
        return x

In [13]:
class WeightedFocalLoss(nn.Module):
    "Non weighted version of Focal Loss"
    def __init__(self, alpha=.25, gamma=2):
        super(WeightedFocalLoss, self).__init__()
        self.alpha = torch.tensor([alpha, 1-alpha]).to(device)
        self.gamma = gamma

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        targets = targets.type(torch.long)
        at = self.alpha.gather(0, targets.data.view(-1))
        pt = torch.exp(-BCE_loss)
        F_loss = at*(1-pt)**self.gamma * BCE_loss
        return F_loss.mean()

In [14]:
def train(model, optimizer, train_loader, val_loader, scheduler, device):
    model.to(device)
    criterion = WeightedFocalLoss().to(device) 
#     criterion = FocalLoss().to(device)
    best_val_f1 = 0
    for epoch in range(1, CFG['EPOCHS']+1):
        model.train()
        train_loss = []
        for epitope_ids_list, epitope_mask_list, protein_features, label in tqdm(iter(train_loader)):
            epitope_ids_list = epitope_ids_list.to(device)
            epitope_mask_list = epitope_mask_list.to(device)

            protein_features = protein_features.to(device)

            label = label.float().to(device)
            
            optimizer.zero_grad()
            
            output = model(epitope_ids_list, epitope_mask_list, protein_features)
            loss = criterion(output, label)
            
            loss.backward()
            optimizer.step()
            
            train_loss.append(loss.item())
            
            if scheduler is not None:
                scheduler.step()
                    
        val_loss, val_f1, val_acc, val_f1_T, val_f1_F = validation(model, val_loader, criterion, device)
        print(f'Epoch : [{epoch}] Train Loss : [{np.mean(train_loss):.5f}] Val Loss : [{val_loss:.5f}] Val F1 : [{val_f1:.5f}] Val acc : [{val_acc:.5f}] Val F1_T : [{val_f1_T:.5f}] Val F1_F : [{val_f1_F:.5f}] ')
        
        if best_val_f1 < val_f1:
            best_val_f1 = val_f1
#             torch.save(model.module.state_dict(), './feature_transformer_best_model.pth', _use_new_zipfile_serialization=False)
            torch.save(model.module.state_dict(), '../../../../../DAS_Storage4/daehun/epitope_best_model.pth', _use_new_zipfile_serialization=False)
            print('Model Saved.')
    return best_val_f1

In [15]:
from sklearn.metrics import accuracy_score

def validation(model, val_loader, criterion, device):
    model.eval()
    pred_proba_label = []
    true_label = []
    val_loss = []
    with torch.no_grad():
        for epitope_ids_list, epitope_mask_list, protein_features, label in tqdm(iter(val_loader)):
            epitope_ids_list = epitope_ids_list.to(device)
            epitope_mask_list = epitope_mask_list.to(device)

            protein_features = protein_features.to(device)
            
            label = label.float().to(device)
            
            model_pred = model(epitope_ids_list, epitope_mask_list, protein_features)
            loss = criterion(model_pred, label)
            model_pred = torch.sigmoid(model_pred).to('cpu')
            
            pred_proba_label += model_pred.tolist()
            true_label += label.to('cpu').tolist()
            
            val_loss.append(loss.item())
    
    pred_label = np.where(np.array(pred_proba_label)>CFG['THRESHOLD'], 1, 0)
    val_f1 = f1_score(true_label, pred_label, average='macro')
    f1_T = f1_score(true_label, pred_label)
    f1_F = f1_score(true_label, pred_label, pos_label=0)
    acc = accuracy_score(true_label, pred_label)
    return np.mean(val_loss), val_f1, acc, f1_T, f1_F

In [16]:
model = TransformerModel()
model = nn.DataParallel(model, device_ids=[3, 4, 5, 1])
model.eval()

optimizer = torch.optim.Adam(params = model.parameters(), lr = CFG["LEARNING_RATE"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000*CFG['EPOCHS'], eta_min=0)

best_score = train(model, optimizer, train_loader, val_loader, scheduler, device)
print(f'Best Validation F1 Score : [{best_score:.5f}]')

Some weights of the model checkpoint at model/bert were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|██████████| 150/150 [05:05<00:00,  2.04s/it]
100%|██████████| 38/38 [00:29<00:00,  1.28it/s]


Epoch : [1] Train Loss : [0.04841] Val Loss : [0.03874] Val F1 : [0.68195] Val acc : [0.86777] Val F1_T : [0.43883] Val F1_F : [0.92506] 
Model Saved.


100%|██████████| 150/150 [04:35<00:00,  1.84s/it]
100%|██████████| 38/38 [00:30<00:00,  1.23it/s]

Epoch : [2] Train Loss : [0.04310] Val Loss : [0.04280] Val F1 : [0.64768] Val acc : [0.82417] Val F1_T : [0.39831] Val F1_F : [0.89704] 



100%|██████████| 150/150 [04:38<00:00,  1.86s/it]
100%|██████████| 38/38 [00:30<00:00,  1.25it/s]


Epoch : [3] Train Loss : [0.04054] Val Loss : [0.03728] Val F1 : [0.69208] Val acc : [0.87168] Val F1_T : [0.45691] Val F1_F : [0.92724] 
Model Saved.


100%|██████████| 150/150 [04:35<00:00,  1.84s/it]
100%|██████████| 38/38 [00:29<00:00,  1.29it/s]


Epoch : [4] Train Loss : [0.03791] Val Loss : [0.03390] Val F1 : [0.71526] Val acc : [0.89953] Val F1_T : [0.48620] Val F1_F : [0.94432] 
Model Saved.


100%|██████████| 150/150 [04:27<00:00,  1.78s/it]
100%|██████████| 38/38 [00:28<00:00,  1.33it/s]


Epoch : [5] Train Loss : [0.03569] Val Loss : [0.03300] Val F1 : [0.71298] Val acc : [0.91248] Val F1_T : [0.47368] Val F1_F : [0.95227] 


100%|██████████| 150/150 [04:19<00:00,  1.73s/it]
100%|██████████| 38/38 [00:29<00:00,  1.29it/s]


Epoch : [6] Train Loss : [0.03398] Val Loss : [0.03201] Val F1 : [0.71623] Val acc : [0.89162] Val F1_T : [0.49314] Val F1_F : [0.93932] 
Model Saved.


100%|██████████| 150/150 [04:25<00:00,  1.77s/it]
100%|██████████| 38/38 [00:27<00:00,  1.36it/s]


Epoch : [7] Train Loss : [0.03229] Val Loss : [0.03117] Val F1 : [0.71087] Val acc : [0.88813] Val F1_T : [0.48448] Val F1_F : [0.93726] 


100%|██████████| 150/150 [04:11<00:00,  1.68s/it]
100%|██████████| 38/38 [00:28<00:00,  1.35it/s]


Epoch : [8] Train Loss : [0.03081] Val Loss : [0.03028] Val F1 : [0.72497] Val acc : [0.90001] Val F1_T : [0.50557] Val F1_F : [0.94438] 
Model Saved.


100%|██████████| 150/150 [04:10<00:00,  1.67s/it]
100%|██████████| 38/38 [00:27<00:00,  1.36it/s]

Epoch : [9] Train Loss : [0.02969] Val Loss : [0.02949] Val F1 : [0.72062] Val acc : [0.89395] Val F1_T : [0.50056] Val F1_F : [0.94068] 



100%|██████████| 150/150 [04:11<00:00,  1.68s/it]
100%|██████████| 38/38 [00:28<00:00,  1.34it/s]


Epoch : [10] Train Loss : [0.02885] Val Loss : [0.03008] Val F1 : [0.70621] Val acc : [0.88229] Val F1_T : [0.47877] Val F1_F : [0.93365] 


100%|██████████| 150/150 [04:10<00:00,  1.67s/it]
100%|██████████| 38/38 [00:27<00:00,  1.37it/s]


Epoch : [11] Train Loss : [0.02783] Val Loss : [0.02775] Val F1 : [0.72939] Val acc : [0.91248] Val F1_T : [0.50679] Val F1_F : [0.95198] 
Model Saved.


100%|██████████| 150/150 [04:10<00:00,  1.67s/it]
100%|██████████| 38/38 [00:28<00:00,  1.36it/s]

Epoch : [12] Train Loss : [0.02685] Val Loss : [0.02860] Val F1 : [0.71457] Val acc : [0.88337] Val F1_T : [0.49507] Val F1_F : [0.93407] 



100%|██████████| 150/150 [04:10<00:00,  1.67s/it]
100%|██████████| 38/38 [00:27<00:00,  1.36it/s]

Epoch : [13] Train Loss : [0.02604] Val Loss : [0.02896] Val F1 : [0.72333] Val acc : [0.89770] Val F1_T : [0.50369] Val F1_F : [0.94297] 



100%|██████████| 150/150 [04:10<00:00,  1.67s/it]
100%|██████████| 38/38 [00:27<00:00,  1.36it/s]


Epoch : [14] Train Loss : [0.02546] Val Loss : [0.02770] Val F1 : [0.73656] Val acc : [0.91033] Val F1_T : [0.52260] Val F1_F : [0.95052] 
Model Saved.


100%|██████████| 150/150 [04:10<00:00,  1.67s/it]
100%|██████████| 38/38 [00:27<00:00,  1.38it/s]


Epoch : [15] Train Loss : [0.02445] Val Loss : [0.03424] Val F1 : [0.66081] Val acc : [0.82991] Val F1_T : [0.42132] Val F1_F : [0.90030] 


100%|██████████| 150/150 [04:11<00:00,  1.68s/it]
100%|██████████| 38/38 [00:27<00:00,  1.37it/s]


Epoch : [16] Train Loss : [0.02411] Val Loss : [0.02915] Val F1 : [0.72044] Val acc : [0.90373] Val F1_T : [0.49408] Val F1_F : [0.94680] 


100%|██████████| 150/150 [04:11<00:00,  1.67s/it]
100%|██████████| 38/38 [00:28<00:00,  1.35it/s]

Epoch : [17] Train Loss : [0.02261] Val Loss : [0.02912] Val F1 : [0.72483] Val acc : [0.91274] Val F1_T : [0.49743] Val F1_F : [0.95222] 



100%|██████████| 150/150 [04:11<00:00,  1.67s/it]
100%|██████████| 38/38 [00:27<00:00,  1.36it/s]

Epoch : [18] Train Loss : [0.02149] Val Loss : [0.02949] Val F1 : [0.70926] Val acc : [0.88344] Val F1_T : [0.48423] Val F1_F : [0.93430] 



100%|██████████| 150/150 [04:11<00:00,  1.68s/it]
100%|██████████| 38/38 [00:27<00:00,  1.39it/s]


Epoch : [19] Train Loss : [0.02025] Val Loss : [0.03195] Val F1 : [0.69835] Val acc : [0.91657] Val F1_T : [0.44180] Val F1_F : [0.95491] 


100%|██████████| 150/150 [04:10<00:00,  1.67s/it]
100%|██████████| 38/38 [00:27<00:00,  1.37it/s]

Epoch : [20] Train Loss : [0.01840] Val Loss : [0.03166] Val F1 : [0.69838] Val acc : [0.88381] Val F1_T : [0.46189] Val F1_F : [0.93487] 



100%|██████████| 150/150 [04:10<00:00,  1.67s/it]
100%|██████████| 38/38 [00:28<00:00,  1.34it/s]

Epoch : [21] Train Loss : [0.01668] Val Loss : [0.03445] Val F1 : [0.69616] Val acc : [0.88237] Val F1_T : [0.45831] Val F1_F : [0.93402] 



100%|██████████| 150/150 [04:10<00:00,  1.67s/it]
100%|██████████| 38/38 [00:28<00:00,  1.36it/s]


Epoch : [22] Train Loss : [0.01532] Val Loss : [0.03717] Val F1 : [0.67864] Val acc : [0.89830] Val F1_T : [0.41295] Val F1_F : [0.94433] 


100%|██████████| 150/150 [04:09<00:00,  1.67s/it]
100%|██████████| 38/38 [00:27<00:00,  1.36it/s]

Epoch : [23] Train Loss : [0.01342] Val Loss : [0.03707] Val F1 : [0.66189] Val acc : [0.85163] Val F1_T : [0.40861] Val F1_F : [0.91518] 



100%|██████████| 150/150 [04:10<00:00,  1.67s/it]
100%|██████████| 38/38 [00:28<00:00,  1.35it/s]

Epoch : [24] Train Loss : [0.01240] Val Loss : [0.04134] Val F1 : [0.67822] Val acc : [0.89007] Val F1_T : [0.41712] Val F1_F : [0.93931] 



100%|██████████| 150/150 [04:10<00:00,  1.67s/it]
100%|██████████| 38/38 [00:27<00:00,  1.39it/s]

Epoch : [25] Train Loss : [0.01067] Val Loss : [0.04692] Val F1 : [0.68671] Val acc : [0.87621] Val F1_T : [0.44306] Val F1_F : [0.93037] 



100%|██████████| 150/150 [04:10<00:00,  1.67s/it]
100%|██████████| 38/38 [00:28<00:00,  1.34it/s]

Epoch : [26] Train Loss : [0.00931] Val Loss : [0.04646] Val F1 : [0.68355] Val acc : [0.88827] Val F1_T : [0.42903] Val F1_F : [0.93807] 



100%|██████████| 150/150 [04:11<00:00,  1.68s/it]
100%|██████████| 38/38 [00:27<00:00,  1.36it/s]

Epoch : [27] Train Loss : [0.00813] Val Loss : [0.05025] Val F1 : [0.69077] Val acc : [0.90265] Val F1_T : [0.43481] Val F1_F : [0.94674] 



100%|██████████| 150/150 [04:10<00:00,  1.67s/it]
100%|██████████| 38/38 [00:27<00:00,  1.38it/s]

Epoch : [28] Train Loss : [0.00709] Val Loss : [0.05283] Val F1 : [0.70363] Val acc : [0.90113] Val F1_T : [0.46169] Val F1_F : [0.94557] 



100%|██████████| 150/150 [04:09<00:00,  1.66s/it]
100%|██████████| 38/38 [00:27<00:00,  1.37it/s]

Epoch : [29] Train Loss : [0.00632] Val Loss : [0.05215] Val F1 : [0.69061] Val acc : [0.88344] Val F1_T : [0.44635] Val F1_F : [0.93487] 



100%|██████████| 150/150 [04:09<00:00,  1.67s/it]
100%|██████████| 38/38 [00:27<00:00,  1.37it/s]

Epoch : [30] Train Loss : [0.00636] Val Loss : [0.05831] Val F1 : [0.70483] Val acc : [0.90074] Val F1_T : [0.46437] Val F1_F : [0.94530] 



100%|██████████| 150/150 [04:10<00:00,  1.67s/it]
100%|██████████| 38/38 [00:27<00:00,  1.38it/s]

Epoch : [31] Train Loss : [0.00543] Val Loss : [0.05549] Val F1 : [0.68436] Val acc : [0.87697] Val F1_T : [0.43779] Val F1_F : [0.93093] 



100%|██████████| 150/150 [04:09<00:00,  1.66s/it]
100%|██████████| 38/38 [00:27<00:00,  1.39it/s]

Epoch : [32] Train Loss : [0.00530] Val Loss : [0.05508] Val F1 : [0.69364] Val acc : [0.89275] Val F1_T : [0.44667] Val F1_F : [0.94062] 



100%|██████████| 150/150 [04:10<00:00,  1.67s/it]
100%|██████████| 38/38 [00:28<00:00,  1.33it/s]

Epoch : [33] Train Loss : [0.00443] Val Loss : [0.06209] Val F1 : [0.69592] Val acc : [0.90121] Val F1_T : [0.44608] Val F1_F : [0.94577] 



100%|██████████| 150/150 [04:10<00:00,  1.67s/it]
100%|██████████| 38/38 [00:28<00:00,  1.34it/s]

Epoch : [34] Train Loss : [0.00431] Val Loss : [0.06339] Val F1 : [0.69209] Val acc : [0.89647] Val F1_T : [0.44124] Val F1_F : [0.94295] 



100%|██████████| 150/150 [04:11<00:00,  1.68s/it]
100%|██████████| 38/38 [00:27<00:00,  1.37it/s]

Epoch : [35] Train Loss : [0.00394] Val Loss : [0.06497] Val F1 : [0.70354] Val acc : [0.90181] Val F1_T : [0.46110] Val F1_F : [0.94599] 



100%|██████████| 150/150 [04:11<00:00,  1.68s/it]
100%|██████████| 38/38 [00:27<00:00,  1.38it/s]

Epoch : [36] Train Loss : [0.00358] Val Loss : [0.05912] Val F1 : [0.69288] Val acc : [0.89007] Val F1_T : [0.44679] Val F1_F : [0.93897] 



100%|██████████| 150/150 [04:10<00:00,  1.67s/it]
100%|██████████| 38/38 [00:28<00:00,  1.35it/s]

Epoch : [37] Train Loss : [0.00394] Val Loss : [0.06964] Val F1 : [0.70616] Val acc : [0.90559] Val F1_T : [0.46408] Val F1_F : [0.94823] 



100%|██████████| 150/150 [04:09<00:00,  1.66s/it]
100%|██████████| 38/38 [00:28<00:00,  1.35it/s]

Epoch : [38] Train Loss : [0.00327] Val Loss : [0.06097] Val F1 : [0.69643] Val acc : [0.89175] Val F1_T : [0.45292] Val F1_F : [0.93993] 



100%|██████████| 150/150 [04:08<00:00,  1.66s/it]
100%|██████████| 38/38 [00:27<00:00,  1.37it/s]

Epoch : [39] Train Loss : [0.00307] Val Loss : [0.06240] Val F1 : [0.70039] Val acc : [0.90899] Val F1_T : [0.45039] Val F1_F : [0.95039] 



100%|██████████| 150/150 [04:12<00:00,  1.68s/it]
100%|██████████| 38/38 [00:28<00:00,  1.36it/s]


Epoch : [40] Train Loss : [0.00320] Val Loss : [0.07926] Val F1 : [0.69902] Val acc : [0.90711] Val F1_T : [0.44876] Val F1_F : [0.94928] 


100%|██████████| 150/150 [04:10<00:00,  1.67s/it]
100%|██████████| 38/38 [00:27<00:00,  1.37it/s]

Epoch : [41] Train Loss : [0.00311] Val Loss : [0.06688] Val F1 : [0.68685] Val acc : [0.89566] Val F1_T : [0.43114] Val F1_F : [0.94256] 



100%|██████████| 150/150 [04:11<00:00,  1.68s/it]
100%|██████████| 38/38 [00:28<00:00,  1.34it/s]

Epoch : [42] Train Loss : [0.00270] Val Loss : [0.06580] Val F1 : [0.70181] Val acc : [0.90527] Val F1_T : [0.45549] Val F1_F : [0.94812] 



100%|██████████| 150/150 [04:11<00:00,  1.68s/it]
100%|██████████| 38/38 [00:27<00:00,  1.39it/s]


Epoch : [43] Train Loss : [0.00281] Val Loss : [0.07012] Val F1 : [0.68278] Val acc : [0.89256] Val F1_T : [0.42480] Val F1_F : [0.94075] 


100%|██████████| 150/150 [04:12<00:00,  1.68s/it]
100%|██████████| 38/38 [00:27<00:00,  1.37it/s]

Epoch : [44] Train Loss : [0.00388] Val Loss : [0.06777] Val F1 : [0.69644] Val acc : [0.89634] Val F1_T : [0.45010] Val F1_F : [0.94277] 



100%|██████████| 150/150 [04:11<00:00,  1.67s/it]
100%|██████████| 38/38 [00:28<00:00,  1.35it/s]

Epoch : [45] Train Loss : [0.00297] Val Loss : [0.07408] Val F1 : [0.69855] Val acc : [0.90719] Val F1_T : [0.44777] Val F1_F : [0.94933] 



100%|██████████| 150/150 [04:10<00:00,  1.67s/it]
100%|██████████| 38/38 [00:28<00:00,  1.34it/s]

Epoch : [46] Train Loss : [0.00275] Val Loss : [0.06610] Val F1 : [0.70531] Val acc : [0.90986] Val F1_T : [0.45980] Val F1_F : [0.95083] 



100%|██████████| 150/150 [04:09<00:00,  1.66s/it]
100%|██████████| 38/38 [00:27<00:00,  1.38it/s]

Epoch : [47] Train Loss : [0.00244] Val Loss : [0.07345] Val F1 : [0.69894] Val acc : [0.90606] Val F1_T : [0.44922] Val F1_F : [0.94865] 



100%|██████████| 150/150 [04:11<00:00,  1.67s/it]
100%|██████████| 38/38 [00:27<00:00,  1.38it/s]

Epoch : [48] Train Loss : [0.00225] Val Loss : [0.07244] Val F1 : [0.70384] Val acc : [0.91350] Val F1_T : [0.45465] Val F1_F : [0.95302] 



100%|██████████| 150/150 [04:11<00:00,  1.68s/it]
100%|██████████| 38/38 [00:28<00:00,  1.34it/s]


Epoch : [49] Train Loss : [0.00194] Val Loss : [0.07387] Val F1 : [0.70094] Val acc : [0.90823] Val F1_T : [0.45196] Val F1_F : [0.94992] 


100%|██████████| 150/150 [04:09<00:00,  1.66s/it]
100%|██████████| 38/38 [00:27<00:00,  1.36it/s]

Epoch : [50] Train Loss : [0.00215] Val Loss : [0.07264] Val F1 : [0.69280] Val acc : [0.90346] Val F1_T : [0.43841] Val F1_F : [0.94719] 
Best Validation F1 Score : [0.73656]





In [17]:
test_df = pd.read_csv('data/test.csv')
test_epitope_ids_list, test_epitope_mask_list, test_protein_features, test_label_list= get_preprocessing('test', test_df, tokenizer)

120944it [21:01, 95.87it/s] 


test dataframe preprocessing was done.


In [18]:
test_dataset = CustomDataset(test_epitope_ids_list, test_epitope_mask_list, test_protein_features, None)
test_loader = DataLoader(test_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=False, num_workers=CFG['NUM_WORKERS'])

In [19]:
model = TransformerModel()
best_checkpoint = torch.load('../../../../../DAS_Storage4/daehun/epitope_best_model.pth', map_location=device)
model.load_state_dict(best_checkpoint)
model.eval()
model.to(device)

Some weights of the model checkpoint at model/bert were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


TransformerModel(
  (transformer): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30, 768, padding_idx=0)
      (position_embeddings): Embedding(72, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affin

In [20]:
def inference(model, test_loader, device):
    model.eval()
    pred_proba_label = []
    with torch.no_grad():
        for epitope_ids_list, epitope_mask_list, protein_features in tqdm(iter(test_loader)):
            epitope_ids_list = epitope_ids_list.to(device)
            epitope_mask_list = epitope_mask_list.to(device)

            protein_features = protein_features.to(device)
            
            model_pred = model(epitope_ids_list, epitope_mask_list, protein_features)
            model_pred = torch.sigmoid(model_pred).to('cpu')
            
            pred_proba_label += model_pred.tolist()
    
    pred_label = np.where(np.array(pred_proba_label)>CFG['THRESHOLD'], 1, 0)
    return pred_label

In [21]:
preds = inference(model, test_loader, device)
submit = pd.read_csv('data/sample_submission.csv')
submit['label'] = preds

submit.to_csv('submission/epitope_submission.csv', index=False)
print('Done.')

100%|██████████| 119/119 [05:06<00:00,  2.58s/it]

Done.



