In [None]:
cycle_start = 1
seeding_dataset = 'data/ChemBL-35-cleaned.csv' # initialized with 50k-ChemBL.csv
h5_input_data = 'data/'+str(cycle_start)+'-inp.h5'

# Data File Preparation

In [None]:
import argparse
import pandas as pd
import h5py
import numpy as np
from molecules.utils import one_hot_array, one_hot_index
from functools import reduce
from sklearn.model_selection import train_test_split
from rdkit import Chem
from rdkit.Chem import AllChem
from group_selfies import (
    fragment_mols, 
    Group, 
    MolecularGraph, 
    GroupGrammar, 
    group_encoder
)

from rdkit.Chem import rdmolfiles
from rdkit.Chem.Draw import IPythonConsole

import IPython.display # from ... import display
from test_utils import *
from rdkit import RDLogger

RDLogger.DisableLog('rdApp.*') 

import os
import sys
from rdkit.Chem import RDConfig
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
import sascorer
from rdkit.Chem import QED
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import gzip
import pandas
import h5py
import numpy as np
from __future__ import print_function
import argparse
import os
import h5py
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn import model_selection
     

def one_hot_array(i, n):
    return map(int, [ix == i for ix in xrange(n)])

def one_hot_index(vec, charset):
    return map(charset.index, vec)

def from_one_hot_array(vec):
    oh = np.where(vec == 1)
    if oh[0].shape == (0, ):
        return None
    return int(oh[0][0])

def decode_smiles_from_indexes(vec, charset):
    # Ensure that each element in 'vec' is a string (not numpy.bytes_)
    return "".join(map(lambda x: str(charset[x], 'utf-8') if isinstance(charset[x], bytes) else charset[x], vec)).strip()

def load_dataset(filename, split = True):
    h5f = h5py.File(filename, 'r')
    if split:
        data_train = h5f['data_train'][:]
    else:
        data_train = None
    data_test = h5f['data_test'][:]
    charset =  h5f['charset'][:]
    h5f.close()
    if split:
        return (data_train, data_test, charset)
    else:
        return (data_test, charset)

In [None]:
MAX_NUM_ROWS = 100000000 #all
SMILES_COL_NAME = 'Original_SMILES' # smiles
PROPERTY_COL_NAME = 'QED'
PROPERTY_COL_NAME2 = 'SA_score'
CHUNK_SIZE = 1000

class dotdict(dict): 
  __getattr__ = dict.get
  __setattr__ = dict.__setitem__
  __delattr__ = dict.__delitem__

args = dotdict()
args.infile = seeding_dataset
args.outfile = h5_input_data
args.length = MAX_NUM_ROWS
args.smiles_column = SMILES_COL_NAME

In [None]:
def chunk_iterator(dataset, chunk_size=CHUNK_SIZE):
    # Split the indices into chunks
    chunk_indices = np.array_split(np.arange(len(dataset)), len(dataset) // chunk_size)
    for chunk_ixs in chunk_indices:
        chunk = dataset[chunk_ixs]
        yield (chunk_ixs, chunk)

In [None]:
def create_chunk_dataset(h5file, dataset_name, dataset, dataset_shape,
                         chunk_size=CHUNK_SIZE, apply_fn=None):
    # Create the HDF5 dataset with the specified shape and chunk size
    new_data = h5file.create_dataset(dataset_name, dataset_shape,
                                     chunks=tuple([chunk_size] + list(dataset_shape[1:])))
    
    # Iterate through chunks
    for chunk_ixs, chunk in chunk_iterator(dataset):
        if apply_fn:
            encoded_data = np.array([list(apply_fn(i)) for i in chunk], dtype=np.float32)
        else:
            encoded_data = np.array(chunk, dtype=np.float32)

        # Assign the encoded data back into the HDF5 dataset
        new_data[chunk_ixs.tolist(), ...] = encoded_data


def one_hot_encoded_fn(row):
    # This function should return a list, not a map
    result = [one_hot_array(x, len(charset)) for x in one_hot_index(row, charset)]
    return result

def one_hot_array(i, n):
    #print(f"One hot array for index {i} of size {n}")
    return [int(ix == i) for ix in range(n)]

def one_hot_index(vec, charset):
    return [charset.index(x) for x in vec]

In [None]:
data = pd.read_csv(args.infile)
keys = data[args.smiles_column].map(len) < 121 # # Filter rows based on SMILES length

if args.length <= len(keys):
    data = data[keys].sample(n=args.length)
else:
    data = data[keys]

# Ensure that all SMILES strings are padded to 120 characters
structures = data[args.smiles_column].map(lambda x: list(x.ljust(120)))

if args.property_column:
    properties = data[args.property_column][keys]

if args.property_column2:
    properties2 = data[args.property_column2][keys]

del data  # Clean up to save memory

train_idx, test_idx = map(np.array, train_test_split(structures.index, test_size=0.05, random_state=42))

# Create the charset from the unique characters in the SMILES strings
# charset = list(reduce(lambda x, y: set(y) | x, structures, set()))

In [None]:
charset = ['O', 'a', '[', 'e', 'K', 't', '4', 'o', '1', 'P', ']', 'p', 'l', 'X', '8', '3', 'Z', '-', 'S', 'L', '=', 'F', 'M', '.', 'C', ' ', 'r', 'T', 'N', '2', '0', 'R', '5', 'i', '/', 'b', 's', '+', '9', 'H', 'c', '@', '(', 'I', 'g', 'A', 'B', '7', '6', '#', '%', '\\', ')', 'n']
print(len(charset))

# Training

In [None]:
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torch.optim as optim
import gzip
import h5py
import argparse
import os
import h5py
import time
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn import model_selection

In [None]:
from molecules.model import vae_loss, dotdict, one_hot_index, decode_smiles_from_indexes, load_dataset
from molecules.model import load_dataset_chunked

In [None]:
from __future__ import print_function

def from_one_hot_array(vec):
    oh = np.where(vec == 1)
    if oh[0].shape == (0, ):
        return None
    return int(oh[0][0])

In [None]:
class MolecularVAE(nn.Module):
    def __init__(self):
        super(MolecularVAE, self).__init__()

        self.conv_1 = nn.Conv1d(120, 9, kernel_size=9)
        self.conv_2 = nn.Conv1d(9, 9, kernel_size=9)
        self.conv_3 = nn.Conv1d(9, 10, kernel_size=11)
        self.linear_0 = nn.Linear(280, 435) # changed from 70 to 280 to reflect the change of charset size
        self.linear_1 = nn.Linear(435, 292)
        self.linear_2 = nn.Linear(435, 292)
        
        self.linear_3 = nn.Linear(292, 292)
        self.gru = nn.GRU(292, 501, 3, batch_first=True)
        self.linear_4 = nn.Linear(501, 54) # changed this output from 33 to 54 to reflect the larger charset size
        
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax()

    def encode(self, x):
        x = self.relu(self.conv_1(x))
        x = self.relu(self.conv_2(x))
        x = self.relu(self.conv_3(x))
        x = x.view(x.size(0), -1)
        x = F.selu(self.linear_0(x))
        return self.linear_1(x), self.linear_2(x)

    def sampling(self, z_mean, z_logvar):
        epsilon = 1e-2 * torch.randn_like(z_logvar)
        return torch.exp(0.5 * z_logvar) * epsilon + z_mean

    def decode(self, z):
        z = F.selu(self.linear_3(z))
        z = z.view(z.size(0), 1, z.size(-1)).repeat(1, 120, 1)
        output, hn = self.gru(z)
        out_reshape = output.contiguous().view(-1, output.size(-1))
        y0 = F.softmax(self.linear_4(out_reshape), dim=1)
        y = y0.contiguous().view(output.size(0), -1, y0.size(-1))
        return y

    def forward(self, x):
        z_mean, z_logvar = self.encode(x)
        z = self.sampling(z_mean, z_logvar)
        return self.decode(z), z_mean, z_logvar

In [None]:
args = dotdict()

In [None]:
torch.manual_seed(42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MolecularVAE().to(device)

train_losses_p = []
val_losses_p = []
optimizer = optim.Adam(model.parameters())#, lr=1e-4) # First 60 with 1e-3, next 60 with 5.e-4, lastly with 1.e-4

In [None]:
def make_losses_graph(train_losses_p, val_losses_p, ymax=100, ymin=0, printed=20):
    df_losses_dict = {
    'train': [float(loss) for loss in train_losses_p],
    'val': [float(loss) for loss in val_losses_p],
    }
    df_losses = pd.DataFrame(df_losses_dict).reset_index(drop=True)#, inplace=True)

    sns.set_style("whitegrid")
    sns.lineplot(df_losses)
    plt.ylabel('Losses')
    plt.ylim(ymin, ymax)
    sns.set_style("ticks")
    #plt.yscale('log')
    losses_graph = pd.DataFrame.from_dict(df_losses_dict)
    losses_graph.index = losses_graph.index + 1
    print(losses_graph.tail(printed))
    print(losses_graph.nsmallest(3,['val']))

In [None]:
make_losses_graph(train_losses_p, val_losses_p, printed=20, ymax=15, ymin=7)

In [None]:
args.batch_size = 100 # was 500
optimizer = optim.Adam(model.parameters(), lr=1e-4) # First 60 with 1e-3, next 60 with 5.e-4, lastly with 1.e-4 | 35/17 train/val when 200 batch size and 10-4

In [None]:
time0 = time.time()
args.epochs = 300 # was 30
args.report_epochs = 1
args.report_epochs2 = args.epochs/5
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def train(epoch):
    time1 = time.time()
    model.train()
    train_loss = 0
    counter = 0
    printed_batch = np.random.randint(0, 45)

    torch.cuda.empty_cache()
    train_batches, data_test, charset = load_dataset_chunked(h5_input_data, split=True, batch_size=50000)
    data_test = torch.utils.data.TensorDataset(torch.from_numpy(data_test))
    test_loader = torch.utils.data.DataLoader(data_test, batch_size=args.batch_size, shuffle=True)
    
    for batch in train_batches:
        if counter == 20:
            timeit = (time.time()-time1)/60.
            timeit_all = 2 * timeit
            print(f"Processing batch {counter:.0f} shaped: {batch.shape}, done in {timeit:.2f} min, so 44 batches around {timeit_all:.0f} min")
        
        batch_tensor = torch.from_numpy(batch)
        current_train_batch = torch.utils.data.TensorDataset(batch_tensor)
        train_loader = torch.utils.data.DataLoader(current_train_batch, batch_size=args.batch_size, shuffle=True)

        for batch_idx, data in enumerate(train_loader):
            data = data[0]#.reshape(data[0].shape[0], 30, 107)
            data = data.to(device)
            optimizer.zero_grad()
            output, mean, logvar = model(data)
            loss = vae_loss(output, data, mean, logvar)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
            
        if counter == printed_batch:
            inp = data[0].cpu().numpy()
            outp = output.cpu().detach().numpy()
            sampled = outp[0].reshape(1, 120, len(charset)).argmax(axis=2)[0]
            print("Input/Label vs Reconstructed from training set, ", str(counter), " :")
            print(decode_smiles_from_indexes(map(from_one_hot_array, inp), charset))
            print(decode_smiles_from_indexes(sampled, charset))
        counter = counter + 1
        
    train_loss_p = train_loss / 2258093 # len(train_loader.dataset) # .cpu().detach().numpy()
    train_losses_p.append(train_loss_p)
    print(train_loss_p)
    
    # Validation step
    model.eval()
    with torch.no_grad():
        val_loss = 0
        for batch_idx, val_data in enumerate(test_loader):
            val_data = val_data[0].to(device)
            output_val, mean, log_var = model(val_data)
            val_loss += vae_loss(output_val, val_data, mean, log_var).item()

        val_loss_p = val_loss / len(test_loader.dataset)
        val_losses_p.append(val_loss_p)
        
        inp = val_data[0].cpu().numpy()
        outp = output_val.cpu().detach().numpy()
        sampled = outp[0].reshape(1, 120, len(charset)).argmax(axis=2)[0]
        print("Input/Label vs Reconstructed from validation set:")
        print(decode_smiles_from_indexes(map(from_one_hot_array, inp), charset))
        print(decode_smiles_from_indexes(sampled, charset))

    timeit = (time.time()-time0)/60.
    print(f'Epoch [{epoch}], Loss: {train_loss_p:.4f}, Validation Loss: {val_loss_p:.4f}, Time: {timeit:.2f} min')

    #if epoch % args.report_epochs2 == 0:
    PATH = "checkpoint_"+str(epoch)+".pth"
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_losses_p': train_losses_p,
        'val_losses_p': val_losses_p,}, PATH)
    
    return train_loss_p    


# keep here
last_epoch = len(train_losses_p)
for epoch in range(last_epoch +1, args.epochs + 1):
    print('Initiating Epoch ', epoch)
    train_loss = train(epoch)

In [None]:
print('trained on ', len(train_losses_p), len(val_losses_p), ' epochs')

In [None]:
df_losses_dict = {
    'train': [float(loss) for loss in train_losses_p],
    'val': [float(loss) for loss in val_losses_p],
}

#df_losses_dict = {'train': train_losses_p, 'val': val_losses_p} 
df_losses = pd.DataFrame(df_losses_dict).reset_index(drop=True)#, inplace=True)

sns.set_style("whitegrid")
sns.lineplot(df_losses)
plt.ylabel('Losses')
#plt.ylim(25, 50)
sns.set_style("ticks")
#sns.set_style("whitegrid", {'axes.grid' : False})
#plt.xlim(0, 490)
#plt.yscale('log')

In [None]:
# https://pytorch.org/tutorials/beginner/saving_loading_models.html
PATH = "VAE_new_"+str(cycle_start)+"-cycle.pth"

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_losses_p': train_losses_p,
            'val_losses_p': val_losses_p,
            }, PATH)

In [None]:
all_means = []
all_labels = []

model.eval()
with torch.no_grad():
    for img in train_loader:
        mean, _ = model.encode(img[0].to(device))
        all_means.append(mean.cpu())
        for imagine in img[0]:
            all_labels.append(decode_smiles_from_indexes(map(from_one_hot_array, imagine), charset))

latent_data = torch.cat(all_means, dim=0)

# Plotting with color coding for labels
plt.figure(figsize=(8, 6))
plt.scatter(latent_data[:, 0], latent_data[:, 1], alpha=0.5, s=1)
#plt.title("Latent Space Representation with Labels")
plt.xlabel("Latent Dimension 1")
plt.ylabel("Latent Dimension 2")
#plt.colorbar(label='Length')
plt.grid()
plt.show()