In [1]:
import random
import pandas as pd
import numpy as np
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn import metrics
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm, trange

from tqdm import tqdm
from sklearn.metrics import f1_score
from transformers import BertConfig, BertModel, BertTokenizer, BertForPreTraining

from Bio.SeqUtils.ProtParam import ProteinAnalysis

In [2]:
device = torch.device('cuda:3') if torch.cuda.is_available() else torch.device('cpu')

In [3]:
CFG = {
    'NUM_WORKERS':4,
    'EPITOPE_MAX_LEN':72,
    'EPOCHS':50,
    'LEARNING_RATE':1e-5,
    'BATCH_SIZE':512,
    'THRESHOLD':0.5,   # 기본적으로 0.5로 사용하지만 data impalance가 심할 경우 더 큰 값을 사용하기도 한다.
    'SEED':41
}

In [4]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG['SEED']) # Seed 고정

In [5]:
tokenizer = BertTokenizer.from_pretrained('model/tokenizer')

In [6]:
def seqtoinput(seq):
    for j in range(len(seq)-1):
        seq = seq[:j+j+1]+ ' ' + seq[j+j+1:]
    return seq
    
def get_protein_feature(seq):
    protein_feature = []
    protein_feature.append(ProteinAnalysis(seq).isoelectric_point())
    protein_feature.append(ProteinAnalysis(seq).aromaticity())
    protein_feature.append(ProteinAnalysis(seq).gravy())
    protein_feature.append(ProteinAnalysis(seq).instability_index())
    return protein_feature
    
def get_preprocessing(data_type, new_df, tokenizer):
    epitope_ids_list = []
    epitope_mask_list = []
    
    protein_features = []
#     epitope_features = []
        
    for epitope, antigen, s_p, e_p in tqdm(zip(new_df['epitope_seq'], new_df['antigen_seq'], new_df['start_position'], new_df['end_position'])):             
        protein_features.append(get_protein_feature(antigen))
#         epitope_features.append(get_peptide_feature(epitope))
        
        epitope = seqtoinput(epitope)
        
        
        epitope_input = tokenizer(epitope, add_special_tokens=True, pad_to_max_length=True, max_length = 72)
        epitope_ids = epitope_input['input_ids']
        epitope_mask = epitope_input['attention_mask']
        
        
        epitope_ids_list.append(epitope_ids)
        epitope_mask_list.append(epitope_mask)

    
    label_list = None
    if data_type != 'test':
        label_list = []
        for label in new_df['label']:
            label_list.append(label)
    print(f'{data_type} dataframe preprocessing was done.')
    return epitope_ids_list, epitope_mask_list, protein_features, label_list

In [7]:
train = pd.read_csv('data/train.csv')
train = train.drop([7511])
test = pd.read_csv('data/test.csv')

train, val = train_test_split(train, train_size=0.8, random_state=12)

train_epitope_ids_list, train_epitope_mask_list, train_protein_features, train_label_list = get_preprocessing('train', train, tokenizer)
val_epitope_ids_list, val_epitope_mask_list, val_protein_features, val_label_list = get_preprocessing('val', val, tokenizer)

0it [00:00, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
152648it [04:02, 630.34it/s]


train dataframe preprocessing was done.


38162it [01:00, 626.61it/s]

val dataframe preprocessing was done.





In [8]:
class CustomDataset(Dataset):
    def __init__(self, epitope_ids_list, epitope_mask_list, protein_features, label_list):
        self.epitope_ids_list = epitope_ids_list
        self.epitope_mask_list = epitope_mask_list
        self.protein_features = protein_features

        self.label_list = label_list
        
    def __getitem__(self, index):
        self.epitope_ids = self.epitope_ids_list[index]
        self.epitope_mask = self.epitope_mask_list[index]
        
        self.protein_feature = self.protein_features[index]        
        
        if self.label_list is not None:
            self.label = self.label_list[index]
            return torch.tensor(self.epitope_ids), torch.tensor(self.epitope_mask), torch.tensor(self.protein_feature), self.label
        else:
            return torch.tensor(self.epitope_ids), torch.tensor(self.epitope_mask), torch.tensor(self.protein_feature)
        
    def __len__(self):
        return len(self.epitope_ids_list)

In [9]:
train_dataset = CustomDataset(train_epitope_ids_list, train_epitope_mask_list, train_protein_features, train_label_list)
train_loader = DataLoader(train_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=True, num_workers=CFG['NUM_WORKERS'])

val_dataset = CustomDataset(val_epitope_ids_list, val_epitope_mask_list, val_protein_features, val_label_list)
val_loader = DataLoader(val_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=False, num_workers=CFG['NUM_WORKERS'])

In [10]:
config = BertConfig(
    vocab_size=30, # default는 영어 기준이므로 내가 만든 vocab size에 맞게 수정해줘야 함
    hidden_size=768,
    num_hidden_layers=12,    # layer num
    num_attention_heads=12,    # transformer attention head number
    intermediate_size=3072,   # transformer 내에 있는 feed-forward network의 dimension size
    hidden_act="gelu",
    hidden_dropout_prob=0.,
    attention_probs_dropout_prob=0.0,
    max_position_embeddings=72,    # embedding size 최대 몇 token까지 input으로 사용할 것인지 지정
    type_vocab_size=2,    # token type ids의 범위 (BERT는 segmentA, segmentB로 2종류)
)

In [11]:
pre = BertForPreTraining(config=config)
pre.save_pretrained('model/bert')

In [12]:
class TransformerModel(nn.Module):
    def __init__(self,
                 protein_hidden_dim=768,
                 pretrained_model='model/bert'
                ):
        super(TransformerModel, self).__init__()              
        # Transformer                
        self.transformer = BertModel.from_pretrained(pretrained_model)
                
        in_channels = protein_hidden_dim + 4
            
        self.classifier = nn.Sequential(
            nn.LeakyReLU(True),
            nn.BatchNorm1d(in_channels),
            nn.Linear(in_channels, protein_hidden_dim//4),
            nn.LeakyReLU(True),
            nn.BatchNorm1d(protein_hidden_dim//4),
            nn.Linear(protein_hidden_dim//4, 1)
        )
        
    def forward(self, epitope_ids_list, epitope_mask_list, protein_features):
        BATCH_SIZE = epitope_ids_list.size(0)
        # Get Embedding Vector
        epitope_x = self.transformer(input_ids=epitope_ids_list, attention_mask=epitope_mask_list)[0]
        
        # transformer [CLS]
        epitope_hidden = epitope_x[:, 0, :]
                
        # Feature Concat -> Binary Classifier
        x = torch.cat([epitope_hidden, protein_features], axis=-1)
        x = self.classifier(x).view(-1)
        return x

In [13]:
def train(model, optimizer, train_loader, val_loader, scheduler, device):
    model.to(device)
    criterion = nn.BCEWithLogitsLoss().to(device) 
    
    best_val_f1 = 0
    for epoch in range(1, CFG['EPOCHS']+1):
        model.train()
        train_loss = []
        for epitope_ids_list, epitope_mask_list, protein_features, label in tqdm(iter(train_loader)):
            epitope_ids_list = epitope_ids_list.to(device)
            epitope_mask_list = epitope_mask_list.to(device)

            protein_features = protein_features.to(device)

            label = label.float().to(device)
            
            optimizer.zero_grad()
            
            output = model(epitope_ids_list, epitope_mask_list, protein_features)
            loss = criterion(output, label)
            
            loss.backward()
            optimizer.step()
            
            train_loss.append(loss.item())
            
            if scheduler is not None:
                scheduler.step()
                    
        val_loss, val_f1, val_acc = validation(model, val_loader, criterion, device)
        print(f'Epoch : [{epoch}] Train Loss : [{np.mean(train_loss):.5f}] Val Loss : [{val_loss:.5f}] Val F1 : [{val_f1:.5f}] Val acc : [{val_acc:.5f}]')
        
        if best_val_f1 < val_f1:
            best_val_f1 = val_f1
#             torch.save(model.module.state_dict(), './feature_transformer_best_model.pth', _use_new_zipfile_serialization=False)
            torch.save(model.module.state_dict(), './epitope_best_model.pth', _use_new_zipfile_serialization=False)
            print('Model Saved.')
    return best_val_f1

In [14]:
from sklearn.metrics import accuracy_score

def validation(model, val_loader, criterion, device):
    model.eval()
    pred_proba_label = []
    true_label = []
    val_loss = []
    with torch.no_grad():
        for epitope_ids_list, epitope_mask_list, protein_features, label in tqdm(iter(val_loader)):
            epitope_ids_list = epitope_ids_list.to(device)
            epitope_mask_list = epitope_mask_list.to(device)

            protein_features = protein_features.to(device)
            
            label = label.float().to(device)
            
            model_pred = model(epitope_ids_list, epitope_mask_list, protein_features)
            loss = criterion(model_pred, label)
            model_pred = torch.sigmoid(model_pred).to('cpu')
            
            pred_proba_label += model_pred.tolist()
            true_label += label.to('cpu').tolist()
            
            val_loss.append(loss.item())
    
    pred_label = np.where(np.array(pred_proba_label)>CFG['THRESHOLD'], 1, 0)
    val_f1 = f1_score(true_label, pred_label, average='macro')
    acc = accuracy_score(true_label, pred_label)
    return np.mean(val_loss), val_f1, acc

In [15]:
model = TransformerModel()
model = nn.DataParallel(model, device_ids=[3, 4])
model.eval()

optimizer = torch.optim.Adam(params = model.parameters(), lr = CFG["LEARNING_RATE"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000*CFG['EPOCHS'], eta_min=0)

best_score = train(model, optimizer, train_loader, val_loader, scheduler, device)
print(f'Best Validation F1 Score : [{best_score:.5f}]')

Some weights of the model checkpoint at model/bert were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|██████████| 299/299 [05:07<00:00,  1.03s/it]
100%|██████████| 75/75 [00:31<00:00,  2.36it/s]


Epoch : [1] Train Loss : [0.65799] Val Loss : [0.52931] Val F1 : [0.69771] Val acc : [0.88536]
Model Saved.


100%|██████████| 299/299 [05:48<00:00,  1.17s/it]
100%|██████████| 75/75 [00:33<00:00,  2.23it/s]

Epoch : [2] Train Loss : [0.61886] Val Loss : [0.64141] Val F1 : [0.68421] Val acc : [0.87157]



100%|██████████| 299/299 [05:51<00:00,  1.18s/it]
100%|██████████| 75/75 [00:34<00:00,  2.19it/s]

Epoch : [3] Train Loss : [0.59472] Val Loss : [0.47961] Val F1 : [0.69193] Val acc : [0.90603]



100%|██████████| 299/299 [05:49<00:00,  1.17s/it]
100%|██████████| 75/75 [00:35<00:00,  2.13it/s]

Epoch : [4] Train Loss : [0.58180] Val Loss : [0.64861] Val F1 : [0.67481] Val acc : [0.85420]



100%|██████████| 299/299 [05:54<00:00,  1.18s/it]
100%|██████████| 75/75 [00:33<00:00,  2.24it/s]

Epoch : [5] Train Loss : [0.55670] Val Loss : [0.69558] Val F1 : [0.62256] Val acc : [0.78827]



100%|██████████| 299/299 [05:51<00:00,  1.18s/it]
100%|██████████| 75/75 [00:34<00:00,  2.19it/s]

Epoch : [6] Train Loss : [0.54837] Val Loss : [0.61243] Val F1 : [0.68981] Val acc : [0.87267]



100%|██████████| 299/299 [05:51<00:00,  1.18s/it]
100%|██████████| 75/75 [00:34<00:00,  2.18it/s]

Epoch : [7] Train Loss : [0.53678] Val Loss : [0.37589] Val F1 : [0.61627] Val acc : [0.91772]



100%|██████████| 299/299 [05:51<00:00,  1.18s/it]
100%|██████████| 75/75 [00:35<00:00,  2.09it/s]


Epoch : [8] Train Loss : [0.51761] Val Loss : [0.50704] Val F1 : [0.70479] Val acc : [0.90593]
Model Saved.


100%|██████████| 299/299 [05:48<00:00,  1.17s/it]
100%|██████████| 75/75 [00:35<00:00,  2.11it/s]

Epoch : [9] Train Loss : [0.50515] Val Loss : [0.34052] Val F1 : [0.60549] Val acc : [0.91785]



100%|██████████| 299/299 [05:50<00:00,  1.17s/it]
100%|██████████| 75/75 [00:31<00:00,  2.35it/s]

Epoch : [10] Train Loss : [0.48846] Val Loss : [0.62615] Val F1 : [0.69362] Val acc : [0.87474]



100%|██████████| 299/299 [05:51<00:00,  1.18s/it]
100%|██████████| 75/75 [00:34<00:00,  2.17it/s]

Epoch : [11] Train Loss : [0.48239] Val Loss : [0.67013] Val F1 : [0.64807] Val acc : [0.82105]



100%|██████████| 299/299 [05:52<00:00,  1.18s/it]
100%|██████████| 75/75 [00:34<00:00,  2.20it/s]

Epoch : [12] Train Loss : [0.46646] Val Loss : [0.68084] Val F1 : [0.65632] Val acc : [0.82849]



100%|██████████| 299/299 [05:54<00:00,  1.18s/it]
100%|██████████| 75/75 [00:33<00:00,  2.23it/s]

Epoch : [13] Train Loss : [0.45750] Val Loss : [0.30802] Val F1 : [0.51365] Val acc : [0.91363]



100%|██████████| 299/299 [05:49<00:00,  1.17s/it]
100%|██████████| 75/75 [00:35<00:00,  2.13it/s]

Epoch : [14] Train Loss : [0.43677] Val Loss : [0.42046] Val F1 : [0.69618] Val acc : [0.91672]



100%|██████████| 299/299 [05:52<00:00,  1.18s/it]
100%|██████████| 75/75 [00:33<00:00,  2.22it/s]

Epoch : [15] Train Loss : [0.42723] Val Loss : [0.35320] Val F1 : [0.60362] Val acc : [0.91866]



100%|██████████| 299/299 [05:52<00:00,  1.18s/it]
100%|██████████| 75/75 [00:34<00:00,  2.18it/s]

Epoch : [16] Train Loss : [0.42045] Val Loss : [0.92595] Val F1 : [0.12941] Val acc : [0.13126]



100%|██████████| 299/299 [05:53<00:00,  1.18s/it]
100%|██████████| 75/75 [00:34<00:00,  2.17it/s]


Epoch : [17] Train Loss : [0.39954] Val Loss : [0.42138] Val F1 : [0.71251] Val acc : [0.91119]
Model Saved.


100%|██████████| 299/299 [05:52<00:00,  1.18s/it]
100%|██████████| 75/75 [00:34<00:00,  2.15it/s]


Epoch : [18] Train Loss : [0.39872] Val Loss : [0.45099] Val F1 : [0.71437] Val acc : [0.91172]
Model Saved.


100%|██████████| 299/299 [05:55<00:00,  1.19s/it]
100%|██████████| 75/75 [00:33<00:00,  2.26it/s]

Epoch : [19] Train Loss : [0.37854] Val Loss : [0.52528] Val F1 : [0.70999] Val acc : [0.89036]



100%|██████████| 299/299 [05:52<00:00,  1.18s/it]
100%|██████████| 75/75 [00:35<00:00,  2.13it/s]

Epoch : [20] Train Loss : [0.36713] Val Loss : [0.36827] Val F1 : [0.60544] Val acc : [0.91892]



100%|██████████| 299/299 [05:52<00:00,  1.18s/it]
100%|██████████| 75/75 [00:34<00:00,  2.18it/s]


Epoch : [21] Train Loss : [0.36124] Val Loss : [0.46963] Val F1 : [0.71842] Val acc : [0.90724]
Model Saved.


100%|██████████| 299/299 [05:54<00:00,  1.18s/it]
100%|██████████| 75/75 [00:34<00:00,  2.20it/s]

Epoch : [22] Train Loss : [0.34649] Val Loss : [0.32819] Val F1 : [0.63100] Val acc : [0.91982]



100%|██████████| 299/299 [05:56<00:00,  1.19s/it]
100%|██████████| 75/75 [00:33<00:00,  2.24it/s]

Epoch : [23] Train Loss : [0.35924] Val Loss : [0.45032] Val F1 : [0.70112] Val acc : [0.91704]



100%|██████████| 299/299 [05:53<00:00,  1.18s/it]
100%|██████████| 75/75 [00:32<00:00,  2.27it/s]

Epoch : [24] Train Loss : [0.33630] Val Loss : [0.36684] Val F1 : [0.60604] Val acc : [0.91885]



100%|██████████| 299/299 [05:50<00:00,  1.17s/it]
100%|██████████| 75/75 [00:35<00:00,  2.13it/s]

Epoch : [25] Train Loss : [0.32860] Val Loss : [0.37140] Val F1 : [0.68581] Val acc : [0.91987]



100%|██████████| 299/299 [05:53<00:00,  1.18s/it]
100%|██████████| 75/75 [00:35<00:00,  2.12it/s]

Epoch : [26] Train Loss : [0.30926] Val Loss : [0.40814] Val F1 : [0.59409] Val acc : [0.91735]



100%|██████████| 299/299 [05:52<00:00,  1.18s/it]
100%|██████████| 75/75 [00:35<00:00,  2.12it/s]

Epoch : [27] Train Loss : [0.32235] Val Loss : [0.46138] Val F1 : [0.71246] Val acc : [0.91112]



100%|██████████| 299/299 [05:48<00:00,  1.17s/it]
100%|██████████| 75/75 [00:34<00:00,  2.21it/s]

Epoch : [28] Train Loss : [0.30934] Val Loss : [0.37921] Val F1 : [0.69326] Val acc : [0.91927]



100%|██████████| 299/299 [05:56<00:00,  1.19s/it]
100%|██████████| 75/75 [00:34<00:00,  2.18it/s]

Epoch : [29] Train Loss : [0.31150] Val Loss : [0.38147] Val F1 : [0.70057] Val acc : [0.91924]



100%|██████████| 299/299 [05:53<00:00,  1.18s/it]
100%|██████████| 75/75 [00:33<00:00,  2.27it/s]

Epoch : [30] Train Loss : [0.29872] Val Loss : [0.46457] Val F1 : [0.71175] Val acc : [0.90745]



100%|██████████| 299/299 [05:49<00:00,  1.17s/it]
100%|██████████| 75/75 [00:35<00:00,  2.12it/s]


Epoch : [31] Train Loss : [0.30416] Val Loss : [0.39384] Val F1 : [0.71916] Val acc : [0.90079]
Model Saved.


100%|██████████| 299/299 [05:52<00:00,  1.18s/it]
100%|██████████| 75/75 [00:35<00:00,  2.14it/s]

Epoch : [32] Train Loss : [0.29204] Val Loss : [1.61683] Val F1 : [0.10792] Val acc : [0.11179]



100%|██████████| 299/299 [06:11<00:00,  1.24s/it]
100%|██████████| 75/75 [00:38<00:00,  1.94it/s]

Epoch : [33] Train Loss : [0.28223] Val Loss : [0.28058] Val F1 : [0.68712] Val acc : [0.92029]



100%|██████████| 299/299 [06:31<00:00,  1.31s/it]
100%|██████████| 75/75 [00:42<00:00,  1.75it/s]

Epoch : [34] Train Loss : [0.27100] Val Loss : [0.27057] Val F1 : [0.62462] Val acc : [0.91963]



100%|██████████| 299/299 [06:28<00:00,  1.30s/it]
100%|██████████| 75/75 [00:38<00:00,  1.97it/s]

Epoch : [35] Train Loss : [0.26072] Val Loss : [0.25421] Val F1 : [0.55555] Val acc : [0.91657]



100%|██████████| 299/299 [06:36<00:00,  1.33s/it]
100%|██████████| 75/75 [00:40<00:00,  1.86it/s]

Epoch : [36] Train Loss : [0.25843] Val Loss : [0.26407] Val F1 : [0.50735] Val acc : [0.91324]



100%|██████████| 299/299 [06:39<00:00,  1.34s/it]
100%|██████████| 75/75 [00:40<00:00,  1.84it/s]

Epoch : [37] Train Loss : [0.25109] Val Loss : [0.33309] Val F1 : [0.68815] Val acc : [0.91885]



100%|██████████| 299/299 [06:35<00:00,  1.32s/it]
100%|██████████| 75/75 [00:38<00:00,  1.93it/s]

Epoch : [38] Train Loss : [0.24277] Val Loss : [0.30858] Val F1 : [0.57529] Val acc : [0.91869]



100%|██████████| 299/299 [06:32<00:00,  1.31s/it]
100%|██████████| 75/75 [00:37<00:00,  2.02it/s]

Epoch : [39] Train Loss : [0.24558] Val Loss : [0.24936] Val F1 : [0.61691] Val acc : [0.91982]



100%|██████████| 299/299 [06:36<00:00,  1.33s/it]
100%|██████████| 75/75 [00:38<00:00,  1.96it/s]

Epoch : [40] Train Loss : [0.24022] Val Loss : [0.26092] Val F1 : [0.57776] Val acc : [0.91801]



100%|██████████| 299/299 [06:32<00:00,  1.31s/it]
100%|██████████| 75/75 [00:35<00:00,  2.10it/s]

Epoch : [41] Train Loss : [0.23437] Val Loss : [0.34324] Val F1 : [0.61482] Val acc : [0.92003]



100%|██████████| 299/299 [06:40<00:00,  1.34s/it]
100%|██████████| 75/75 [00:39<00:00,  1.88it/s]

Epoch : [42] Train Loss : [0.23125] Val Loss : [0.25103] Val F1 : [0.61635] Val acc : [0.91979]



100%|██████████| 299/299 [06:41<00:00,  1.34s/it]
100%|██████████| 75/75 [00:41<00:00,  1.80it/s]

Epoch : [43] Train Loss : [0.22564] Val Loss : [0.24615] Val F1 : [0.54227] Val acc : [0.91617]



100%|██████████| 299/299 [06:40<00:00,  1.34s/it]
100%|██████████| 75/75 [00:42<00:00,  1.76it/s]

Epoch : [44] Train Loss : [0.22289] Val Loss : [0.24024] Val F1 : [0.56171] Val acc : [0.91754]



100%|██████████| 299/299 [06:36<00:00,  1.32s/it]
100%|██████████| 75/75 [00:39<00:00,  1.91it/s]


Epoch : [45] Train Loss : [0.22120] Val Loss : [0.32207] Val F1 : [0.71968] Val acc : [0.90574]


OSError: [Errno 28] No space left on device

In [16]:
test_df = pd.read_csv('data/test.csv')
test_epitope_ids_list, test_epitope_mask_list, test_protein_features, test_label_list= get_preprocessing('test', test_df, tokenizer)

120944it [20:07, 100.18it/s]

test dataframe preprocessing was done.





In [17]:
test_dataset = CustomDataset(test_epitope_ids_list, test_epitope_mask_list, test_protein_features, None)
test_loader = DataLoader(test_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=False, num_workers=CFG['NUM_WORKERS'])

In [18]:
model = TransformerModel()
best_checkpoint = torch.load('./epitope_best_model.pth', map_location=device)
model.load_state_dict(best_checkpoint)
model.eval()
model.to(device)

Some weights of the model checkpoint at model/bert were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


TransformerModel(
  (transformer): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30, 768, padding_idx=0)
      (position_embeddings): Embedding(72, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affin

In [19]:
def inference(model, test_loader, device):
    model.eval()
    pred_proba_label = []
    with torch.no_grad():
        for epitope_ids_list, epitope_mask_list, protein_features in tqdm(iter(test_loader)):
            epitope_ids_list = epitope_ids_list.to(device)
            epitope_mask_list = epitope_mask_list.to(device)

            protein_features = protein_features.to(device)
            
            model_pred = model(epitope_ids_list, epitope_mask_list, protein_features)
            model_pred = torch.sigmoid(model_pred).to('cpu')
            
            pred_proba_label += model_pred.tolist()
    
    pred_label = np.where(np.array(pred_proba_label)>CFG['THRESHOLD'], 1, 0)
    return pred_label

In [20]:
preds = inference(model, test_loader, device)
submit = pd.read_csv('data/sample_submission.csv')
submit['label'] = preds

submit.to_csv('submission/epitope_submission.csv', index=False)
print('Done.')

100%|██████████| 237/237 [02:30<00:00,  1.57it/s]


Done.
