In [1]:
import random
import pandas as pd
import numpy as np
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn import metrics
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm, trange

from tqdm import tqdm
from sklearn.metrics import f1_score
from transformers import BertConfig, BertModel, BertTokenizer, BertForPreTraining


In [2]:
device = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')

In [3]:
CFG = {
    'NUM_WORKERS':32,
    'ANTIGEN_WINDOW':128,
    'ANTIGEN_MAX_LEN':128, # ANTIGEN_WINDOW와 ANTIGEN_MAX_LEN은 같아야합니다.
    'EPITOPE_MAX_LEN':256,
    'EPOCHS':50,
    'LEARNING_RATE':1e-4,
    'BATCH_SIZE':128,
    'THRESHOLD':0.5,
    'SEED':41
}

In [4]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG['SEED']) # Seed 고정

In [5]:
tokenizer = BertTokenizer.from_pretrained('model/tokenizer')

In [6]:
tokenizer.vocab

OrderedDict([('[PAD]', 0),
             ('[UNK]', 1),
             ('[CLS]', 2),
             ('[SEP]', 3),
             ('[MASK]', 4),
             ('L', 5),
             ('A', 6),
             ('G', 7),
             ('V', 8),
             ('E', 9),
             ('S', 10),
             ('I', 11),
             ('K', 12),
             ('R', 13),
             ('D', 14),
             ('T', 15),
             ('P', 16),
             ('N', 17),
             ('Q', 18),
             ('F', 19),
             ('Y', 20),
             ('M', 21),
             ('H', 22),
             ('C', 23),
             ('W', 24),
             ('X', 25),
             ('U', 26),
             ('B', 27),
             ('Z', 28),
             ('O', 29)])

In [7]:
def seqtoinput(seq):
    for j in range(len(seq)-1):
        seq = seq[:j+j+1]+ ' ' + seq[j+j+1:]
    return seq
    

def get_preprocessing(data_type, new_df, tokenizer):
    epitope_ids_list = []
    epitope_mask_list = []
    
    antigen_ids_list = []
    antigen_mask_list = []
        
    for epitope, antigen, s_p, e_p in tqdm(zip(new_df['epitope_seq'], new_df['antigen_seq'], new_df['start_position'], new_df['end_position'])):        
        # Left antigen : [start_position-WINDOW : start_position]
        # Right antigen : [end_position : end_position+WINDOW]
        mean = int((s_p+e_p)/2)
        start_position = mean-CFG['ANTIGEN_WINDOW']-1
        end_position = mean+CFG['ANTIGEN_WINDOW']
        if start_position < 0:
            start_position = 0
        if end_position > len(antigen):
            end_position = len(antigen)
        
        antigen = antigen[int(start_position):int(end_position)]
        # left / right antigen sequence 추출

        if CFG['EPITOPE_MAX_LEN']<len(epitope):
            epitope = epitope[:CFG['EPITOPE_MAX_LEN']]
        else:
            epitope = epitope[:]
        
        antigen = seqtoinput(antigen)
        epitope = seqtoinput(epitope)
        
        
        antigen_input = tokenizer(antigen, add_special_tokens=True, pad_to_max_length=True, max_length = 512)
        antigen_ids = antigen_input['input_ids']
        antigen_mask = antigen_input['attention_mask']
        
        
        epitope_input = tokenizer(epitope, add_special_tokens=True, pad_to_max_length=True, max_length = 256)
        epitope_ids = epitope_input['input_ids']
        epitope_mask = epitope_input['attention_mask']
        
        
        epitope_ids_list.append(epitope_ids)
        epitope_mask_list.append(epitope_mask)
        
        antigen_ids_list.append(antigen_ids)
        antigen_mask_list.append(antigen_mask)

    
    label_list = None
    if data_type != 'test':
        label_list = []
        for label in new_df['label']:
            label_list.append(label)
    print(f'{data_type} dataframe preprocessing was done.')
    return epitope_ids_list, epitope_mask_list, antigen_ids_list, antigen_mask_list, label_list

In [None]:
train = pd.read_csv('data/train.csv')
test = pd.read_csv('data/test.csv')

train, val = train_test_split(train, train_size=0.9, random_state=12)

train_epitope_ids_list, train_epitope_mask_list, train_antigen_ids_list, train_antigen_mask_list, train_label_list = get_preprocessing('train', train, tokenizer)
val_epitope_ids_list, val_epitope_mask_list, val_antigen_ids_list, val_antigen_mask_list, val_label_list = get_preprocessing('val', val, tokenizer)


0it [00:00, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
39309it [01:12, 560.88it/s]

In [None]:
class CustomDataset(Dataset):
    def __init__(self, epitope_ids_list, epitope_mask_list, antigen_ids_list, antigen_mask_list, label_list):
        self.epitope_ids_list = epitope_ids_list
        self.epitope_mask_list = epitope_mask_list
        self.antigen_ids_list = antigen_ids_list
        self.antigen_mask_list = antigen_mask_list

        self.label_list = label_list
        
    def __getitem__(self, index):
        self.epitope_ids = self.epitope_ids_list[index]
        self.epitope_mask = self.epitope_mask_list[index]
        
        self.left_antigen_ids = self.left_antigen_ids_list[index]
        self.left_antigen_mask = self.left_antigen_mask_list[index]
        
        self.right_antigen_ids = self.right_antigen_ids_list[index]
        self.right_antigen_mask = self.right_antigen_mask_list[index]
        
        if self.label_list is not None:
            self.label = self.label_list[index]
            return torch.tensor(self.epitope_ids), torch.tensor(self.epitope_mask), torch.tensor(self.left_antigen_ids), torch.tensor(self.left_antigen_mask), torch.tensor(self.right_antigen_ids), torch.tensor(self.right_antigen_mask), self.label
        else:
            return torch.tensor(self.epitope_ids), torch.tensor(self.epitope_mask), torch.tensor(self.left_antigen_ids), torch.tensor(self.left_antigen_mask), torch.tensor(self.right_antigen_ids), torch.tensor(self.right_antigen_mask)
        
    def __len__(self):
        return len(self.epitope_ids_list)

In [None]:
train_dataset = CustomDataset(train_epitope_ids_list, train_epitope_mask_list, train_antigen_ids_list, train_antigen_mask_list, train_label_list)
train_loader = DataLoader(train_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=True, num_workers=CFG['NUM_WORKERS'])

val_dataset = CustomDataset(val_epitope_ids_list, val_epitope_mask_list, val_antigen_ids_list, val_antigen_mask_list, val_label_list)
val_loader = DataLoader(val_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=False, num_workers=CFG['NUM_WORKERS'])

In [None]:
config = BertConfig(
    vocab_size=30, # default는 영어 기준이므로 내가 만든 vocab size에 맞게 수정해줘야 함
    hidden_size=1024,
    num_hidden_layers=3,    # layer num
    num_attention_heads=8,    # transformer attention head number
    intermediate_size=4096,   # transformer 내에 있는 feed-forward network의 dimension size
    hidden_act="gelu",
    hidden_dropout_prob=0.0,
    attention_probs_dropout_prob=0.0,
    max_position_embeddings=500,    # embedding size 최대 몇 token까지 input으로 사용할 것인지 지정
    type_vocab_size=2,    # token type ids의 범위 (BERT는 segmentA, segmentB로 2종류)
)

In [None]:
pre = BertForPreTraining(config=config)
pre.save_pretrained('model/transformer')

In [None]:
class TransformerModel(nn.Module):
    def __init__(self,
                 epitope_length=CFG['EPITOPE_MAX_LEN'],
                 epitope_emb_node=1024,
                 epitope_hidden_dim=1024,
                 antigen_length=CFG['ANTIGEN_MAX_LEN'],
                 antigen_emb_node=1024,
                 antigen_hidden_dim=1024,
                 pretrained_model='model/transformer'
                ):
        super(BaseModel, self).__init__()              
        # Transformer                
        self.epitope_transformer = BertModel.from_pretrained(pretrained_model)
        
        self.antigen_transformer = BertMoel.from_pretrained(pretrained_model)
        
        in_channels = epitope_hidden_dim+antigen_hidden_dim
            
        self.classifier = nn.Sequential(
            nn.LeakyReLU(True),
            nn.BatchNorm1d(in_channels),
            nn.Linear(in_channels, in_channels//4),
            nn.LeakyReLU(True),
            nn.BatchNorm1d(in_channels//4),
            nn.Linear(in_channels//4, 1)
        )
        
    def forward(self, epitope_x1, epitope_x2, antigen_x1, antigen_x2):
        BATCH_SIZE = epitope_x1.size(0)
        # Get Embedding Vector
        epitope_x = self.epitope_transformer(input_ids=epitope_x1, attention_mask=epitope_x2)[0]
        
        antigen_x = self.antigen_transformer(input_ids=antigen_x1, attention_mask=antigen_x2)[0]
                
        
        # LSTM
        epitope_hidden = epitope_x[:, 0, :]

        antigen_hidden = antigen_x[:, 0, :]
                
        # Feature Concat -> Binary Classifier
        x = torch.cat([epitope_hidden, antigen_hidden], axis=-1)
        x = self.classifier(x).view(-1)
        return x

In [None]:
def train(model, optimizer, train_loader, val_loader, scheduler, device):
    model.to(device)
    criterion = nn.BCEWithLogitsLoss().to(device) 
    
    best_val_f1 = 0
    for epoch in range(1, CFG['EPOCHS']+1):
        model.train()
        train_loss = []
        for epitope_ids_list, epitope_mask_list, antigen_ids_list, antigen_mask_list, label_list in tqdm(iter(train_loader)):
            epitope_ids_list = epitope_ids_list.to(device)
            epitope_mask_list = epitope_mask_list.to(device)

            antigen_ids_list = antigen_ids_list.to(device)
            antigen_mask_list = antigen_mask_list.to(device)

            label = label.float().to(device)
            
            optimizer.zero_grad()
            
            output = model(epitope_ids_list, epitope_mask_list, antigen_ids_list, antigen_mask_list)
            loss = criterion(output, label)
            
            loss.backward()
            optimizer.step()
            
            train_loss.append(loss.item())
            
            if scheduler is not None:
                scheduler.step()
                    
        val_loss, val_f1 = validation(model, val_loader, criterion, device)
        print(f'Epoch : [{epoch}] Train Loss : [{np.mean(train_loss):.5f}] Val Loss : [{val_loss:.5f}] Val F1 : [{val_f1:.5f}]')
        
        if best_val_f1 < val_f1:
            best_val_f1 = val_f1
            torch.save(model.state_dict(), './antigen_transformer_best_model.pth', _use_new_zipfile_serialization=False)
            print('Model Saved.')
    return best_val_f1

In [None]:
def validation(model, val_loader, criterion, device):
    model.eval()
    pred_proba_label = []
    true_label = []
    val_loss = []
    with torch.no_grad():
        for epitope_ids_list, epitope_mask_list, antigen_ids_list, antigen_mask_list, label_list in tqdm(iter(val_loader)):
            epitope_ids_list = epitope_ids_list.to(device)
            epitope_mask_list = epitope_mask_list.to(device)

            antigen_ids_list = antigen_ids_list.to(device)
            antigen_mask_list = antigen_mask_list.to(device)

            label = label.float().to(device)
            
            model_pred = model(epitope_ids_list, epitope_mask_list, antigen_ids_list, antigen_mask_list)
            loss = criterion(model_pred, label)
            model_pred = torch.sigmoid(model_pred).to('cpu')
            
            pred_proba_label += model_pred.tolist()
            true_label += label.to('cpu').tolist()
            
            val_loss.append(loss.item())
            
    pred_label = np.where(np.array(pred_proba_label)>CFG['THRESHOLD'], 1, 0)
    val_f1 = f1_score(true_label, pred_label, average='macro')
    return np.mean(val_loss), val_f1

In [None]:
model = TransformerModel()
model = nn.DataParallel(model, device_ids=[1, 2, 3, 4, 5])
model.eval()
optimizer = torch.optim.Adam(params = model.parameters(), lr = CFG["LEARNING_RATE"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000*CFG['EPOCHS'], eta_min=0)

best_score = train(model, optimizer, train_loader, val_loader, scheduler, device)
print(f'Best Validation F1 Score : [{best_score:.5f}]')

In [None]:
test_df = pd.read_csv('data/test.csv')
test_epitope_ids_list, test_epitope_mask_list, test_left_antigen_ids_list, test_left_antigen_mask_list, test_right_antigen_ids_list, test_right_antigen_mask_list = get_preprocessing('test', test_df, tokenizer)

In [None]:
test_dataset = CustomDataset(test_epitope_ids_list, test_epitope_mask_list, test_left_antigen_ids_list, test_left_antigen_mask_list, test_right_antigen_ids_list, test_right_antigen_mask_list, None)
test_loader = DataLoader(test_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=False, num_workers=CFG['NUM_WORKERS'])

In [None]:
model = TransformerModel()
best_checkpoint = torch.load('./antigen_transformer_best_model.pth')
model.load_state_dict(best_checkpoint)
model.eval()
model.to(device)

In [None]:
def inference(model, test_loader, device):
    model.eval()
    pred_proba_label = []
    with torch.no_grad():
        for epitope_ids_list, epitope_mask_list, antigen_ids_list, antigen_mask_list in tqdm(iter(test_loader)):
            epitope_ids_list = epitope_ids_list.to(device)
            epitope_mask_list = epitope_mask_list.to(device)

            antigen_ids_list = antigen_ids_list.to(device)
            antigen_mask_list = antigen_mask_list.to(device)
            
            model_pred = model(epitope_ids_list, epitope_mask_list, antigen_ids_list, antigen_mask_list)
            model_pred = torch.sigmoid(model_pred).to('cpu')
            
            pred_proba_label += model_pred.tolist()
    
    pred_label = np.where(np.array(pred_proba_label)>CFG['THRESHOLD'], 1, 0)
    return pred_label

In [None]:
preds = inference(model, test_loader, device)

In [None]:
submit = pd.read_csv('data/sample_submission.csv')
submit['label'] = preds

In [None]:
submit.to_csv('submission/submit4.csv', index=False)
print('Done.')