In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import os
import glob
import pickle
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

from model.bert import BERT, MLMHead
from utils.molecule_dataloader import MoleculeLangaugeModelDataset, collate_fn
from utils.trainer import train, evaluate, predict


def load_dataset():
    print("load dataset ... ")
    with open("data/molecule_net/molecule_total.pickle", 'rb') as f:
        train_data = pickle.load(f)
        
    train_data = train_data[:100000]
#     with open("data/molecule_net/molecule_small.pickle", "rb") as f:
#         train_data = pickle.load(f)
    
    
    train_data, test_data = train_test_split(train_data, test_size=0.2, shuffle=True, random_state=42)
    train_data, valid_data = train_test_split(train_data, test_size=0.2, shuffle=True, random_state=42)
    
    return train_data, valid_data, test_data


def load_tokenizer():
    print("load tokenizer ... ")
    with open("data/molecule_net/molecule_tokenizer.pickle", "rb") as f:
        tokenizer = pickle.load(f)

    return tokenizer


train_data, valid_data, test_data = load_dataset()
tokenizer = load_tokenizer()

load dataset ... 
load tokenizer ... 


In [2]:
seq_len = 100
d_model = 128
dim_feedforward = 512
dropout_rate = 0.1
pad_token_id = 0
nhead = 8
num_layers = 8
# use_RNN = False
use_RNN = True
batch_size = 512 * 4
masking_rate = 0.3
vocab_dim = len(tokenizer[0])
learning_rate = 0.0001

train_dataset = MoleculeLangaugeModelDataset(data=train_data, seq_len=seq_len, tokenizer=tokenizer, masking_rate=masking_rate)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=collate_fn, pin_memory=True)

valid_dataset = MoleculeLangaugeModelDataset(data=valid_data, seq_len=seq_len, tokenizer=tokenizer, masking_rate=masking_rate)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=8, collate_fn=collate_fn, pin_memory=True)

test_dataset = MoleculeLangaugeModelDataset(data=test_data, seq_len=seq_len, tokenizer=tokenizer, masking_rate=masking_rate)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=8, collate_fn=collate_fn, pin_memory=True)

DEVICE = "cuda"

bert_base = BERT(vocab_dim, seq_len, d_model, dim_feedforward, pad_token_id, nhead, num_layers, dropout_rate)
model = MLMHead(bert_base, d_model, vocab_dim, use_RNN).to(DEVICE)
# model = MLMHead(bert_base, d_model, vocab_dim, use_RNN)

criterion = nn.CrossEntropyLoss(ignore_index=pad_token_id)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.1)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
# scheduler = CosineAnnealingWarmupRestarts(optimizer, first_cycle_steps=200, cycle_mult=1.0,
#                                           max_lr=0.005, min_lr=0.00001, warmup_steps=50, gamma=1.0)


In [3]:
import os
from tqdm import tqdm
import warnings
warnings.filterwarnings(action='ignore')

N_EPOCHS = 100
PAITIENCE = 10
start_epoch = 0
n_paitience = 0
best_valid_loss = float('inf')
optimizer.zero_grad()

project_name = "BERT_deargen_similar"
output_path = f"output/{project_name}"
weight_path = f"weights/{project_name}"

os.makedirs(output_path, exist_ok=True)
os.makedirs(weight_path, exist_ok=True)
   
for epoch in range(start_epoch, N_EPOCHS):
    print(f'Epoch: {epoch:04}')
    model.train()

    epoch_loss = 0
    epoch_corrects = 0
    epoch_num_data = 0

    for X, target, masking_label in tqdm(train_dataloader):
        
        optimizer.zero_grad()
        
        output = model(X.to("cuda"))
        output_ = torch.argmax(output.clone().detach().to("cpu"), axis=-1)
        target_ = target.clone().detach().to('cpu')
        output_dim = output.shape[-1]
        
        output = output.reshape(-1, output_dim)
        target = target.reshape(-1).to("cuda")
    
        loss = criterion(output, target)
        loss.backward()
        
        epoch_loss += loss.item()

        optimizer.step()
        
        epoch_corrects += torch.sum(output_[target_ != 0] == target_[target_ != 0])
        epoch_num_data += torch.numel(output_[target_ != 0])
    
    print(f"loss: {epoch_loss} accuracy: {epoch_corrects / epoch_num_data * 100}")
    

Epoch: 0000


100%|██████████| 32/32 [00:23<00:00,  1.38it/s]


loss: 112.78132677078247 accuracy: 31.84148597717285
Epoch: 0001


100%|██████████| 32/32 [00:22<00:00,  1.40it/s]


loss: 91.19699573516846 accuracy: 37.683956146240234
Epoch: 0002


100%|██████████| 32/32 [00:22<00:00,  1.39it/s]


loss: 82.96349954605103 accuracy: 37.683956146240234
Epoch: 0003


100%|██████████| 32/32 [00:22<00:00,  1.39it/s]


loss: 78.67789721488953 accuracy: 37.683956146240234
Epoch: 0004


100%|██████████| 32/32 [00:22<00:00,  1.40it/s]


loss: 76.17940521240234 accuracy: 37.683956146240234
Epoch: 0005


 34%|███▍      | 11/32 [00:09<00:17,  1.18it/s]


KeyboardInterrupt: 

In [11]:
for X, target, masking_label in tqdm(train_dataloader):
    output = model(X.to("cuda"))
    break

  0%|          | 0/32 [00:01<?, ?it/s]


In [12]:
X.shape

torch.Size([2048, 100])

In [13]:
target.shape

torch.Size([2048, 100])

In [14]:
output.shape

torch.Size([2048, 100, 69])

In [19]:
import torch.nn.functional as F

F.cross_entropy(output.reshape((-1, 69)).to("cpu"), target.to("cpu"))

ValueError: Expected input batch_size (204800) to match target batch_size (2048).