In [None]:
cycle = 1
seeding_dataset = 'data/ChemBL-35-cleaned.csv' # initialized with 50k-ChemBL.csv

trainingset_path  = 'data/'+str(cycle)+'-training_set_org.csv'  
charset_path= 'data/1-0.001-inp.h5'
latent_dataset_path = 'data/'+str(cycle)+'-training_set_latent_space.csv'
PATH = "checkpoint_239.pth"

In [None]:
from rdkit import Chem
from rdkit.Chem import AllChem
from group_selfies import (
    fragment_mols, 
    Group, 
    MolecularGraph, 
    GroupGrammar, 
    group_encoder
)

from rdkit.Chem import rdmolfiles
from rdkit.Chem.Draw import IPythonConsole

import IPython.display # from ... import display
from test_utils import *
from rdkit import RDLogger

RDLogger.DisableLog('rdApp.*') 

import os
import sys
from rdkit.Chem import RDConfig
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
import sascorer
from rdkit.Chem import QED

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import gzip
import pandas
import h5py
import numpy as np
from __future__ import print_function
import argparse
import os
import h5py
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn import model_selection


def one_hot_array(i, n):
    return map(int, [ix == i for ix in xrange(n)])

def one_hot_index(vec, charset):
    return map(charset.index, vec)

def from_one_hot_array(vec):
    oh = np.where(vec == 1)
    if oh[0].shape == (0, ):
        return None
    return int(oh[0][0])

def decode_smiles_from_indexes(vec, charset):
    # Ensure that each element in 'vec' is a string (not numpy.bytes_)
    return "".join(map(lambda x: str(charset[x], 'utf-8') if isinstance(charset[x], bytes) else charset[x], vec)).strip()

def load_dataset(filename, split = True):
    h5f = h5py.File(filename, 'r')
    if split:
        data_train = h5f['data_train'][:]
    else:
        data_train = None
    data_test = h5f['data_test'][:]
    charset =  h5f['charset'][:]
    h5f.close()
    if split:
        return (data_train, data_test, charset)
    else:
        return (data_test, charset)

class MolecularVAE(nn.Module):
    def __init__(self):
        super(MolecularVAE, self).__init__()

        self.conv_1 = nn.Conv1d(120, 9, kernel_size=9)
        self.conv_2 = nn.Conv1d(9, 9, kernel_size=9)
        self.conv_3 = nn.Conv1d(9, 10, kernel_size=11)
        self.linear_0 = nn.Linear(280, 435) # changed from 70 to 280 to reflect the change of charset size
        self.linear_1 = nn.Linear(435, 292)
        self.linear_2 = nn.Linear(435, 292)
        
        self.linear_3 = nn.Linear(292, 292)
        self.gru = nn.GRU(292, 501, 3, batch_first=True)
        self.linear_4 = nn.Linear(501, 54) # changed this output from 33 to 54 to reflect the larger charset size
        
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax()

    def encode(self, x):
        x = self.relu(self.conv_1(x))
        x = self.relu(self.conv_2(x))
        x = self.relu(self.conv_3(x))
        x = x.view(x.size(0), -1)
        x = F.selu(self.linear_0(x))
        return self.linear_1(x), self.linear_2(x)

    def sampling(self, z_mean, z_logvar):
        epsilon = 1e-2 * torch.randn_like(z_logvar)
        return torch.exp(0.5 * z_logvar) * epsilon + z_mean

    def decode(self, z):
        z = F.selu(self.linear_3(z))
        z = z.view(z.size(0), 1, z.size(-1)).repeat(1, 120, 1)
        output, hn = self.gru(z)
        out_reshape = output.contiguous().view(-1, output.size(-1))
        y0 = F.softmax(self.linear_4(out_reshape), dim=1)
        y = y0.contiguous().view(output.size(0), -1, y0.size(-1))
        return y

    def forward(self, x):
        z_mean, z_logvar = self.encode(x)
        z = self.sampling(z_mean, z_logvar)
        return self.decode(z), z_mean, z_logvar


def load_dataset_chunked(filename, split=True, batch_size=10000):
    # Open the HDF5 file explicitly
    h5f = h5py.File(filename, 'r')

    # Memory-mapping the data (this avoids loading the entire dataset into memory at once)
    data_test = np.array(h5f['data_test'], dtype='float32', copy=False)
    
    # Handle charset as strings directly
    charset = h5f['charset']
    if charset.dtype.kind in {'S', 'O'}:  # If it's a string or object type
        charset = [x.decode('utf-8') if isinstance(x, bytes) else x for x in charset]  # Decode bytes if needed
    else:
        charset = np.array(charset, dtype='float32', copy=False)
    
    if split:
        # Instead of loading the entire data_train, we'll iterate in chunks
        data_train = h5f['data_train']
        total_samples = data_train.shape[0]
        
        # Define the generator that reads data in chunks
        def data_batch_generator():
            """Generator to load data in batches."""
            for i in range(0, total_samples, batch_size):
                batch = data_train[i:i+batch_size]  # Read a batch from disk
                yield batch

        # Return the generator, data_test, and charset
        return (data_batch_generator(), data_test, charset)
    else:
        # If not splitting, return data_test and charset only
        return (data_test, charset)
    
    # Don't forget to close the file manually when done
    h5f.close()

In [None]:
import time
torch.manual_seed(42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = MolecularVAE().to(device)

torch.manual_seed(42)

class dotdict(dict): 
  __getattr__ = dict.get
  __setattr__ = dict.__setitem__
  __delattr__ = dict.__delitem__

args = dotdict()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
checkpoint = torch.load(PATH, weights_only=True)
model.load_state_dict(checkpoint['model_state_dict'])
epoch = checkpoint['epoch']
train_losses_p = checkpoint['train_losses_p']
val_losses_p = checkpoint['val_losses_p']

model.eval()

In [None]:
latent_data = pd.read_csv(latent_dataset_path)
_, _, charset = load_dataset(charset_path)
training_set_org = pd.read_csv(trainingset_path)

# Random

In [None]:
from torch.distributions.uniform import Uniform
#max_latent_values = latent_data.max(axis=0).values.tolist()
#min_latent_values = latent_data.max(axis=0).values.tolist()
#r1 = min(min_latent_values)
#r2 = max(max_latent_values)

def generating_samples(per_cycle, latent_dim):
    if sample_type == 'randn':
        z_mean = torch.randn(per_cycle, latent_dim).to(device) #mean 0 and variance 1
    elif sample_type == 'uniform':
    # Alternatively from uniform distribution
        z_mean = torch.FloatTensor(per_cycle, latent_dim).uniform_(r1, r2).to(device)
    elif sample_type == 'uniform_d':
        z_mean = Uniform(r1, r2).sample((per_cycle,latent_dim)).to(device)
    else:
        return None

    new_data = model.decode(z_mean)
    new_data_numpy = new_data.cpu().detach().numpy().reshape(per_cycle, -1)

    return new_data_numpy

def generating_samples_cycles(cycles, per_cycle, latent_dim):
    amount = cycles * per_cycle
    length = len(charset)*120
    new_data_numpy = np.empty((amount, length))
    for i in range(cycles):
        # Generate new samples
        tempi = generating_samples(per_cycle=per_cycle, latent_dim=latent_dim)
        # Concatenate the new data into the pre-allocated array
        new_data_numpy[i * per_cycle: (i + 1) * per_cycle] = tempi
    
    return new_data_numpy

sample_type = 'randn'

In [None]:
time0 = time.time()
valid_smiles = []
#print('min and max latent: ', r1, r2)

for cycles in range(1, 101):
    new_data_numpy=generating_samples_cycles(cycles=10, per_cycle=100, latent_dim=292)

    all_smiles = []
    for id, molecule in enumerate(new_data_numpy):
        all_smiles.append(decode_smiles_from_indexes(molecule.reshape(1, 120, len(charset)).argmax(axis=2)[0], charset))

    for smi in all_smiles:
        m = Chem.MolFromSmiles(smi,sanitize=True)
        if m is None:
            pass
        else:
            try:
                Chem.SanitizeMol(m)
                valid_smiles.append(smi)
            except:
                pass

    if cycles%10 == 0:
        print('%.2f' % (len(valid_smiles) / len(all_smiles) / cycles *100), '% of generated samples are valid samples, ', len(valid_smiles), ' out of: ', len(all_smiles)*(cycles), 'in: {0:2.2f} min'.format( (time.time()-time0)/60.))

In [None]:
from collections import OrderedDict
unique_smiles = OrderedDict((x, True) for x in valid_smiles).keys()
print('%.2f' % (len(unique_smiles) / len(valid_smiles)*100),  '% of generated valid samples are unique samples.')

In [None]:
print('Calculating QED/SAS for ', len(unique_smiles), ' molecules out of all samples')
df_generated = pd.DataFrame(unique_smiles, columns=["Original_SMILES"]).drop_duplicates(subset=['Original_SMILES'])
for index, row in df_generated.iterrows():
        try:
            mol = Chem.MolFromSmiles(row['Original_SMILES'])#+'OP(C)(=O)F')
            qed = QED.default(mol)
            try:
                sas_score = sascorer.calculateScore(mol)
            except:
                sas_score = np.nan
        except:
            sas_score = np.nan
            qed = np.nan
        
        df_generated.at[index, "QED"] = qed
        df_generated.at[index, "SA_score"] = sas_score
        df_generated.at[index, "Origin"] = 'random'

new = df_generated.dropna(subset=['QED', 'SA_score']).sort_values(['QED'], ascending=False)

In [None]:
new.describe()

In [None]:
sns.histplot(new['QED'], kde=False, bins=10)

In [None]:
new.to_csv('random.csv', index=False)

# Check encoded training set

In [None]:
newkl = model.decode(torch.tensor(latent_data.to_numpy(), dtype=torch.float32).to(device))

In [None]:
all_smiles = []
valid_smiles = []
invalid_smiles = []
attempt = newkl.cpu().detach().numpy().reshape(newkl.shape[0], -1)
attempt.shape

In [None]:
    for id, molecule in enumerate(attempt):
        all_smiles.append(decode_smiles_from_indexes(molecule.reshape(1, 120, len(charset)).argmax(axis=2)[0], charset))

    for smi in all_smiles:
        m = Chem.MolFromSmiles(smi,sanitize=False)
        if m is None:
            invalid_smiles.append(smi)
            pass
        else:
            try:
                Chem.SanitizeMol(m)
                valid_smiles.append(smi)
            except:
                pass

print('%.2f' % (len(valid_smiles)), ' are all smiles encoded out of ', len(all_smiles))

In [None]:
invalid_smiles[0]