In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import pytorch_lightning as pl
import torchmetrics
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

from cosine_annealing_warmup import CosineAnnealingWarmupRestarts

import os
import glob
import pickle
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

from model.bert import BERT, MLMHead
from utils.molecule_dataloader import MoleculeLangaugeModelDataset, collate_fn
from utils.trainer import train, evaluate, predict


def load_dataset():
    print("load dataset ... ")
#     with open("data/molecule_net/molecule_total.pickle", 'rb') as f:
#         train_data = pickle.load(f)
        
#     train_data = train_data[:100000]
    with open("data/molecule_net/molecule_small.pickle", "rb") as f:
        train_data = pickle.load(f)
    
    
    train_data, test_data = train_test_split(train_data, test_size=0.2, shuffle=True, random_state=42)
    train_data, valid_data = train_test_split(train_data, test_size=0.2, shuffle=True, random_state=42)
    
    return train_data, valid_data, test_data


def load_tokenizer():
    print("load tokenizer ... ")
    with open("data/molecule_net/molecule_tokenizer.pickle", "rb") as f:
        tokenizer = pickle.load(f)

    return tokenizer


train_data, valid_data, test_data = load_dataset()
tokenizer = load_tokenizer()

load dataset ... 
load tokenizer ... 


In [2]:
seq_len = 100
d_model = 128
dim_feedforward = 512
dropout_rate = 0.1
pad_token_id = 0
nhead = 8
num_layers = 8
# use_RNN = False
use_RNN = True
batch_size = 512 * 4
masking_rate = 0.15
vocab_dim = len(tokenizer[0])
learning_rate = 0.0005

train_dataset = MoleculeLangaugeModelDataset(data=train_data, seq_len=seq_len, tokenizer=tokenizer, masking_rate=masking_rate)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=collate_fn, pin_memory=True)

valid_dataset = MoleculeLangaugeModelDataset(data=valid_data, seq_len=seq_len, tokenizer=tokenizer, masking_rate=masking_rate)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=8, collate_fn=collate_fn, pin_memory=True)

test_dataset = MoleculeLangaugeModelDataset(data=test_data, seq_len=seq_len, tokenizer=tokenizer, masking_rate=masking_rate)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=8, collate_fn=collate_fn, pin_memory=True)

DEVICE = "cuda"

bert_base = BERT(vocab_dim, seq_len, d_model, dim_feedforward, pad_token_id, nhead, num_layers, dropout_rate)
model = MLMHead(bert_base, d_model, vocab_dim, use_RNN).to(DEVICE)
# model = MLMHead(bert_base, d_model, vocab_dim, use_RNN)

criterion = nn.CrossEntropyLoss(ignore_index=pad_token_id)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.1)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
# scheduler = CosineAnnealingWarmupRestarts(optimizer, first_cycle_steps=200, cycle_mult=1.0,
#                                           max_lr=0.005, min_lr=0.00001, warmup_steps=50, gamma=1.0)


In [3]:
import pytorch_lightning as pl
import torchmetrics

class MoleculeNet(pl.LightningModule):
    def __init__(self, learning_rate):
        super(MoleculeNet, self).__init__()
        
        self.bert_base = BERT(vocab_dim, seq_len, d_model, dim_feedforward, pad_token_id, nhead, num_layers, dropout_rate)
        self.model = MLMHead(bert_base, d_model, vocab_dim, use_RNN).to(DEVICE)
        self.learning_rate = learning_rate
        
#         self.train_accuracy = torchmetrics.Accuracy()
#         self.valid_accuracy = torchmetrics.Accuracy()
#         self.test_accuracy = torchmetrics.Accuracy()
    
    
    def training_step(self, batch, batch_idx):
        x, y, masked_label = batch
        
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y, ignore_index=0)
        
        self.log("train_loss", loss)
#         self.log("train_accuracy", self.train_accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)

        return loss
    
        
    def validation_step(self, batch, batch_idx):
        x, y, masked_label = batch
        
        y_hat = self.model(x)
#         loss = F.cross_entropy(y_hat, y, ignore_index=0)
        
        self.log("valid_loss", loss)
#         self.log("valid_accuracy", self.valid_accuracy(y_hat, y), on_step=False, on_epoch=True, prog_bar=True, logger=True)

    
    def test_step(self, batch, batch_idx):
        x, y, masked_label = batch
        
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y, ignore_index=0)
        
        self.log("test_loss", loss)
#         self.log("test_accuracy", self.test_accuracy(y_hat, y), on_step=False, on_epoch=True, prog_bar=True, logger=True)
          
        
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=learning_rate, weight_decay=0.1)
        
        return {"optimizer": optimizer}

    
# def define_callbacks(patience, ckpt_path):
#     early_stopping = EarlyStopping('valid_loss', patience=patience)
#     check_points = ModelCheckpoint(monitor="valid_loss", mode="min", dirpath=ckpt_path, save_top_k=1)
    
#     return [early_stopping, check_points]


LEARNING_RATE = 1e-4

melecule_net = MoleculeNet(LEARNING_RATE)
# callbacks = define_callbacks(10, "./weights")


In [4]:
N_EPOCHS = 50
trainer = pl.Trainer(gpus=1, max_epochs=N_EPOCHS, enable_progress_bar=True)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
trainer.fit(melecule_net, train_dataloader)

  rank_zero_warn("You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [None]:
N_EPOCHS = 50

trainer = pl.Trainer(gpus=1, max_epochs=N_EPOCHS, enable_progress_bar=True)
trainer.fit(melecule_net, train_dataloader, valid_dataloader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [3]:
# import os
# import warnings
# warnings.filterwarnings(action='ignore')

# N_EPOCHS = 100
# PAITIENCE = 10
# start_epoch = 0
# n_paitience = 0
# best_valid_loss = float('inf')
# optimizer.zero_grad()

# project_name = "BERT_deargen_similar"
# output_path = f"output/{project_name}"
# weight_path = f"weights/{project_name}"

# os.makedirs(output_path, exist_ok=True)
# os.makedirs(weight_path, exist_ok=True)
   
# for epoch in range(start_epoch, N_EPOCHS):
#     print(f'Epoch: {epoch:04}')

#     train_loss, train_accuracy = train(model, train_dataloader, optimizer, criterion, DEVICE)
#     valid_loss, valid_accuracy = evaluate(model, valid_dataloader, optimizer, criterion, DEVICE)

#     scheduler.step(valid_loss)

#     print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_accuracy:.4f}\nValid Loss: {valid_loss:.4f} | Valid Acc: {valid_accuracy:.4f}')
    
#     with torch.no_grad():
#         model.eval()

#         for i, (X, target, masking_label) in enumerate(test_dataloader):
#             if i < 5:
#                 output = model(X.to("cuda"))
#                 output_ = torch.argmax(output.clone().detach().to("cpu"), axis=-1)
#                 target_ = target.clone().detach().to("cpu")
#                 print(f"prediction results: {np.round((torch.sum(output_ == target_) / torch.numel(output_)).numpy() * 100, 2)} %")
#             else:
#                 break
    
#     with open(os.path.join(output_path, "log.txt"), "a") as f:
#         f.write("epoch: {0:04d} train loss: {1:.4f}, train acc: {2:.4f}, test loss: {3:.4f}, test acc: {4:.4f}\n".format(epoch, train_loss, train_accuracy, valid_loss, valid_accuracy))

#     if n_paitience < PAITIENCE:
#         if best_valid_loss > valid_loss:
#             best_valid_loss = valid_loss
#             torch.save(model.state_dict(), os.path.join(weight_path, 'MoleculeNet_LM_best.pt'))
#             n_paitience = 0
#         elif best_valid_loss <= valid_loss:
#             n_paitience += 1
#     else:
#         print("Early stop!")
#         model.load_state_dict(torch.load(os.path.join(weight_path, 'MoleculeNet_LM_best.pt')))
#         model.eval()
#         break



Epoch: 0000


100%|██████████| 32/32 [00:22<00:00,  1.44it/s]


Train Loss: 2.8355 | Train Acc: 0.3807
Valid Loss: 2.2707 | Valid Acc: 0.3950
prediction results: 19.0 %
prediction results: 18.83 %
prediction results: 18.61 %
prediction results: 18.98 %
prediction results: 19.19 %
Epoch: 0001


100%|██████████| 32/32 [00:22<00:00,  1.42it/s]


Train Loss: 2.0240 | Train Acc: 0.3879
Valid Loss: 1.7838 | Valid Acc: 0.3894
prediction results: 28.07 %
prediction results: 27.82 %
prediction results: 27.56 %
prediction results: 28.02 %
prediction results: 28.43 %
Epoch: 0002


100%|██████████| 32/32 [00:22<00:00,  1.44it/s]


Train Loss: 1.6774 | Train Acc: 0.3817
Valid Loss: 1.5246 | Valid Acc: 0.3870
prediction results: 33.53 %
prediction results: 33.34 %
prediction results: 33.0 %
prediction results: 33.61 %
prediction results: 34.08 %
Epoch: 0003


100%|██████████| 32/32 [00:22<00:00,  1.44it/s]


Train Loss: 1.4626 | Train Acc: 0.3815
Valid Loss: 1.3479 | Valid Acc: 0.3874
prediction results: 34.4 %
prediction results: 34.21 %
prediction results: 33.87 %
prediction results: 34.52 %
prediction results: 34.9 %
Epoch: 0004


100%|██████████| 32/32 [00:22<00:00,  1.43it/s]


Train Loss: 1.3255 | Train Acc: 0.3822
Valid Loss: 1.2434 | Valid Acc: 0.3854
prediction results: 37.17 %
prediction results: 36.95 %
prediction results: 36.6 %
prediction results: 37.25 %
prediction results: 37.8 %
Epoch: 0005


100%|██████████| 32/32 [00:22<00:00,  1.43it/s]


Train Loss: 1.2430 | Train Acc: 0.3831
Valid Loss: 1.1785 | Valid Acc: 0.3903
prediction results: 37.79 %
prediction results: 37.61 %
prediction results: 37.22 %
prediction results: 37.91 %
prediction results: 38.51 %
Epoch: 0006


100%|██████████| 32/32 [00:22<00:00,  1.44it/s]


Train Loss: 1.1915 | Train Acc: 0.3826
Valid Loss: 1.1374 | Valid Acc: 0.3883
prediction results: 38.08 %
prediction results: 37.81 %
prediction results: 37.55 %
prediction results: 38.15 %
prediction results: 38.69 %
Epoch: 0007


100%|██████████| 32/32 [00:22<00:00,  1.42it/s]


Train Loss: 1.1568 | Train Acc: 0.3819
Valid Loss: 1.1092 | Valid Acc: 0.3892
prediction results: 38.32 %
prediction results: 38.09 %
prediction results: 37.85 %
prediction results: 38.5 %
prediction results: 39.04 %
Epoch: 0008


100%|██████████| 32/32 [00:22<00:00,  1.43it/s]


Train Loss: 1.1332 | Train Acc: 0.3820
Valid Loss: 1.0894 | Valid Acc: 0.3867
prediction results: 38.72 %
prediction results: 38.51 %
prediction results: 38.09 %
prediction results: 38.68 %
prediction results: 39.32 %
Epoch: 0009


100%|██████████| 32/32 [00:22<00:00,  1.44it/s]


Train Loss: 1.1177 | Train Acc: 0.3824
Valid Loss: 1.0766 | Valid Acc: 0.3890
prediction results: 38.93 %
prediction results: 38.62 %
prediction results: 38.38 %
prediction results: 38.95 %
prediction results: 39.62 %
Epoch: 0010


100%|██████████| 32/32 [00:22<00:00,  1.44it/s]


Train Loss: 1.1072 | Train Acc: 0.3820
Valid Loss: 1.0676 | Valid Acc: 0.3869
prediction results: 39.09 %
prediction results: 38.8 %
prediction results: 38.59 %
prediction results: 39.22 %
prediction results: 39.74 %
Epoch: 0011


100%|██████████| 32/32 [00:22<00:00,  1.43it/s]


Train Loss: 1.1012 | Train Acc: 0.3828
Valid Loss: 1.0634 | Valid Acc: 0.3866
prediction results: 39.19 %
prediction results: 38.94 %
prediction results: 38.64 %
prediction results: 39.25 %
prediction results: 39.92 %
Epoch: 0012


100%|██████████| 32/32 [00:22<00:00,  1.43it/s]


Train Loss: 1.0961 | Train Acc: 0.3841
Valid Loss: 1.0589 | Valid Acc: 0.3870
prediction results: 39.28 %
prediction results: 39.0 %
prediction results: 38.7 %
prediction results: 39.31 %
prediction results: 39.93 %
Epoch: 0013


100%|██████████| 32/32 [00:22<00:00,  1.43it/s]


Train Loss: 1.0923 | Train Acc: 0.3841
Valid Loss: 1.0562 | Valid Acc: 0.3879
prediction results: 39.22 %
prediction results: 39.03 %
prediction results: 38.76 %
prediction results: 39.37 %
prediction results: 39.98 %
Epoch: 0014


100%|██████████| 32/32 [00:22<00:00,  1.45it/s]


Train Loss: 1.0911 | Train Acc: 0.3841
Valid Loss: 1.0557 | Valid Acc: 0.3895
prediction results: 39.29 %
prediction results: 39.03 %
prediction results: 38.76 %
prediction results: 39.41 %
prediction results: 39.99 %
Epoch: 0015


100%|██████████| 32/32 [00:22<00:00,  1.44it/s]


Train Loss: 1.0917 | Train Acc: 0.3838
Valid Loss: 1.0573 | Valid Acc: 0.3895
prediction results: 39.31 %
prediction results: 39.07 %
prediction results: 38.75 %
prediction results: 39.44 %
prediction results: 39.95 %
Epoch: 0016


100%|██████████| 32/32 [00:22<00:00,  1.43it/s]


Train Loss: 1.0935 | Train Acc: 0.3840
Valid Loss: 1.0600 | Valid Acc: 0.3869
prediction results: 39.28 %
prediction results: 39.02 %
prediction results: 38.77 %
prediction results: 39.45 %
prediction results: 40.02 %
Epoch: 0017


100%|██████████| 32/32 [00:22<00:00,  1.43it/s]


Train Loss: 1.0963 | Train Acc: 0.3853
Valid Loss: 1.0629 | Valid Acc: 0.3888
prediction results: 39.29 %
prediction results: 39.03 %
prediction results: 38.73 %
prediction results: 39.36 %
prediction results: 39.9 %
Epoch: 0018


100%|██████████| 32/32 [00:22<00:00,  1.43it/s]


Train Loss: 1.0999 | Train Acc: 0.3834
Valid Loss: 1.0664 | Valid Acc: 0.3861
prediction results: 39.23 %
prediction results: 38.95 %
prediction results: 38.71 %
prediction results: 39.31 %
prediction results: 39.9 %
Epoch: 0019


100%|██████████| 32/32 [00:22<00:00,  1.43it/s]


Train Loss: 1.1035 | Train Acc: 0.3849
Valid Loss: 1.0705 | Valid Acc: 0.3891
prediction results: 39.13 %
prediction results: 38.89 %
prediction results: 38.62 %
prediction results: 39.26 %
prediction results: 39.87 %
Epoch: 0020


 22%|██▏       | 7/32 [00:06<00:22,  1.12it/s]


KeyboardInterrupt: 