In [1]:
import random
import pandas as pd
import numpy as np
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn import metrics
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm, trange

from tqdm import tqdm
from sklearn.metrics import f1_score
from transformers import BertConfig, BertModel, BertTokenizer, BertForPreTraining


In [2]:
device = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')

In [3]:
CFG = {
    'NUM_WORKERS':4,
    'ANTIGEN_WINDOW':256,
    'ANTIGEN_MAX_LEN':256, # ANTIGEN_WINDOW와 ANTIGEN_MAX_LEN은 같아야합니다.
    'EPITOPE_MAX_LEN':256,
    'EPOCHS':50,
    'LEARNING_RATE':1e-4,
    'BATCH_SIZE':128,
    'THRESHOLD':0.7,   # 기본적으로 0.5로 사용하지만 data impalance가 심할 경우 더 큰 값을 사용하기도 한다.
    'SEED':41
}

In [4]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG['SEED']) # Seed 고정

In [5]:
tokenizer = BertTokenizer.from_pretrained('model/tokenizer')

In [6]:
tokenizer.vocab

OrderedDict([('[PAD]', 0),
             ('[UNK]', 1),
             ('[CLS]', 2),
             ('[SEP]', 3),
             ('[MASK]', 4),
             ('L', 5),
             ('A', 6),
             ('G', 7),
             ('V', 8),
             ('E', 9),
             ('S', 10),
             ('I', 11),
             ('K', 12),
             ('R', 13),
             ('D', 14),
             ('T', 15),
             ('P', 16),
             ('N', 17),
             ('Q', 18),
             ('F', 19),
             ('Y', 20),
             ('M', 21),
             ('H', 22),
             ('C', 23),
             ('W', 24),
             ('X', 25),
             ('U', 26),
             ('B', 27),
             ('Z', 28),
             ('O', 29)])

In [7]:
def seqtoinput(seq):
    for j in range(len(seq)-1):
        seq = seq[:j+j+1]+ ' ' + seq[j+j+1:]
    return seq
    

def get_preprocessing(data_type, new_df, tokenizer):
    epitope_ids_list = []
    epitope_mask_list = []
    
    left_antigen_ids_list = []
    left_antigen_mask_list = []
    
    right_antigen_ids_list = []
    right_antigen_mask_list = []
        
    for epitope, antigen, s_p, e_p in tqdm(zip(new_df['epitope_seq'], new_df['antigen_seq'], new_df['start_position'], new_df['end_position'])):        
        # Left antigen : [start_position-WINDOW : start_position]
        # Right antigen : [end_position : end_position+WINDOW]

        start_position = s_p-CFG['ANTIGEN_WINDOW']-1
        end_position = e_p+CFG['ANTIGEN_WINDOW']
        if start_position < 0:
            start_position = 0
        if end_position > len(antigen):
            end_position = len(antigen)
        
        # left / right antigen sequence 추출
        left_antigen = antigen[int(start_position) : int(s_p)-1]
        
        right_antigen = antigen[int(e_p) : int(end_position)]

        if CFG['EPITOPE_MAX_LEN']<len(epitope):
            epitope = epitope[:CFG['EPITOPE_MAX_LEN']]
        else:
            epitope = epitope[:]
        
        left_antigen = seqtoinput(left_antigen)
        right_antigen = seqtoinput(left_antigen)
        epitope = seqtoinput(epitope)
        
        if len(left_antigen) == 0:
            left_antigen = '[PAD]'
            
        if len(right_antigen) == 0:
            right_antigen = '[PAD]'
        
        left_antigen_input = tokenizer(left_antigen, add_special_tokens=True, pad_to_max_length=True, max_length = 256)
        left_antigen_ids = left_antigen_input['input_ids']
        left_antigen_mask = left_antigen_input['attention_mask']
        
        right_antigen_input = tokenizer(right_antigen, add_special_tokens=True, pad_to_max_length=True, max_length = 256)
        right_antigen_ids = right_antigen_input['input_ids']
        right_antigen_mask = right_antigen_input['attention_mask']
        
        epitope_input = tokenizer(epitope, add_special_tokens=True, pad_to_max_length=True, max_length = 256)
        epitope_ids = epitope_input['input_ids']
        epitope_mask = epitope_input['attention_mask']
        
        
        epitope_ids_list.append(epitope_ids)
        epitope_mask_list.append(epitope_mask)
        
        left_antigen_ids_list.append(left_antigen_ids)
        left_antigen_mask_list.append(left_antigen_mask)
        
        right_antigen_ids_list.append(right_antigen_ids)
        right_antigen_mask_list.append(right_antigen_mask)
    
    label_list = None
    if data_type != 'test':
        label_list = []
        for label in new_df['label']:
            label_list.append(label)
    print(f'{data_type} dataframe preprocessing was done.')
    return epitope_ids_list, epitope_mask_list, left_antigen_ids_list, left_antigen_mask_list, right_antigen_ids_list, right_antigen_mask_list, label_list

In [8]:
train = pd.read_csv('data/train.csv')
test = pd.read_csv('data/test.csv')

train, val = train_test_split(train, train_size=0.8, random_state=12)

train_epitope_ids_list, train_epitope_mask_list, train_left_antigen_ids_list, train_left_antigen_mask_list, train_right_antigen_ids_list, train_right_antigen_mask_list, train_label_list = get_preprocessing('train', train, tokenizer)
val_epitope_ids_list, val_epitope_mask_list, val_left_antigen_ids_list, val_left_antigen_mask_list, val_right_antigen_ids_list, val_right_antigen_mask_list, val_label_list = get_preprocessing('val', val, tokenizer)


0it [00:00, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
152648it [07:55, 320.84it/s]


train dataframe preprocessing was done.


38163it [01:58, 320.85it/s]

val dataframe preprocessing was done.





In [9]:
class CustomDataset(Dataset):
    def __init__(self, epitope_ids_list, epitope_mask_list, left_antigen_ids_list, left_antigen_mask_list, right_antigen_ids_list, right_antigen_mask_list, label_list):
        self.epitope_ids_list = epitope_ids_list
        self.epitope_mask_list = epitope_mask_list
        self.left_antigen_ids_list = left_antigen_ids_list
        self.left_antigen_mask_list = left_antigen_mask_list
        self.right_antigen_ids_list = right_antigen_ids_list
        self.right_antigen_mask_list = right_antigen_mask_list
        self.label_list = label_list
        
    def __getitem__(self, index):
        self.epitope_ids = self.epitope_ids_list[index]
        self.epitope_mask = self.epitope_mask_list[index]
        
        self.left_antigen_ids = self.left_antigen_ids_list[index]
        self.left_antigen_mask = self.left_antigen_mask_list[index]
        
        self.right_antigen_ids = self.right_antigen_ids_list[index]
        self.right_antigen_mask = self.right_antigen_mask_list[index]
        
        if self.label_list is not None:
            self.label = self.label_list[index]
            return torch.tensor(self.epitope_ids), torch.tensor(self.epitope_mask), torch.tensor(self.left_antigen_ids), torch.tensor(self.left_antigen_mask), torch.tensor(self.right_antigen_ids), torch.tensor(self.right_antigen_mask), self.label
        else:
            return torch.tensor(self.epitope_ids), torch.tensor(self.epitope_mask), torch.tensor(self.left_antigen_ids), torch.tensor(self.left_antigen_mask), torch.tensor(self.right_antigen_ids), torch.tensor(self.right_antigen_mask)
        
    def __len__(self):
        return len(self.epitope_ids_list)

In [10]:
train_dataset = CustomDataset(train_epitope_ids_list, train_epitope_mask_list, train_left_antigen_ids_list, train_left_antigen_mask_list, train_right_antigen_ids_list, train_right_antigen_mask_list, train_label_list)
train_loader = DataLoader(train_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=True, num_workers=CFG['NUM_WORKERS'])

val_dataset = CustomDataset(val_epitope_ids_list, val_epitope_mask_list, val_left_antigen_ids_list, val_left_antigen_mask_list, val_right_antigen_ids_list, val_right_antigen_mask_list, val_label_list)
val_loader = DataLoader(val_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=False, num_workers=CFG['NUM_WORKERS'])

In [11]:
config = BertConfig(
    vocab_size=30, # default는 영어 기준이므로 내가 만든 vocab size에 맞게 수정해줘야 함
    hidden_size=1024,
    num_hidden_layers=3,    # layer num
    num_attention_heads=8,    # transformer attention head number
    intermediate_size=4096,   # transformer 내에 있는 feed-forward network의 dimension size
    hidden_act="gelu",
    hidden_dropout_prob=0.0,
    attention_probs_dropout_prob=0.0,
    max_position_embeddings=500,    # embedding size 최대 몇 token까지 input으로 사용할 것인지 지정
    type_vocab_size=2,    # token type ids의 범위 (BERT는 segmentA, segmentB로 2종류)
)

In [12]:
pre = BertForPreTraining(config=config)
pre.save_pretrained('model/transformer')

In [13]:
class TransformerModel(nn.Module):
    def __init__(self,
                 epitope_length=CFG['EPITOPE_MAX_LEN'],
                 epitope_emb_node=1024,
                 epitope_hidden_dim=1024,
                 left_antigen_length=CFG['ANTIGEN_MAX_LEN'],
                 left_antigen_emb_node=1024,
                 left_antigen_hidden_dim=1024,
                 right_antigen_length=CFG['ANTIGEN_MAX_LEN'],
                 right_antigen_emb_node=1024,
                 right_antigen_hidden_dim=1024,
                 pretrained_model='model/transformer'
                ):
        super(TransformerModel, self).__init__()              
        # Transformer                
        self.epitope_transformer = BertModel.from_pretrained(pretrained_model)
        
        self.left_antigen_transformer = BertModel.from_pretrained(pretrained_model)
        
        self.right_antigen_transformer = BertModel.from_pretrained(pretrained_model)

        in_channels = epitope_hidden_dim+left_antigen_hidden_dim+right_antigen_hidden_dim
            
        self.classifier = nn.Sequential(
            nn.LeakyReLU(True),
            nn.BatchNorm1d(in_channels),
            nn.Linear(in_channels, in_channels//4),
            nn.LeakyReLU(True),
            nn.BatchNorm1d(in_channels//4),
            nn.Linear(in_channels//4, 1)
        )
        
    def forward(self, epitope_x1, epitope_x2, left_antigen_x1, left_antigen_x2, right_antigen_x1, right_antigen_x2):
        BATCH_SIZE = epitope_x1.size(0)
        # Get Embedding Vector
        epitope_x = self.epitope_transformer(input_ids=epitope_x1, attention_mask=epitope_x2)[0]
        
        left_antigen_x = self.left_antigen_transformer(input_ids=left_antigen_x1, attention_mask=left_antigen_x2)[0]
        
        right_antigen_x = self.right_antigen_transformer(input_ids=right_antigen_x1, attention_mask=right_antigen_x2)[0]
        
        
        # LSTM
        epitope_hidden = epitope_x[:, 0, :]

        left_antigen_hidden = left_antigen_x[:, 0, :]
        
        right_antigen_hidden = right_antigen_x[:, 0, :]
        
        # Feature Concat -> Binary Classifier
        x = torch.cat([epitope_hidden, left_antigen_hidden, right_antigen_hidden], axis=-1)
        x = self.classifier(x).view(-1)
        return x

In [14]:
def train(model, optimizer, train_loader, val_loader, scheduler, device):
    model.to(device)
    criterion = nn.BCEWithLogitsLoss().to(device) 
    
    best_val_f1 = 0
    for epoch in range(1, CFG['EPOCHS']+1):
        model.train()
        train_loss = []
        for epitope_ids_list, epitope_mask_list, left_antigen_ids_list, left_antigen_mask_list, right_antigen_ids_list, right_antigen_mask_list, label in tqdm(iter(train_loader)):
            epitope_ids_list = epitope_ids_list.to(device)
            epitope_mask_list = epitope_mask_list.to(device)

            left_antigen_ids_list = left_antigen_ids_list.to(device)
            left_antigen_mask_list = left_antigen_mask_list.to(device)

            right_antigen_ids_list = right_antigen_ids_list.to(device)
            right_antigen_mask_list = right_antigen_mask_list.to(device)

            label = label.float().to(device)
            
            optimizer.zero_grad()
            
            output = model(epitope_ids_list, epitope_mask_list, left_antigen_ids_list, left_antigen_mask_list, right_antigen_ids_list, right_antigen_mask_list)
            loss = criterion(output, label)
            
            loss.backward()
            optimizer.step()
            
            train_loss.append(loss.item())
            
            if scheduler is not None:
                scheduler.step()
                    
        val_loss, val_f1 = validation(model, val_loader, criterion, device)
        print(f'Epoch : [{epoch}] Train Loss : [{np.mean(train_loss):.5f}] Val Loss : [{val_loss:.5f}] Val F1 : [{val_f1:.5f}]')
        
        if best_val_f1 < val_f1:
            best_val_f1 = val_f1
            torch.save(model.module.state_dict(), './transformer_best_model_0.7.pth', _use_new_zipfile_serialization=False)
            print('Model Saved.')
    return best_val_f1

In [15]:
def validation(model, val_loader, criterion, device):
    model.eval()
    pred_proba_label = []
    true_label = []
    val_loss = []
    with torch.no_grad():
        for epitope_ids_list, epitope_mask_list, left_antigen_ids_list, left_antigen_mask_list, right_antigen_ids_list, right_antigen_mask_list, label in tqdm(iter(val_loader)):
            epitope_ids_list = epitope_ids_list.to(device)
            epitope_mask_list = epitope_mask_list.to(device)

            left_antigen_ids_list = left_antigen_ids_list.to(device)
            left_antigen_mask_list = left_antigen_mask_list.to(device)

            right_antigen_ids_list = right_antigen_ids_list.to(device)
            right_antigen_mask_list = right_antigen_mask_list.to(device)

            label = label.float().to(device)
            
            model_pred = model(epitope_ids_list, epitope_mask_list, left_antigen_ids_list, left_antigen_mask_list, right_antigen_ids_list, right_antigen_mask_list)
            loss = criterion(model_pred, label)
            model_pred = torch.sigmoid(model_pred).to('cpu')
            
            pred_proba_label += model_pred.tolist()
            true_label += label.to('cpu').tolist()
            
            val_loss.append(loss.item())
            
    pred_label = np.where(np.array(pred_proba_label)>CFG['THRESHOLD'], 1, 0)
    val_f1 = f1_score(true_label, pred_label, average='macro')
    return np.mean(val_loss), val_f1

In [16]:
model = TransformerModel()
model = nn.DataParallel(model, device_ids=[1, 2, 3, 4, 5])
model.eval()
optimizer = torch.optim.Adam(params = model.parameters(), lr = CFG["LEARNING_RATE"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10*CFG['EPOCHS'], eta_min=0)

best_score = train(model, optimizer, train_loader, val_loader, scheduler, device)
print(f'Best Validation F1 Score : [{best_score:.5f}]')

Some weights of the model checkpoint at model/transformer were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at model/transformer were not used when initializing BertModel: ['cls.seq_relationship

Epoch : [1] Train Loss : [0.48262] Val Loss : [0.29307] Val F1 : [0.67468]
Model Saved.


100%|██████████| 1193/1193 [31:53<00:00,  1.60s/it]
100%|██████████| 299/299 [03:49<00:00,  1.30it/s]

Epoch : [2] Train Loss : [0.23707] Val Loss : [0.20731] Val F1 : [0.66219]



100%|██████████| 1193/1193 [32:04<00:00,  1.61s/it]
100%|██████████| 299/299 [03:57<00:00,  1.26it/s]

Epoch : [3] Train Loss : [0.20389] Val Loss : [0.19765] Val F1 : [0.65528]



100%|██████████| 1193/1193 [31:57<00:00,  1.61s/it]
100%|██████████| 299/299 [03:54<00:00,  1.28it/s]


Epoch : [4] Train Loss : [0.19619] Val Loss : [0.20138] Val F1 : [0.68231]
Model Saved.


100%|██████████| 1193/1193 [32:05<00:00,  1.61s/it]
100%|██████████| 299/299 [03:55<00:00,  1.27it/s]

Epoch : [5] Train Loss : [0.19222] Val Loss : [0.72043] Val F1 : [0.63351]



100%|██████████| 1193/1193 [31:59<00:00,  1.61s/it]
100%|██████████| 299/299 [03:53<00:00,  1.28it/s]

Epoch : [6] Train Loss : [0.19247] Val Loss : [0.20771] Val F1 : [0.61318]



100%|██████████| 1193/1193 [31:57<00:00,  1.61s/it]
100%|██████████| 299/299 [03:40<00:00,  1.36it/s]

Epoch : [7] Train Loss : [0.18735] Val Loss : [0.19357] Val F1 : [0.67917]



100%|██████████| 1193/1193 [31:43<00:00,  1.60s/it]
100%|██████████| 299/299 [03:48<00:00,  1.31it/s]

Epoch : [8] Train Loss : [0.18054] Val Loss : [0.19042] Val F1 : [0.65354]



100%|██████████| 1193/1193 [31:48<00:00,  1.60s/it]
100%|██████████| 299/299 [03:52<00:00,  1.29it/s]

Epoch : [9] Train Loss : [0.18243] Val Loss : [0.19426] Val F1 : [0.63657]



100%|██████████| 1193/1193 [31:48<00:00,  1.60s/it]
100%|██████████| 299/299 [03:53<00:00,  1.28it/s]

Epoch : [10] Train Loss : [0.18787] Val Loss : [0.29320] Val F1 : [0.49975]



100%|██████████| 1193/1193 [31:37<00:00,  1.59s/it]
100%|██████████| 299/299 [03:52<00:00,  1.29it/s]

Epoch : [11] Train Loss : [0.18510] Val Loss : [0.21133] Val F1 : [0.63704]



100%|██████████| 1193/1193 [31:41<00:00,  1.59s/it]
100%|██████████| 299/299 [03:52<00:00,  1.29it/s]

Epoch : [12] Train Loss : [0.18575] Val Loss : [0.21055] Val F1 : [0.60759]



100%|██████████| 1193/1193 [31:34<00:00,  1.59s/it]
100%|██████████| 299/299 [03:33<00:00,  1.40it/s]

Epoch : [13] Train Loss : [0.17957] Val Loss : [0.19421] Val F1 : [0.67397]



100%|██████████| 1193/1193 [31:45<00:00,  1.60s/it]
100%|██████████| 299/299 [03:51<00:00,  1.29it/s]

Epoch : [14] Train Loss : [0.17924] Val Loss : [0.19519] Val F1 : [0.65369]



100%|██████████| 1193/1193 [31:42<00:00,  1.59s/it]
100%|██████████| 299/299 [03:52<00:00,  1.29it/s]

Epoch : [15] Train Loss : [0.17728] Val Loss : [0.20667] Val F1 : [0.62479]



100%|██████████| 1193/1193 [31:49<00:00,  1.60s/it]
100%|██████████| 299/299 [03:51<00:00,  1.29it/s]

Epoch : [16] Train Loss : [0.17575] Val Loss : [0.27172] Val F1 : [0.51648]



100%|██████████| 1193/1193 [31:43<00:00,  1.60s/it]
100%|██████████| 299/299 [03:53<00:00,  1.28it/s]

Epoch : [17] Train Loss : [0.17534] Val Loss : [0.19734] Val F1 : [0.64391]



100%|██████████| 1193/1193 [31:48<00:00,  1.60s/it]
100%|██████████| 299/299 [03:52<00:00,  1.29it/s]

Epoch : [18] Train Loss : [0.16924] Val Loss : [0.19327] Val F1 : [0.68117]



100%|██████████| 1193/1193 [31:41<00:00,  1.59s/it]
100%|██████████| 299/299 [03:33<00:00,  1.40it/s]


Epoch : [19] Train Loss : [0.16899] Val Loss : [0.19307] Val F1 : [0.69067]
Model Saved.


100%|██████████| 1193/1193 [31:29<00:00,  1.58s/it]
100%|██████████| 299/299 [03:50<00:00,  1.30it/s]


Epoch : [20] Train Loss : [0.17044] Val Loss : [0.24252] Val F1 : [0.74798]
Model Saved.


100%|██████████| 1193/1193 [31:23<00:00,  1.58s/it]
100%|██████████| 299/299 [03:53<00:00,  1.28it/s]

Epoch : [21] Train Loss : [0.17560] Val Loss : [0.29796] Val F1 : [0.57011]



100%|██████████| 1193/1193 [31:25<00:00,  1.58s/it]
100%|██████████| 299/299 [03:49<00:00,  1.30it/s]

Epoch : [22] Train Loss : [0.17135] Val Loss : [0.24256] Val F1 : [0.72384]



100%|██████████| 1193/1193 [31:12<00:00,  1.57s/it]
100%|██████████| 299/299 [03:51<00:00,  1.29it/s]

Epoch : [23] Train Loss : [0.16783] Val Loss : [0.19440] Val F1 : [0.67562]



100%|██████████| 1193/1193 [31:18<00:00,  1.57s/it]
100%|██████████| 299/299 [03:45<00:00,  1.32it/s]

Epoch : [24] Train Loss : [0.16936] Val Loss : [0.19621] Val F1 : [0.67884]



100%|██████████| 1193/1193 [31:15<00:00,  1.57s/it]
100%|██████████| 299/299 [03:38<00:00,  1.37it/s]

Epoch : [25] Train Loss : [0.16979] Val Loss : [0.20842] Val F1 : [0.73945]



100%|██████████| 1193/1193 [31:30<00:00,  1.58s/it]
100%|██████████| 299/299 [03:50<00:00,  1.30it/s]

Epoch : [26] Train Loss : [0.16935] Val Loss : [0.30247] Val F1 : [0.51627]



100%|██████████| 1193/1193 [31:25<00:00,  1.58s/it]
100%|██████████| 299/299 [03:49<00:00,  1.30it/s]

Epoch : [27] Train Loss : [0.16632] Val Loss : [0.24624] Val F1 : [0.57557]



100%|██████████| 1193/1193 [31:31<00:00,  1.59s/it]
100%|██████████| 299/299 [03:51<00:00,  1.29it/s]

Epoch : [28] Train Loss : [0.16311] Val Loss : [0.19452] Val F1 : [0.66530]



100%|██████████| 1193/1193 [31:31<00:00,  1.59s/it]
100%|██████████| 299/299 [03:52<00:00,  1.29it/s]

Epoch : [29] Train Loss : [0.16032] Val Loss : [0.19539] Val F1 : [0.69874]



100%|██████████| 1193/1193 [31:38<00:00,  1.59s/it]
100%|██████████| 299/299 [03:41<00:00,  1.35it/s]

Epoch : [30] Train Loss : [0.16102] Val Loss : [0.20987] Val F1 : [0.63945]



100%|██████████| 1193/1193 [31:32<00:00,  1.59s/it]
100%|██████████| 299/299 [03:47<00:00,  1.31it/s]

Epoch : [31] Train Loss : [0.15978] Val Loss : [0.24463] Val F1 : [0.54966]



 21%|██        | 245/1193 [06:37<25:38,  1.62s/it]


KeyboardInterrupt: 

In [None]:
test_df = pd.read_csv('data/test.csv')
test_epitope_ids_list, test_epitope_mask_list, test_left_antigen_ids_list, test_left_antigen_mask_list, test_right_antigen_ids_list, test_right_antigen_mask_list = get_preprocessing('test', test_df, tokenizer)

In [None]:
test_dataset = CustomDataset(test_epitope_ids_list, test_epitope_mask_list, test_left_antigen_ids_list, test_left_antigen_mask_list, test_right_antigen_ids_list, test_right_antigen_mask_list, None)
test_loader = DataLoader(test_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=False, num_workers=CFG['NUM_WORKERS'])

In [None]:
model = TransformerModel()
best_checkpoint = torch.load('./transformer_best_model_0.7.pth')
model.load_state_dict(best_checkpoint)
model.eval()
model.to(device)

In [None]:
def inference(model, test_loader, device):
    model.eval()
    pred_proba_label = []
    with torch.no_grad():
        for epitope_ids_list, epitope_mask_list, left_antigen_ids_list, left_antigen_mask_list, right_antigen_ids_list, right_antigen_mask_list in tqdm(iter(test_loader)):
            epitope_ids_list = epitope_ids_list.to(device)
            epitope_mask_list = epitope_mask_list.to(device)

            left_antigen_ids_list = left_antigen_ids_list.to(device)
            left_antigen_mask_list = left_antigen_mask_list.to(device)

            right_antigen_ids_list = right_antigen_ids_list.to(device)
            right_antigen_mask_list = right_antigen_mask_list.to(device)
            
            model_pred = model(epitope_ids_list, epitope_mask_list, left_antigen_ids_list, left_antigen_mask_list, right_antigen_ids_list, right_antigen_mask_list)
            model_pred = torch.sigmoid(model_pred).to('cpu')
            
            pred_proba_label += model_pred.tolist()
    
    pred_label = np.where(np.array(pred_proba_label)>CFG['THRESHOLD'], 1, 0)
    return pred_label

In [None]:
preds = inference(model, test_loader, device)

In [None]:
submit = pd.read_csv('data/sample_submission.csv')
submit['label'] = preds

In [None]:
submit.to_csv('submission/submit3.csv', index=False)
print('Done.')