In [None]:
from transformers import BertTokenizer

molecule_tokenizer = BertTokenizer.from_pretrained("data/drug/molecule_tokenizer")
vocab_size = len(molecule_tokenizer.get_vocab().keys())

print(f"load tokenizer\nvocab size: {vocab_size}\nspecial tokens: {molecule_tokenizer.all_special_tokens}")

In [None]:
import os
import pickle

if not os.path.exists("data/drug/X.pkl"):
    from sklearn.model_selection import train_test_split
    
    with open("data/drug/chem_qed_filtered.txt", 'r') as f:
        data = f.readlines()
    
        print(f"load dataset ... # of data: {len(data)}")
    
    X_train, X_test = train_test_split(data, test_size=0.1, random_state=42, shuffle=True)
    X_train, X_valid = train_test_split(X_train, test_size=0.005, random_state=42, shuffle=True)
    
    with open("data/drug/X.pkl", "wb") as f:
        pickle.dump([X_train, X_valid, X_test], f)
else:
    with open("data/drug/X.pkl", "rb") as f:
        X_train, X_valid, X_test = pickle.load(f)
        
print(f"load dataset\nX_train: {len(X_train)}\nX_valid: {len(X_valid)}\nX_test: {len(X_test)}")

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, RandomSampler
from transformers import DataCollatorForLanguageModeling

max_seq_len = 128

class MaskedLMDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        
    def encode(self, data):
        return self.tokenizer.encode(" ".join(data.replace("\n", "")), max_length=self.max_length, truncation=True)
        
        
    def __len__(self):
        return len(self.data)

    
    def __getitem__(self, idx):
        return torch.tensor(self.encode(self.data[idx]), dtype=torch.long)
    
    
data_collator_train = DataCollatorForLanguageModeling(
    tokenizer=molecule_tokenizer, mlm=True, mlm_probability=0.3
)

data_collator_valid = DataCollatorForLanguageModeling(
    tokenizer=molecule_tokenizer, mlm=True, mlm_probability=0.15
)

train_dataset = MaskedLMDataset(X_train, molecule_tokenizer, max_length=max_seq_len)
train_sampler = RandomSampler(X_train, replacement=True, num_samples=1000000)
train_dataloader = DataLoader(train_dataset, batch_size=2048, collate_fn=data_collator_train, num_workers=16, pin_memory=True, prefetch_factor=10, drop_last=True, sampler=train_sampler)

valid_dataset = MaskedLMDataset(X_valid, molecule_tokenizer, max_length=max_seq_len)
valid_dataloader = DataLoader(valid_dataset, batch_size=2048, collate_fn=data_collator_valid, num_workers=16, pin_memory=True, prefetch_factor=10)

test_dataset = MaskedLMDataset(X_test, molecule_tokenizer, max_length=max_seq_len)
test_dataloader = DataLoader(test_dataset, batch_size=2048, collate_fn=data_collator_valid, num_workers=16, pin_memory=True, prefetch_factor=10)


In [None]:
import torchmetrics
import pytorch_lightning as pl
from transformers import BertConfig, BertForMaskedLM
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

config = BertConfig(
    vocab_size=vocab_size,
    hidden_size=128,
    num_hidden_layers=8,
    num_attention_heads=8,
    intermediate_size=512,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=max_seq_len + 2,
    type_vocab_size=1,
    pad_token_id=0,
    position_embedding_type="absolute"
)


class Bert(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.save_hyperparameters()
        self.model = BertForMaskedLM(config)
        
        self.train_accuracy = torchmetrics.Accuracy()
        self.valid_accuracy = torchmetrics.Accuracy()
        self.test_accuracy = torchmetrics.Accuracy()
        
        
    def forward(self, input_ids, labels):
        return self.model(input_ids=input_ids, labels=labels)

       
    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        labels = batch['labels']
        
        output = self(input_ids, labels)

        loss = output.loss
        logits = output.logits

        preds = logits.argmax(dim=-1)
        
        self.log('train_loss', float(loss), on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_accuracy", self.train_accuracy(preds[labels > 0], labels[labels > 0]), on_step=False, on_epoch=True, prog_bar=True, logger=True)
        
        return loss

    
    def validation_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        labels = batch['labels']
        
        output = self(input_ids, labels)

        loss = output.loss
        logits = output.logits

        preds = logits.argmax(dim=-1)
        
        self.log('valid_loss', float(loss), on_step=False, on_epoch=True, prog_bar=True)
        self.log("valid_accuracy", self.valid_accuracy(preds[labels > 0], labels[labels > 0]), on_step=False, on_epoch=True, prog_bar=True, logger=True)
    
    
    def test_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        labels = batch['labels']
        
        output = self(input_ids, labels)

        loss = output.loss
        logits = output.logits

        preds = logits.argmax(dim=-1)
        
        self.log('test_loss', float(loss), on_step=False, on_epoch=True, prog_bar=True)
        self.log("test_accuracy", self.test_accuracy(preds[labels > 0], labels[labels > 0]), on_step=False, on_epoch=True, prog_bar=True, logger=True)
    
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=5e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
    
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
    

model = Bert(config)
callbacks = [
    ModelCheckpoint(monitor='valid_loss', save_top_k=10, dirpath='weights/molecule_bert_pretraining_masking_rate_30', filename='molecule_bert-{epoch:03d}-{valid_loss:.4f}-{valid_accuracy:.4f}'),
]

trainer = pl.Trainer(max_epochs=50, gpus=1, enable_progress_bar=True, callbacks=callbacks, precision=16)

In [None]:
trainer.fit(model, train_dataloader, valid_dataloader)

In [None]:
ckpt_fname = ""

model = Bert(config).load_from_checkpoint("weights/molecule_bert_pretraining_masking_rate_30/" + ckpt_fname)
trainer.test(model, test_dataloader)

In [None]:
model.model.base_model.save_pretrained("weights/molecule_bert")