In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torchtext

import pickle

import glob
import numpy as np
import pandas as pd
from sklearn.utils import shuffle

# with open("data/molecule_small.pickle", 'rb') as f:
#     molecules = pickle.load(f)

with open("data/molecule_total.pickle", 'rb') as f:
    molecules = pickle.load(f)

train_data = molecules[int(len(molecules) * 0.2):]
test_data  = molecules[:int(len(molecules) * 0.2)]

# tokenizer = torchtext.legacy.data.Field(tokenize=None,
#                                         init_token='<CLS>',
#                                         eos_token='<SEP>',
#                                         pad_token='<PAD>',
#                                         unk_token='<MASK>',
#                                         lower=False,
#                                         batch_first=False,
#                                         include_lengths=False)

# tokenizer.build_vocab(train_data, min_freq=1)

# with open("data/tokenizer.pickle", 'wb') as f:
#     pickle.dump(tokenizer, f)

with open("data/tokenizer.pickle", "rb") as f:
    tokenizer = pickle.load(f)

vocab_dim     = len(tokenizer.vocab.itos)
seq_len       = 128
embedding_dim = 512
device        = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device        = "cpu"
batch_size    = 512

In [2]:
class MolecularLangaugeModelDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, seq_len=128, masking_rate=0.15):
        super(MolecularLangaugeModelDataset, self).__init__()

        self.data          = data        
        self.tokenizer     = tokenizer
        self.vocab         = tokenizer.vocab
        self.seq_len       = seq_len
        self.masking_rate  = masking_rate
        
        self.cls_token_id  = self.tokenizer.vocab.stoi[self.tokenizer.init_token]
        self.sep_token_id  = self.tokenizer.vocab.stoi[self.tokenizer.eos_token]
        self.pad_token_id  = self.tokenizer.vocab.stoi[self.tokenizer.pad_token]
        self.mask_token_id = self.tokenizer.vocab.stoi[self.tokenizer.unk_token]
        
    def __getitem__(self, idx):
        try:
            target = self.tokenizer.numericalize(self.data[idx]).squeeze()

            if len(target) < self.seq_len - 2:
                pad_length = self.seq_len - len(target) - 2
            else:
                target = target[:self.seq_len-2]
                pad_length = 0

            masked_sent, masking_label = self.masking(target)

            # MLM
            train = torch.cat([
                torch.tensor([self.cls_token_id]), 
                masked_sent,
                torch.tensor([self.sep_token_id]),
                torch.tensor([self.pad_token_id] * pad_length)
            ]).long().contiguous()

            target = torch.cat([
                torch.tensor([self.cls_token_id]), 
                target,
                torch.tensor([self.sep_token_id]),
                torch.tensor([self.pad_token_id] * pad_length)
            ]).long().contiguous()

            masking_label = torch.cat([
                torch.zeros(1), 
                masking_label,
                torch.zeros(1),
                torch.zeros(pad_length)
            ])

            segment_embedding = torch.zeros(target.size(0))
        
            return train, target, segment_embedding, masking_label
        except:
            return None
        
    
    def __len__(self):
        return len(self.data)
    
    
    def __iter__(self):
        for x in self.data:
            yield x
            
    
    def get_vocab(self):
        return self.vocab

    
    # TODO mask 안에서 random 으로 바꿔주는 것 추가
    def masking(self, x):
        x             = torch.tensor(x).long().contiguous()
        masking_idx   = torch.randperm(x.size()[0])[:round(x.size()[0] * self.masking_rate) + 1]       
        masking_label = torch.zeros(x.size()[0])
        masking_label[masking_idx] = 1
        x             = x.masked_fill(masking_label.bool(), self.mask_token_id)
        
        return x, masking_label
    
def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

In [3]:
# https://inhyeokyoo.github.io/project/nlp/bert-issue/

import torch
import torch.nn as nn

class BERT(nn.Module):
    def __init__(self, vocab_dim, seq_len, embedding_dim, pad_token_id):
        super(BERT, self).__init__()
        self.pad_token_id  = pad_token_id
        self.nhead         = 8
        self.embedding     = BERTEmbedding(vocab_dim, seq_len, embedding_dim)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=self.nhead, batch_first=True)
        self.encoder_block = nn.TransformerEncoder(self.encoder_layer, num_layers=6)
        
        
    def forward(self, data, segment_embedding):
        pad_mask  = BERT.get_attn_pad_mask(data, data, self.pad_token_id).repeat(self.nhead, 1, 1)
        embedding = self.embedding(data, segment_embedding)
        output    = self.encoder_block(embedding, pad_mask) 
        
        return output
    
    @staticmethod
    def get_attn_pad_mask(seq_q, seq_k, i_pad):
        batch_size, len_q = seq_q.size()
        batch_size, len_k = seq_k.size()
        pad_attn_mask = seq_k.data.eq(i_pad)
        pad_attn_mask= pad_attn_mask.unsqueeze(1).expand(batch_size, len_q, len_k)
        
        return pad_attn_mask

In [4]:
class BERTEmbedding(nn.Module):
    def __init__(self, vocab_dim, seq_len, embedding_dim, dropout_rate=0.1, device=device):
        super(BERTEmbedding, self).__init__()
        self.seq_len       = seq_len
        self.vocab_dim     = vocab_dim
        self.embedding_dim = embedding_dim
        self.dropout_rate  = dropout_rate
        
        # vocab --> embedding
        self.token_embedding      = nn.Embedding(self.vocab_dim, self.embedding_dim) 
        self.token_dropout        = nn.Dropout(self.dropout_rate)    
        
        # seq len --> embedding
        self.positional_embedding = nn.Embedding(self.seq_len, self.embedding_dim)
        self.positional_dropout   = nn.Dropout(self.dropout_rate) 
        
        # segment (0, 1) --> embedding
        self.segment_embedding    = nn.Embedding(2, self.embedding_dim)
        self.segment_dropout      = nn.Dropout(self.dropout_rate) 
        
        
    def forward(self, data, segment_embedding):
        token_embedding      = self.token_embedding(data)
        token_embedding      = self.token_dropout(token_embedding)
        
        positional_encoding  = torch.arange(start=0, end=self.seq_len, step=1).long()
        # data의 device 정보 가져와서 처리
        positional_encoding  = positional_encoding.unsqueeze(0).expand(data.size()).to(device)
        positional_embedding = self.positional_embedding(positional_encoding)
        positional_embedding = self.positional_dropout(positional_embedding)
        
        segment_embedding    = self.segment_embedding(segment_embedding)
        segment_embedding    = self.segment_dropout(segment_embedding)
        
        return token_embedding + positional_embedding + segment_embedding

In [5]:
class MaskedLanguageModeling(nn.Module):
    def __init__(self, bert, output_dim):
        super(MaskedLanguageModeling, self).__init__()
        self.bert = bert
        d_model   = bert.embedding.token_embedding.weight.size(1)
        self.fc   = nn.Linear(d_model, output_dim)
    
    def forward(self, x, segment_embedding):
        output = self.bert(x, segment_embedding)
        output = self.fc(output)
        
        return output

In [7]:
def train(model, iterator, optimizer, criterion, device, clip=1):
    model.train()
    
    epoch_loss = 0
    
    for batch, (X, target, segment_emb, masking_label) in enumerate(iterator):
        optimizer.zero_grad()
        
        output = model(X.to(device), segment_emb.long().to(device))
        output_dim = output.shape[-1]

        
#         output = output[masking_label.bool().to(device)].reshape(-1, output_dim)
#         target = target.reshape(-1)[masking_label.reshape(-1).bool()].to(device)
        
        output = output.reshape(-1, output_dim)
        target = target.reshape(-1).to(device)
    
        loss   = criterion(output, target)
        loss.backward()
                
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(iterator)


@torch.no_grad()
def evaluate(model, iterator, optimizer, criterion, device):
    model.eval()
    
    epoch_loss = 0
    
    for batch, (X, target, segment_emb, masking_label) in enumerate(iterator):
        optimizer.zero_grad()
        
        output = model(X.to(device), segment_emb.long().to(device))
        output_dim = output.shape[-1]
        
#         output = output[masking_label.bool().to(device)].reshape(-1, output_dim)
#         target = target.reshape(-1)[masking_label.reshape(-1).bool()].to(device)
        
        output = output.reshape(-1, output_dim)
        target = target.reshape(-1).to(device)
        
        loss   = criterion(output, target)
        
        epoch_loss += loss.item()

    return epoch_loss / len(iterator)


@torch.no_grad()
def predict(model, iterator, device, tokenizer):
    model.eval()
    
    for batch, (X, target, segment_emb, masking_label) in enumerate(iterator):
        output = model(X.to(device), segment_emb.long().to(device))
    
        output_ = torch.argmax(output.clone().detach().to("cpu"), axis=-1)
        target_ = target.clone().detach().to("cpu")

        output_list = decode(output_, tokenizer)
        target_list = decode(target_, tokenizer)

    return output_list, target_list


def decode(x, tokenizer):
    results = []
    for line in x:
        decoded = ""
        for s in line:
            decoded += tokenizer.vocab.itos[s]
        results.append(decoded)
        
    return results 


def generate_epoch_dataloader(data, seq_len, tokenizer, masking_rate, batch_size, collate_fn, shuffle=True, num_workers=6):
    dataset    = MolecularLangaugeModelDataset(data=data, seq_len=seq_len, tokenizer=tokenizer, masking_rate=masking_rate)
    dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn)
    
    return dataloader


def generate_epoch_prediction_dataloader(data, seq_len, tokenizer, masking_rate, batch_size, collate_fn, shuffle=True, num_workers=5):    
    dataset    = MolecularLangaugeModelDataset(data=data, seq_len=seq_len, tokenizer=tokenizer, masking_rate=masking_rate)
    dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn)
    
    return dataloader

In [8]:
import warnings
warnings.filterwarnings(action='ignore')

bert_base = BERT(vocab_dim=vocab_dim, seq_len=seq_len, embedding_dim=embedding_dim, pad_token_id=1).to(device)
model     = MaskedLanguageModeling(bert_base, output_dim=vocab_dim).to(device)

optimizer = optim.AdamW(model.parameters(), lr=1e-4, betas=[0.9, 0.999], weight_decay=0.01)
scheduler = ReduceLROnPlateau(optimizer, 'min')
criterion = nn.CrossEntropyLoss(ignore_index=1)

In [None]:
start_epoch = 0
if len(glob.glob("output/*.tsv")) != 0:
    print("load pretrained model ... ")
    start_epoch = len(glob.glob("output/*.tsv"))
    model.load_state_dict(torch.load('weights/MolecularNet_LM_best.pt'))
    
N_EPOCHS  = 1000
PAITIENCE = 30

n_paitience = 0
best_valid_loss = float('inf')
optimizer.zero_grad()
optimizer.step()

for epoch in range(start_epoch, N_EPOCHS):
    print(f'Epoch: {epoch:04}')
    epoch_masking_rate = np.random.choice([0.2, 0.3, 0.4, 0.5, 0.6])
    epoch_train_data   = shuffle(train_data, n_samples=int(len(train_data) * 0.01))
    epoch_valid_data   = shuffle(test_data, n_samples=int(len(test_data) * 0.01))
    train_dataloader   = generate_epoch_dataloader(
                                                    epoch_train_data, 
                                                    seq_len=seq_len, 
                                                    tokenizer=tokenizer, 
                                                    batch_size=batch_size, 
                                                    masking_rate=epoch_masking_rate,
                                                    collate_fn=collate_fn
                                                    )
    
    valid_dataloader   = generate_epoch_dataloader(
                                                    epoch_valid_data, 
                                                    seq_len=seq_len, 
                                                    tokenizer=tokenizer, 
                                                    batch_size=batch_size, 
                                                    masking_rate=0.3,
                                                    collate_fn=collate_fn
                                                    )
    
    print(f'Masking rate: {epoch_masking_rate} Train dataset: {len(epoch_train_data)}')
    
    train_loss = train(model, train_dataloader, optimizer, criterion, device)
    valid_loss = evaluate(model, valid_dataloader, optimizer, criterion, device)
    
    scheduler.step(valid_loss)
    
    print(f'Train Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f}')

    with open("output/log.txt", "a") as f:
        f.write("epoch: {0:04d} train loss: {1:.4f}, test loss: {2:.4f}\n".format(epoch, train_loss, valid_loss))

    if epoch % 1 == 0:
        print("Predictions ...\n")
        sampled_for_prediction = shuffle(epoch_valid_data, n_samples=10)
        prediction_dataloader  = generate_epoch_prediction_dataloader(
                                                                        sampled_for_prediction, 
                                                                        seq_len=seq_len, 
                                                                        tokenizer=tokenizer, 
                                                                        batch_size=len(sampled_for_prediction), 
                                                                        masking_rate=0.5, 
                                                                        collate_fn=collate_fn
                                                                        )
        output_list, target_list = predict(model, prediction_dataloader, device, tokenizer)
        prediction_results = pd.DataFrame({"output": output_list, "target": target_list})
        prediction_results.to_csv("output/prediction_results_epoch-{0:04d}.tsv".format(epoch), sep="\t", index=False)            
        
    if n_paitience < PAITIENCE:
        if best_valid_loss > valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), 'weights/MolecularNet_LM_best.pt')
            n_paitience = 0
        elif best_valid_loss <= valid_loss:
            n_paitience += 1
    else:
        print("Early stop!")
        model.load_state_dict(torch.load('weights/MolecularNet_LM_best.pt'))
        model.eval()
        break

load pretrained model ... 
Epoch: 0065
Masking rate: 0.2 Train dataset: 871311
Train Loss: 0.0561 | Valid Loss: 0.0698
Predictions ...

Epoch: 0066
Masking rate: 0.2 Train dataset: 871311
Train Loss: 0.0560 | Valid Loss: 0.0694
Predictions ...

Epoch: 0067
Masking rate: 0.4 Train dataset: 871311
Train Loss: 0.1692 | Valid Loss: 0.0678
Predictions ...

Epoch: 0068
Masking rate: 0.6 Train dataset: 871311
Train Loss: 0.4434 | Valid Loss: 0.0698
Predictions ...

Epoch: 0069
Masking rate: 0.5 Train dataset: 871311
Train Loss: 0.2705 | Valid Loss: 0.0685
Predictions ...

Epoch: 0070
Masking rate: 0.5 Train dataset: 871311
Train Loss: 0.2694 | Valid Loss: 0.0684
Predictions ...

Epoch: 0071
Masking rate: 0.6 Train dataset: 871311
Train Loss: 0.4365 | Valid Loss: 0.0692
Predictions ...

Epoch: 0072
Masking rate: 0.2 Train dataset: 871311
Train Loss: 0.0555 | Valid Loss: 0.0678
Predictions ...

Epoch: 0073
Masking rate: 0.3 Train dataset: 871311
Train Loss: 0.0992 | Valid Loss: 0.0675
Predictio