In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torchtext

import pickle
from tqdm import tqdm


# with open("data/molecule_small.pickle", 'rb') as f:
#     molecules = pickle.load(f)

with open("data/molecule_total.pickle", 'rb') as f:
    molecules = pickle.load(f)

molecules = molecules[:int(len(molecules) * 0.01)]
train_data = molecules[int(len(molecules) * 0.2):]
test_data  = molecules[:int(len(molecules) * 0.2)]

# tokenizer = torchtext.legacy.data.Field(tokenize=None,
#                                         init_token='<CLS>',
#                                         eos_token='<SEP>',
#                                         pad_token='<PAD>',
#                                         unk_token='<MASK>',
#                                         lower=False,
#                                         batch_first=False,
#                                         include_lengths=False)

# tokenizer.build_vocab(train_data, min_freq=1)

with open("data/tokenizer.pickle", "rb") as f:
    tokenizer = pickle.load(f)

vocab_dim     = len(tokenizer.vocab.itos)
seq_len       = 256
embedding_dim = 512
device        = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device        = "cpu"
batch_size    = 256

In [11]:
# with open("data/tokenizer.pickle", 'wb') as f:
#     pickle.dump(tokenizer, f)
    
# with open("data/molecule_total.pickle", 'wb') as f:
#     pickle.dump(molecules, f)

In [2]:
class MolecularLangaugeModelDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, seq_len=128, masking_rate=0.15):
        super(MolecularLangaugeModelDataset, self).__init__()

        self.data          = data        
        self.tokenizer     = tokenizer
        self.vocab         = tokenizer.vocab
        self.seq_len       = seq_len
        self.masking_rate  = masking_rate
        
        self.cls_token_id  = self.tokenizer.vocab.stoi[self.tokenizer.init_token]
        self.sep_token_id  = self.tokenizer.vocab.stoi[self.tokenizer.eos_token]
        self.pad_token_id  = self.tokenizer.vocab.stoi[self.tokenizer.pad_token]
        self.mask_token_id = self.tokenizer.vocab.stoi[self.tokenizer.unk_token]
        
    def __getitem__(self, idx):
        target = self.tokenizer.numericalize(self.data[idx]).squeeze()
        
        if len(target) < self.seq_len - 2:
            pad_length = self.seq_len - len(target) - 2
        else:
            target = target[:self.seq_len-2]
            pad_length = 0
               
        masked_sent, masking_label = self.masking(target)
        
        # MLM
        train = torch.cat([
            torch.tensor([self.cls_token_id]), 
            masked_sent,
            torch.tensor([self.sep_token_id]),
            torch.tensor([self.pad_token_id] * pad_length)
        ]).long().contiguous()
        
        target = torch.cat([
            torch.tensor([self.cls_token_id]), 
            target,
            torch.tensor([self.sep_token_id]),
            torch.tensor([self.pad_token_id] * pad_length)
        ]).long().contiguous()
        
        masking_label = torch.cat([
            torch.zeros(1), 
            masking_label,
            torch.zeros(1),
            torch.zeros(pad_length)
        ])
                
        segment_embedding = torch.zeros(target.size(0))
        
        return train, target, segment_embedding, masking_label
        
    
    def __len__(self):
        return len(self.data)
    
    
    def __iter__(self):
        for x in self.data:
            yield x
            
    
    def get_vocab(self):
        return self.vocab

    
    # TODO mask 안에서 random 으로 바꿔주는 것 추가
    def masking(self, x):
        x             = torch.tensor(x).long().contiguous()
        masking_idx   = torch.randperm(x.size()[0])[:round(x.size()[0] * self.masking_rate) + 1]       
        masking_label = torch.zeros(x.size()[0])
        masking_label[masking_idx] = 1
        x             = x.masked_fill(masking_label.bool(), self.mask_token_id)
        
        return x, masking_label

In [3]:
# https://inhyeokyoo.github.io/project/nlp/bert-issue/

import torch
import torch.nn as nn

class BERT(nn.Module):
    def __init__(self, vocab_dim, seq_len, embedding_dim, pad_token_id):
        super(BERT, self).__init__()
        self.pad_token_id  = pad_token_id
        self.nhead         = 8
        self.embedding     = BERTEmbedding(vocab_dim, seq_len, embedding_dim)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=self.nhead, batch_first=True)
        self.encoder_block = nn.TransformerEncoder(self.encoder_layer, num_layers=6)
        
        
    def forward(self, data, segment_embedding):
        pad_mask  = BERT.get_attn_pad_mask(data, data, self.pad_token_id).repeat(self.nhead, 1, 1)
        embedding = self.embedding(data, segment_embedding)
        output    = self.encoder_block(embedding, pad_mask) 
        
        return output
    
    @staticmethod
    def get_attn_pad_mask(seq_q, seq_k, i_pad):
        batch_size, len_q = seq_q.size()
        batch_size, len_k = seq_k.size()
        pad_attn_mask = seq_k.data.eq(i_pad)
        pad_attn_mask= pad_attn_mask.unsqueeze(1).expand(batch_size, len_q, len_k)
        
        return pad_attn_mask

In [4]:
class BERTEmbedding(nn.Module):
    def __init__(self, vocab_dim, seq_len, embedding_dim, dropout_rate=0.1, device=device):
        super(BERTEmbedding, self).__init__()
        self.seq_len       = seq_len
        self.vocab_dim     = vocab_dim
        self.embedding_dim = embedding_dim
        self.dropout_rate  = dropout_rate
        
        # vocab --> embedding
        self.token_embedding      = nn.Embedding(self.vocab_dim, self.embedding_dim) 
        self.token_dropout        = nn.Dropout(self.dropout_rate)    
        
        # seq len --> embedding
        self.positional_embedding = nn.Embedding(self.seq_len, self.embedding_dim)
        self.positional_dropout   = nn.Dropout(self.dropout_rate) 
        
        # segment (0, 1) --> embedding
        self.segment_embedding    = nn.Embedding(2, self.embedding_dim)
        self.segment_dropout      = nn.Dropout(self.dropout_rate) 
        
        
    def forward(self, data, segment_embedding):
        token_embedding      = self.token_embedding(data)
        token_embedding      = self.token_dropout(token_embedding)
        
        positional_encoding  = torch.arange(start=0, end=self.seq_len, step=1).long()
        # data의 device 정보 가져와서 처리
        positional_encoding  = positional_encoding.unsqueeze(0).expand(data.size()).to(device)
        positional_embedding = self.positional_embedding(positional_encoding)
        positional_embedding = self.positional_dropout(positional_embedding)
        
        segment_embedding    = self.segment_embedding(segment_embedding)
        segment_embedding    = self.segment_dropout(segment_embedding)
        
        return token_embedding + positional_embedding + segment_embedding

In [5]:
class MaskedLanguageModeling(nn.Module):
    def __init__(self, bert, output_dim):
        super(MaskedLanguageModeling, self).__init__()
        self.bert = bert
        d_model   = bert.embedding.token_embedding.weight.size(1)
        self.fc   = nn.Linear(d_model, output_dim)
    
    def forward(self, x, segment_embedding):
        output = self.bert(x, segment_embedding)
        output = self.fc(output)
        
        return output

In [6]:
def train(model, iterator, optimizer, criterion, device, clip=1):
    model.train()
    
    epoch_loss = 0
    
    for batch, (X, target, segment_emb, masking_label) in tqdm(enumerate(iterator), total=len(iterator)):
        optimizer.zero_grad()
        
        output = model(X.to(device), segment_emb.long().to(device))
        output_dim = output.shape[-1]
        
        output = output[masking_label.bool().to(device)].reshape(-1, output_dim)
        target = target.reshape(-1)[masking_label.reshape(-1).bool()].to(device)
        
        loss   = criterion(output, target)
        loss.backward()
                
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(iterator)


@torch.no_grad()
def evaluate(model, iterator, optimizer, criterion, device):
    model.eval()
    
    epoch_loss = 0
    
    for batch, (X, target, segment_emb, masking_label) in enumerate(iterator):
        optimizer.zero_grad()
        
        output = model(X.to(device), segment_emb.long().to(device))
        output_dim = output.shape[-1]
        
        output = output[masking_label.bool().to(device)].reshape(-1, output_dim)
        target = target.reshape(-1)[masking_label.reshape(-1).bool()].to(device)
        
        loss   = criterion(output, target)
        
        epoch_loss += loss.item()

    return epoch_loss / len(iterator)


@torch.no_grad()
def predict(model, iterator, device):
    model.eval()
    
    for batch, (X, target, segment_emb, masking_label) in enumerate(iterator):
        optimizer.zero_grad()
        
        output = model(X.to(device), segment_emb.long().to(device))
        output_dim = output.shape[-1]
        
        output = output[masking_label.bool().to(device)].reshape(-1, output_dim)
        target = target.reshape(-1)[masking_label.reshape(-1).bool()].to(device)
        
        print(f"Target: {target.tolist()}\nOutput: {torch.argmax(output, 1).tolist()}\n")

In [7]:
train_dataset = MolecularLangaugeModelDataset(data=train_data, seq_len=seq_len, tokenizer=tokenizer, masking_rate=0.3)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

valid_dataset = MolecularLangaugeModelDataset(data=test_data, seq_len=seq_len, tokenizer=tokenizer, masking_rate=0.3)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

predict_dataset = MolecularLangaugeModelDataset(data=test_data[:10], seq_len=seq_len, tokenizer=tokenizer, masking_rate=0.3)
predict_dataloader = torch.utils.data.DataLoader(predict_dataset, batch_size=1, shuffle=False)


In [None]:
import warnings
warnings.filterwarnings(action='ignore')

bert      = BERT(vocab_dim=vocab_dim, seq_len=seq_len, embedding_dim=embedding_dim, pad_token_id=1).to(device)
model     = MaskedLanguageModeling(bert, output_dim=vocab_dim).to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=[0.9, 0.999], weight_decay=0.01)
scheduler = ReduceLROnPlateau(optimizer, 'min')
criterion = nn.CrossEntropyLoss()

N_EPOCHS  = 1000
PAITIENCE = 30

n_paitience = 0
best_valid_loss = float('inf')
optimizer.zero_grad()
optimizer.step()

for epoch in range(N_EPOCHS):
    if epoch % 5 == 0:
        predict(model, predict_dataloader, device)
    
    train_loss = train(model, train_dataloader, optimizer, criterion, device)
    valid_loss = evaluate(model, valid_dataloader, optimizer, criterion, device)
    
    scheduler.step(valid_loss)
    
    print(f'Epoch: {epoch + 1:04}')
    print(f'Train Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f}')

    if n_paitience < PAITIENCE:
        if best_valid_loss > valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), 'weights/Molecular_LM_best.pt')
            n_paitience = 0
        elif best_valid_loss <= valid_loss:
            n_paitience += 1
    else:
        print("Early stop!")
        model.load_state_dict(torch.load('weights/MolecularNet_LM_best.pt'))
        model.eval()
        break

  0%|          | 0/3404 [00:00<?, ?it/s]

Target: [4, 5, 4, 5, 5, 4, 5, 9, 6, 4, 7, 8, 14, 5, 4, 4, 4, 4, 4, 7, 4]
Output: [1, 10, 10, 58, 58, 10, 10, 10, 58, 10, 10, 10, 10, 10, 58, 58, 58, 10, 10, 1, 58]

Target: [4, 4, 10, 6, 8, 10, 7, 4, 8, 4, 4, 7, 9, 4, 12]
Output: [58, 10, 10, 10, 10, 1, 1, 10, 10, 10, 10, 58, 10, 10, 58]

Target: [4, 10, 5, 10, 7, 9, 8, 6, 11, 7, 6, 5, 9, 7, 4, 4, 5, 9, 15, 13]
Output: [10, 10, 10, 1, 10, 10, 58, 10, 10, 58, 1, 10, 58, 10, 58, 10, 10, 1, 10, 10]

Target: [4, 6, 4, 6, 5, 8, 4, 4, 4, 5, 10, 4, 11, 4, 5]
Output: [58, 10, 10, 10, 10, 58, 10, 58, 10, 58, 58, 58, 10, 58, 58]

Target: [6, 4, 6, 5, 9, 8, 10, 4, 6, 7, 8, 4, 4, 5, 4]
Output: [10, 10, 10, 10, 10, 58, 10, 1, 10, 10, 10, 10, 58, 10, 58]

Target: [4, 5, 5, 4, 6, 5, 7, 6, 4, 4, 4, 4, 5, 4, 12, 4]
Output: [58, 10, 58, 1, 1, 58, 58, 58, 10, 58, 58, 58, 58, 58, 10, 10]

Target: [4, 5, 5, 6, 4, 5, 10, 4, 6, 4, 8, 4, 4, 7, 7, 15]
Output: [58, 10, 58, 10, 10, 58, 1, 10, 58, 10, 10, 58, 1, 10, 10, 58]

Target: [4, 16, 6, 5, 7, 7, 4, 10, 4, 

 33%|███▎      | 1137/3404 [06:19<12:57,  2.92it/s]