<a href="https://colab.research.google.com/github/jhwnoh/UST-GenerativeModels/blob/main/P2_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1. Define Data,Model,Sampling

Install rdkit & convert SMILES to molecules

In [None]:
!pip install rdkit

import rdkit
from rdkit.Chem import MolFromSmiles,MolToSmiles
import rdkit
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

Collecting rdkit
  Downloading rdkit-2023.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m30.5/30.5 MB[0m [31m54.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: rdkit
Successfully installed rdkit-2023.9.2


Define required packages

In [None]:
import numpy as np
import pandas as pd
from collections import defaultdict
from tqdm import tqdm

import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
import torch.optim as optim

In [None]:
class MolData(Dataset):
  def __init__(self,smis,toks):
    self.smis = smis
    self.toks = toks + ['<','>'] #'<'; start of sequence, '>'; end of sequence
    self.Ntok = len(toks)
    self.Nmax = 120

  def __len__(self):
    return len(self.smis)

  def __getitem__(self,idx):
    smi = '<'+self.smis[idx]+'>'
    smi += '>'*(self.Nmax-len(smi))

    x_all = np.array([self.toks.index(s) for s in smi]).flatten()
    y = x_all[1:] #output

    x = torch.LongTensor(x_all)
    y = torch.LongTensor(y)
    return x,y

In [None]:
class MolVAE(nn.Module):
    def __init__(self,dim_x0,dim_x1,dim_h,n_layer,d_ratio,dim_z):
        super(MolVAE,self).__init__()
        self.n_layer = n_layer
        self.emb_layer = nn.Embedding(dim_x0,dim_x1)

        self.enc = nn.GRU(dim_x1,dim_h,
                        num_layers=n_layer,
                        dropout = d_ratio,
                        batch_first = True)

        self.fc_z1 = nn.Sequential(
                            nn.Linear(dim_h,dim_h),
                            nn.ReLU(),
                            nn.Linear(dim_h,2*dim_z))

        self.fc_z2 = nn.Linear(dim_z,dim_h)

        self.dec = nn.GRU(dim_x1+dim_z,dim_h,
                        num_layers=n_layer,
                        dropout = d_ratio,
                        batch_first = True)

        self.out = nn.Sequential(
                  nn.Linear(dim_h,dim_h),
                  nn.ReLU(),
                  nn.Linear(dim_h,dim_x0))

    def forward(self,x):
        x_emb = self.emb_layer(x)

        mu,log_var = self.encoder(x_emb)
        eps = torch.randn_like(mu)
        z = mu + eps*torch.exp(log_var/2)

        out = self.decoder(x_emb[:,:-1],z)
        return out,mu,log_var

    def encoder(self,x):
        _,h1 = self.enc(x,None)
        h2 = self.fc_z1(h1[-1])
        mu,log_var = torch.chunk(h2,2,dim=-1)
        return mu,log_var

    def decoder(self,x,z):
        N,L,F = x.shape
        h0_z = z.unsqueeze(1).repeat(1,L,1)

        x_in = torch.cat([x,h0_z],dim=-1)

        h0_rnn = self.fc_z2(z).unsqueeze(0).repeat(self.n_layer,1,1)
        out,h_d = self.dec(x_in,h0_rnn)
        out = self.out(out)
        return out

    def sampling(self,x0,z,h0=None,is_first=True):
        x = self.emb_layer(x0)

        N,L,F = x.shape
        h0_z = z.unsqueeze(1).repeat(1,L,1)
        x_in = torch.cat([x,h0_z],dim=-1)

        if is_first:
            h0 = self.fc_z2(z).unsqueeze(0).repeat(self.n_layer,1,1)

        out,h1 = self.dec(x_in,h0)
        out = self.out(out)
        return out,h1

In [None]:
def Sampling(sampler,dim_z,n_sample,max_len,tok_lib):
    sampler.eval()
    with torch.no_grad():
        inits = torch.LongTensor([34]*n_sample)
        loader = DataLoader(inits,batch_size=100)

        Sampled = []
        Zs = []
        for inp in tqdm(loader):
            x_in = inp.reshape(-1,1)

            x_hat = []
            z = torch.randn(len(x_in),dim_z)
            h = None
            is_first = True
            for seq_iter in range(max_len):

                if seq_iter > 0:
                    is_first = False

                out,h = sampler.sampling(x_in,z,h,is_first)
                prob = F.softmax(out,dim=-1).squeeze(1)
                x_in = torch.multinomial(prob,1)

                x_hat.append(x_in.cpu().detach().numpy())

            x_hat = np.hstack(x_hat)
            Sampled.append(x_hat)
            Zs.append(z.cpu().detach().numpy())

        Sampled = np.vstack(Sampled)
        Zs = np.vstack(Zs)

        Mols = []
        Idx = []
        for i,s in enumerate(Sampled):
            n_end = np.sum(s==35)

            if n_end == 0:
                continue

            n = np.min(np.where(s==35)[0])
            m = ''.join(tok_lib[s[:n]].tolist())
            Mols.append(m)
            Idx.append(i)

        Vals = []
        Lat = []
        for ii in Idx:
            smi = Mols[ii]
            mol = MolFromSmiles(smi)
            if not mol is None:
                Vals.append(MolToSmiles(mol))
                Lat.append(Zs[ii])

        Uni = list(set(Vals))
        return Vals,Lat,len(Vals),len(Uni)

# 2. Trainer

In [None]:
def LinearAnnealing(n_iter, start=0.0, stop=1.0,  n_cycle=4, ratio=0.5):
    L = np.ones(n_iter) * stop
    period = n_iter/n_cycle
    step = (stop-start)/(period*ratio) # linear schedule

    for c in range(n_cycle):
        v, i = start, 0
        while v <= stop and (int(i+c*period) < n_iter):
            L[int(i+c*period)] = v
            v += step
            i += 1
    return L

In [None]:
df = pd.read_csv('https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/main/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv')
smis_ = [ss.split()[0] for ss in df['smiles']]

toks = []
for smi in tqdm(smis_):
  toks += list(set(smi))
  toks = list(set(toks))
toks = list(set(toks))

n_train = 1000
n_val = 1000

np.random.seed(1)
np.random.shuffle(smis_)

smi_train = smis_[:n_train]
smi_val = smis_[n_train:n_train+n_val]

batch_size = 64

train_data = MolData(smi_train,toks)
tok_lib = np.array(train_data.toks) # For sampling
train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True)

test_data = MolData(smi_val,toks)
test_loader = DataLoader(test_data,batch_size=batch_size,shuffle=False)

DimZ = 156
model = MolVAE(36,64,256,2,0.2,DimZ)

lr = 2e-4
ce_loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=lr)

num_epoch = 200
max_norm = 5

LOGs = []
Betas = LinearAnnealing(n_iter=num_epoch,start=0.0,stop=0.2).tolist()

for ep in range(num_epoch):
    model.train()
    for inp in tqdm(train_loader):
        x_in = inp[0]
        tgt = inp[1].view(-1)

        x_out,mu,log_var = model(x_in)

        rec = ce_loss(x_out.reshape(-1,36),tgt)
        kld = torch.mean(0.5*(mu**2+torch.exp(log_var)-log_var-1))

        loss = rec + Betas[ep]*kld

        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm_(model.parameters(), max_norm)
        optimizer.step()

    model.eval()
    with torch.no_grad():
        Sim = []
        Mus = []
        Stds = []
        KLDs = 0
        Ns = 0
        for inp in tqdm(test_loader):
            x_in = inp[0]
            tgt = inp[1].view(-1)

            x_out,mu,log_var = model(x_in)

            kld = torch.sum(torch.mean(0.5*(mu**2+torch.exp(log_var)-log_var-1),-1))

            KLDs += kld.cpu().detach().numpy().flatten()[0]
            Ns += len(x_in)

            id_out = np.argmax(x_out.cpu().detach().numpy(),-1)
            id_in = x_in[:,1:].cpu().detach().numpy()
            acc = np.mean(id_out==id_in,1).reshape(-1,1)

            Sim.append(acc)
            Mus.append(mu.cpu().detach().numpy())
            Stds.append(torch.exp(log_var/2).cpu().detach().numpy())

        Sim = np.vstack(Sim)
        Mus = np.vstack(Mus)
        Stds = np.vstack(Stds)
        mols,z_mol,val,uniq = Sampling(model,DimZ,1000,100,tok_lib)

        print(ep,Betas[ep],np.min(Sim),np.max(Sim),np.mean(Sim),np.std(Sim),KLDs/Ns,val,uniq)


# 3. Use pre-trained model

In [None]:
DimZ = 156
model = MolVAE(36,128,480,3,0.2,DimZ)

chkpt = torch.load('your/path',map_location='cpu')
model.load_state_dict(model['state_dict'])

mols,z_mol,val,uniq = Sampling(model,DimZ,1000,100,tok_lib)

100%|██████████| 1584663/1584663 [00:04<00:00, 319626.70it/s]
100%|██████████| 10/10 [00:21<00:00,  2.10s/it]
