In [None]:
! pip install rdkit
! pip install torch_geometric

import numpy as np
from rdkit import Chem
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import pandas as pd
import multiprocessing as mp
from utils import *
from model import *
from sklearn.model_selection import train_test_split

# Load Data & Model

In [None]:
smiles_path = 'data/smiles.txt' # smiles dataset
data_path = 'data/qed_properties.csv' # properties we want to optimize

df_smiles = pd.read_csv(smiles_path)
datas = pd.read_csv(data_path)

In [None]:
def get_data(r):
    smiles = [df_smiles.loc[r, 'Whole'], df_smiles.loc[r, 'Core']]
    cond = datas.loc[r].to_list()[1:]
    whole = [cond[0]]
    scaff = [cond[2]]
    return [smiles, whole, scaff]

In [None]:
pool = mp.Pool(processes=6)
results = pool.map(get_data, range(len(df_smiles)))
pool.close()
pool.join()

In [None]:
smiles, whole_conditions, scaffold_conditions = list(zip(*results))

# Load & train model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = MolVAE(33, 4, 128, 0, rnn=True)
model = model.to(device)
model.load_state_dict(torch.load('weights/rnn_weights.pt')) # weight of pretrained model

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer.load_state_dict(torch.load('weights/rnn_optimizer.pt'))

In [None]:
train_indices, test_indices = train_test_split(range(len(smiles)), random_state=1, test_size=0.2)

In [None]:
train_dataset = GraphData(smiles, None, train_indices)
test_dataset = GraphData(smiles, None, test_indices)

train_dl = GraphLoader(train_dataset, shuffle=True, batch_size=100, num_workers=6)
test_dl = GraphLoader(test_dataset, shuffle=False, batch_size=100, num_workers=3)

In [None]:
# cyclical annealing

betas = frange_cycle_linear(0, 1, len(train_dl), 1, 0.8)

In [None]:
def train(model, train_dl, epochs, betas, save=False, save_path=None):

    for epoch in range(epochs):
        comb_losses = []
        recon_losses = []
        vae_losses = []

        progress = tqdm(train_dl)

        for _, batch in enumerate(progress):
            model.train()
            optimizer.zero_grad()

            steps, whole = batch
            whole.condition = whole.condition.float()

            pred = model(whole.to(device), steps.to(device), None)

            recon_loss, vae_loss = pred
            loss = recon_loss + vae_loss * betas[_]

            comb_losses.append(loss.data.cpu().item())
            recon_losses.append(recon_loss.data.cpu().item())
            vae_losses.append(vae_loss.data.cpu().item())

            loss.backward()

            optimizer.step()

            progress.set_description(f'Epoch [{epoch + 1}/{epochs}]')
            progress.set_postfix(total_loss=np.mean(comb_losses), recon_loss=np.mean(recon_losses), vae_loss=np.mean(vae_losses), current_vae_loss=vae_loss.item(), beta=betas[_])

            if _ % 1000 == 0:
                torch.cuda.empty_cache()
                gc.collect()
                
                if save is True:
                    torch.save(model.state_dict(), f'{pre}_weights.pt')
                    torch.save(optimizer.state_dict(), f'{pre}_optimizer.pt')

        torch.cuda.empty_cache()
        gc.collect()

        if save is True:
            torch.save(model.state_dict(), f'{pre}_weights.pt')
            torch.save(optimizer.state_dict(), f'{pre}_optimizer.pt')
            
        return comb_losses, recon_losses, vae_losses
            
def evaluate(test_dl):
    comb_losses = []
    recon_losses = []
    vae_losses = [0]
            
    progress = tqdm(test_dl)

    for _, batch in enumerate(progress):
        model.eval()
        
        steps, whole = batch
        whole.condition = torch.Tensor(whole.condition).float()

        pred = model(whole.to(device), steps.to(device), whole.condition)
        
        recon_loss, vae_loss = pred
        loss = recon_loss + vae_loss * 5e-4
        comb_losses.append(loss.data.cpu().item())
        recon_losses.append(recon_loss.data.cpu().item())
        vae_losses.append(vae_loss.data.cpu().item())
        
        progress.set_description(f'Epoch [{epoch}/{epochs}]')
        progress.set_postfix(total_loss=np.mean(comb_losses), recon_loss=np.mean(recon_losses), vae_loss=np.mean(vae_losses))
        
    return np.mean(comb_losses), np.mean(recon_losses), np.mean(vae_losses)

In [None]:
total_loss, recon_loss, vae_loss = train(train_dl, 5, betas)

# Transfer Learning

In [None]:
fine_tune = smiles[:1000] # placeholder
fine_tune = [Chem.MolToSmiles(x) for x in fine_tune]
fine_tune = [x for x in fine_tune if make_graph(x) != None]

In [None]:
ft_smiles = []

for smi in tqdm(fine_tune):
    cores = get_cores(smi)
    if cores == None:
        continue
    for core in cores:
        ft_smiles.append([smi, core])

In [None]:
ft_cond = [None] * len(ft_smiles) # if model is conditioned replace this w/ the list of qeds of the whole molecule

In [None]:
train_dataset = GraphData(ft_smiles, ft_cond, range(len(ft_smiles)), key=False)

train_dl = GraphLoader(train_dataset, shuffle=True, batch_size=100, num_workers=6)

In [None]:
model = MolVAE(33, 4, 128, 0, rnn=True)
model = model.to(device)
model.load_state_dict(torch.load('weights/rnn_weights.pt'))
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
betas = frange_cycle_linear(0, 0.4, len(progress), 1, 1) # reduce betas due to less data

In [None]:
comb_loss, recon_loss, vae_loss = train(model, train_dl, betas)