In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR

import torchtext

import os
import glob
import pickle
import numpy as np
import pandas as pd
from sklearn.utils import shuffle

from model.bert import BERT, MaskedLanguageModeling

from tqdm import tqdm


def load_dataset():
    print("load dataset ... ")
    with open("data/molecule_net/MoleculeNet_train.pickle", 'rb') as f:
        train_data = pickle.load(f)

    test_data  = train_data[:int(len(train_data) * 0.2)]
    train_data = train_data[int(len(train_data) * 0.2):]
    
    return train_data, test_data


def load_tokenizer():
    print("load tokenizer ... ")
    with open("data/molecule_net/MoleculeNet_tokenizer.pickle", "rb") as f:
        tokenizer = pickle.load(f)

    return tokenizer


def define_model(vocab_dim, seq_len, embedding_dim, device, num_head=4, num_layer=4):
    bert_base = BERT(vocab_dim=vocab_dim, seq_len=seq_len, embedding_dim=embedding_dim, pad_token_id=1, num_head=num_head, num_layer=num_layer).to(device)
    model     = MaskedLanguageModeling(bert_base, output_dim=vocab_dim, use_RNN=True).to(device)

    optimizer = optim.AdamW(model.parameters(), lr=1e-4, betas=[0.9, 0.999], weight_decay=0.01)
    # scheduler = CosineAnnealingLR(optimizer, T_max=10)
    scheduler = ReduceLROnPlateau(optimizer)
    criterion = nn.CrossEntropyLoss(ignore_index=1)

    return model, optimizer, scheduler, criterion


def check_trained_weights(output_path="output/MoleculeNet/*.tsv", trained_weight='weights/MoleculeNet_LM_best.pt'):
    start_epoch = 0
    if len(glob.glob(output_path)) != 0:
        print(f"load pretrained model : {trained_weight}")
        start_epoch = len(glob.glob(output_path))
        model.load_state_dict(torch.load(trained_weight))

    return start_epoch, model


In [7]:
class MoleculeLangaugeModelDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, seq_len=128, masking_rate=0.15):
        super(MoleculeLangaugeModelDataset, self).__init__()

        self.data          = data        
        self.tokenizer     = tokenizer
        self.vocab         = tokenizer.vocab
        self.seq_len       = seq_len
        self.masking_rate  = masking_rate
        
        self.cls_token_id  = self.tokenizer.vocab.stoi[self.tokenizer.init_token]
        self.sep_token_id  = self.tokenizer.vocab.stoi[self.tokenizer.eos_token]
        self.pad_token_id  = self.tokenizer.vocab.stoi[self.tokenizer.pad_token]
        self.mask_token_id = self.tokenizer.vocab.stoi[self.tokenizer.unk_token]
        
    def __getitem__(self, idx):
        try:
            target = self.tokenizer.numericalize(self.data[idx]).squeeze()

            if len(target) < self.seq_len - 2:
                pad_length = self.seq_len - len(target) - 2
            else:
                target = target[:self.seq_len-2]
                pad_length = 0

            masked_sent, masking_label = self.masking(target)

            # MLM
            train = torch.cat([
                torch.tensor([self.cls_token_id]), 
                masked_sent,
                torch.tensor([self.sep_token_id]),
                torch.tensor([self.pad_token_id] * pad_length)
            ]).long().contiguous()

            target = torch.cat([
                torch.tensor([self.cls_token_id]), 
                target,
                torch.tensor([self.sep_token_id]),
                torch.tensor([self.pad_token_id] * pad_length)
            ]).long().contiguous()

            masking_label = torch.cat([
                torch.zeros(1), 
                masking_label,
                torch.zeros(1),
                torch.zeros(pad_length)
            ])

            segment_embedding = torch.zeros(target.size(0))
        
            return train, target, segment_embedding, masking_label
        except:
            return None
        
    
    def __len__(self):
        return len(self.data)
    
    
    def __iter__(self):
        for x in self.data:
            yield x
            
    
    def get_vocab(self):
        return self.vocab

    
    # TODO mask 안에서 random 으로 바꿔주는 것 추가
    def masking(self, x):
        x             = torch.tensor(x).long().contiguous()
        masking_idx   = torch.randperm(x.size()[0])[:round(x.size()[0] * self.masking_rate) + 1]       
        masking_label = torch.zeros(x.size()[0])
        masking_label[masking_idx] = 1
        x             = x.masked_fill(masking_label.bool(), self.mask_token_id)
        
        return x, masking_label
    
def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

In [8]:
def generate_epoch_prediction_dataloader(data, seq_len, tokenizer, masking_rate, batch_size, collate_fn, shuffle=True, num_workers=8):    
    print("start prediction")
    dataset    = MoleculeLangaugeModelDataset(data=data, seq_len=seq_len, tokenizer=tokenizer, masking_rate=masking_rate)
    dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn)
    
    return dataloader

In [42]:
@torch.no_grad()
def predict(model, iterator, device, tokenizer):
    model.eval()
    
    o = []
    t = []
    
    for batch, (X, target, segment_emb, masking_label) in enumerate(tqdm(iterator)):
        output = model(X.to(device), segment_emb.long().to(device))
    
        output_ = torch.argmax(output.clone().detach().to("cpu"), axis=-1)
        target_ = target.clone().detach().to("cpu")
        
        o.extend(output_[masking_label.bool()].tolist())
        t.extend(target_[masking_label.bool()].tolist())
        
    return o, t

def decode(x, tokenizer):
    results = []
    for line in x:
        decoded = ""
        for s in line:
            decoded += tokenizer.vocab.itos[s]
        results.append(decoded)
        
    return results 

In [43]:
import warnings
warnings.filterwarnings(action='ignore')

train_data, test_data = load_dataset()
tokenizer = load_tokenizer()

VOCAB_DIM     = len(tokenizer.vocab.itos)
SEQ_LEN       = 256
EMBEDDING_DIM = 512
DEVICE        = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BATCH_SIZE    = 512
N_EPOCHS      = 1000
PAITIENCE     = 50

output_path = "output/MoleculeNet"
weight_path = "weights"

model, optimizer, scheduler, criterion = define_model(vocab_dim=VOCAB_DIM, seq_len=SEQ_LEN, embedding_dim=EMBEDDING_DIM, device=DEVICE)

n_paitience = 0
best_valid_loss = float('inf')
optimizer.zero_grad()
optimizer.step()

start_epoch, model = check_trained_weights()

samples_for_prediction = test_data
prediction_dataloader  = generate_epoch_prediction_dataloader(
                                                                samples_for_prediction, 
                                                                seq_len=SEQ_LEN, 
                                                                tokenizer=tokenizer, 
                                                                batch_size=BATCH_SIZE, 
                                                                masking_rate=0.3, 
                                                                collate_fn=collate_fn,
                                                                num_workers=10
                                                                )

output_list, target_list = predict(model, prediction_dataloader, DEVICE, tokenizer)



load dataset ... 
load tokenizer ... 
load pretrained model : weights/MoleculeNet_LM_best.pt
start prediction


  4%|▎         | 1534/42545 [05:49<2:35:39,  4.39it/s]


KeyboardInterrupt: 

In [51]:
device = "cuda:0"
BATCH_SIZE = 1024
prediction_dataloader  = generate_epoch_prediction_dataloader(
                                                                samples_for_prediction, 
                                                                seq_len=SEQ_LEN, 
                                                                tokenizer=tokenizer, 
                                                                batch_size=BATCH_SIZE, 
                                                                masking_rate=0.3, 
                                                                collate_fn=collate_fn,
                                                                num_workers=10
                                                                )

o = []
t = []

with torch.no_grad():
    model.eval()
    for batch, (X, target, segment_emb, masking_label) in enumerate(tqdm(prediction_dataloader)):
        if batch == 100:
            break

        output = model(X.to(device), segment_emb.long().to(device))

        output_ = torch.argmax(output.clone().detach().to("cpu"), axis=-1)
        target_ = target.clone().detach().to("cpu")

        o.extend(output_[masking_label.bool()].tolist())
        t.extend(target_[masking_label.bool()].tolist())

start prediction


  0%|          | 100/21273 [00:46<2:42:21,  2.17it/s]


In [52]:
from sklearn.metrics import accuracy_score

accuracy_score(o, t)

0.9020997394902591

In [55]:
sum(p.numel() for p in model.parameters())

17540677