In [1]:
import os
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import train_test_split

from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import pytorch_lightning as pl
import torchmetrics.functional as tm
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

In [2]:
smiles_df = pd.read_csv("data/mol_id_to_smiles.csv")

train_data = torch.load("data/qm9_train_data.pt")
X_train = train_data['mol_id']
X_train = pd.DataFrame(X_train, columns=["ID"]).merge(smiles_df, how="left", left_on="ID", right_on="id")['smiles'].values
y_train = train_data['mu']

X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=0.2, random_state=42)

test_data = torch.load("data/qm9_test_data.pt")
X_test = test_data["mol_id"]
X_test = pd.DataFrame(X_test, columns=["ID"]).merge(smiles_df, how="left", left_on="ID", right_on="id")['smiles'].values

In [3]:
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("jonghyunlee/DrugLikeMoleculeBERT")
encoder = AutoModel.from_pretrained("jonghyunlee/DrugLikeMoleculeBERT")

In [4]:
class SmilesDataset(Dataset):
    def __init__(self, tokenizer, X, y=None, max_length=128, is_predict=False):
        self.X = X
        self.is_predict = is_predict

        if not self.is_predict:
            self.y = y
        else:
            self.y = None
            
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def encode(self, sequence):
        return self.tokenizer.encode_plus(" ".join(sequence), max_length=self.max_length, truncation=True)
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        smiles = self.encode(self.X[idx])
        
        if not self.is_predict:
            y = self.y[idx]
            return smiles, y
        else:
            return smiles
    
    
def collate_batch_train(batch):
    smiles, y = [], []
    
    for (smiles_, y_) in batch:
        smiles.append(smiles_)
        y.append(y_)
        
    smiles = tokenizer.pad(smiles, return_tensors="pt")
    y = torch.tensor(y).float()
    
    return smiles, y
    
    
def collate_batch_test(batch):
    smiles = []
    
    for (smiles_) in batch:
        smiles.append(smiles_)
        
    smiles = tokenizer.pad(smiles, return_tensors="pt")
    
    return smiles


train_dataset = SmilesDataset(tokenizer, X_train, y_train)
train_dataloader = DataLoader(train_dataset, batch_size=256, 
                              collate_fn=collate_batch_train, 
                              num_workers=16, pin_memory=True, 
                              prefetch_factor=10, drop_last=True)

valid_dataset = SmilesDataset(tokenizer, X_valid, y_valid)
valid_dataloader = DataLoader(valid_dataset, batch_size=256, 
                              collate_fn=collate_batch_train, 
                              num_workers=16, pin_memory=True, 
                              prefetch_factor=10)

test_dataset = SmilesDataset(tokenizer, X_test, is_predict=True)
test_dataloader = DataLoader(test_dataset, batch_size=256, 
                             collate_fn=collate_batch_test, 
                             num_workers=16, pin_memory=True, 
                             prefetch_factor=10)



In [5]:
class BERT(nn.Module):
    def __init__(self, encoder, input_dim=128, hidden_dim=512):
        super().__init__()
        self.encoder = encoder
        
        for param in self.encoder.encoder.layer[0:-1].parameters():
            param.requires_grad = False
        
        self.align = nn.Sequential(
            nn.LayerNorm(input_dim),
            nn.Linear(input_dim, hidden_dim)
        )
        
        self.fc1 = nn.Linear(hidden_dim, hidden_dim * 2)
        self.fc2 = nn.Linear(hidden_dim * 2, hidden_dim * 2)
        self.fc3 = nn.Linear(hidden_dim * 2, hidden_dim)
        
        self.fc_out = nn.Linear(hidden_dim, 1)
        
    
    def forward(self, smiles):
        x = self.encoder(**smiles)
        x = x.pooler_output
        x = self.align(x)

        x = F.gelu(self.fc1(x))
        x = F.gelu(self.fc2(x))
        x = F.gelu(self.fc3(x))
        out = self.fc_out(x)
        
        return out

In [6]:
class MoleculePropertyPredictor(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
    
    
    def step(self, batch):
        smiles, y = batch
        pred = self.model(smiles).squeeze(-1)
        loss = F.l1_loss(pred, y)
        acc = tm.mean_squared_error(pred, y)
        
        return loss, acc
    
       
    def training_step(self, batch, batch_idx):
        loss, acc = self.step(batch)
        
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        
        return loss

    
    def validation_step(self, batch, batch_idx):
        loss, acc = self.step(batch)
        
        self.log('valid_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("valid_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
    
    
    def predict_step(self, batch, batch_idx):
        return self.model(batch)
    
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
    
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
    
    
model = BERT(encoder)
predictor = MoleculePropertyPredictor(model)
callbacks = [
    ModelCheckpoint(monitor='valid_loss', save_top_k=3, dirpath='weights/BERT', filename='BERT-{epoch:03d}-{valid_loss:.4f}-{valid_acc:.4f}'),
]

trainer = pl.Trainer(max_epochs=500, gpus=1, enable_progress_bar=True, callbacks=callbacks)

  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(predictor, train_dataloader, valid_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type | Params
-------------------------------
0 | model | BERT | 3.8 M 
-------------------------------
2.4 M     Trainable params
1.4 M     Non-trainable params
3.8 M     Total params
15.179    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [None]:
ckpt_fname = ""
predictor = predictor.load_from_checkpoint("weights/BERT/" + ckpt_fname, model=model)

pred = trainer.predict(predictor, test_dataloader)

In [None]:
preds = []

def to_np(x):
    return x.cpu().detach().numpy()

for p in tqdm(pred):
    preds.append(to_np(p))

preds = np.concatenate(preds, axis=0)
np.savetxt('pred.csv', preds)