In [10]:
!wget http://molcyclegan.ardigen.com/250k_rndm_zinc_drugs_clean_3_canonized.csv
!wget http://molcyclegan.ardigen.com/X_JTVAE_250k_rndm_zinc.csv

--2020-03-15 18:16:16--  http://molcyclegan.ardigen.com/250k_rndm_zinc_drugs_clean_3_canonized.csv
Resolving molcyclegan.ardigen.com (molcyclegan.ardigen.com)... 188.128.194.238
Connecting to molcyclegan.ardigen.com (molcyclegan.ardigen.com)|188.128.194.238|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 22101155 (21M) [text/csv]
Saving to: ‘250k_rndm_zinc_drugs_clean_3_canonized.csv.1’


2020-03-15 18:16:19 (8.29 MB/s) - ‘250k_rndm_zinc_drugs_clean_3_canonized.csv.1’ saved [22101155/22101155]

--2020-03-15 18:16:19--  http://molcyclegan.ardigen.com/X_JTVAE_250k_rndm_zinc.csv
Resolving molcyclegan.ardigen.com (molcyclegan.ardigen.com)... 188.128.194.238
Connecting to molcyclegan.ardigen.com (molcyclegan.ardigen.com)|188.128.194.238|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 164441376 (157M) [text/csv]
Saving to: ‘X_JTVAE_250k_rndm_zinc.csv.1’


2020-03-15 18:16:35 (9.95 MB/s) - ‘X_JTVAE_250k_rndm_zinc.csv.1’ saved [164441376/164441

In [0]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import torch.optim as optim

In [0]:
data = pd.read_csv('X_JTVAE_250k_rndm_zinc.csv')

In [0]:
smiles = data['SMILES'].values
np.savetxt(r'smiles.txt', smiles, fmt='%s')

In [0]:
class Discriminator(nn.Module):
    def __init__(self, data_shape):
        super(Discriminator, self).__init__()
        self.data_shape = data_shape

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(self.data_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )

    def forward(self, mol):
        validity = self.model(mol)
        return validity

    def save(self, path):
        save_dict = {
            'model': self.model.state_dict(),
            'data_shape': self.data_shape,
        }
        torch.save(save_dict, path)
        return

    @staticmethod
    def load(path):
        save_dict = torch.load(path)
        D = Discriminator(save_dict['data_shape'])
        D.model.load_state_dict(save_dict["model"])

        return D

In [0]:
import torch.nn as nn
import numpy as np
import torch


class Generator(nn.Module):
    def __init__(self, data_shape, latent_dim=None):
        super(Generator, self).__init__()
        self.data_shape = data_shape

        # latent dim of the generator is one of the hyperparams.
        # by default it is set to the prod of data_shapes
        self.latent_dim = int(np.prod(self.data_shape)) if latent_dim is None else latent_dim

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(self.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(self.data_shape))),
            # nn.Tanh() # expecting latent vectors to be not normalized
        )

    def forward(self, z):
        out = self.model(z)
        return out

    def save(self, path):
        save_dict = {
            'latent_dim': self.latent_dim,
            'model': self.model.state_dict(),
            'data_shape': self.data_shape,
        }
        torch.save(save_dict, path)

        return

    @staticmethod
    def load(path):
        save_dict = torch.load(path)
        G = Generator(save_dict['data_shape'], latent_dim=save_dict['latent_dim'])
        G.model.load_state_dict(save_dict["model"])

        return G

In [0]:
class Sampler(object):
    """
    Sampling the mols the generator.
    All scripts should use this class for sampling.
    """

    def __init__(self, generator: Generator):
        self.set_generator(generator)

    def set_generator(self, generator):
        self.G = generator

    def sample(self, n):
        # Sample noise as generator input
        z = torch.cuda.FloatTensor(np.random.uniform(-1, 1, (n, self.G.latent_dim)))
        # Generate a batch of mols
        return self.G(z)

In [0]:
from torch.utils import data
import json
import numpy as np


class LatentMolsDataset(data.Dataset):
    def __init__(self, latent_space_mols):
        self.data = latent_space_mols

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]


In [0]:
import pickle
import os
import torch
import torch.autograd as autograd
import numpy as np
import json
import time
import sys


class TrainModelRunner:
    # Loss weight for gradient penalty
    lambda_gp = 10

    def __init__(self, input_data_path, output_model_folder, decode_mols_save_path='', n_epochs=2000, starting_epoch=1,
                 batch_size=2500, lr=0.0002, b1=0.5, b2=0.999,  n_critic=5,
                 save_interval=1000, sample_after_training=30000, message=""):
        self.message = message

        # init params
        self.input_data_path = input_data_path
        self.output_model_folder = output_model_folder
        self.n_epochs = n_epochs
        self.starting_epoch = starting_epoch
        self.batch_size = batch_size
        self.lr = lr
        self.b1 = b1
        self.b2 = b2
        self.n_critic = n_critic
        self.save_interval = save_interval
        self.sample_after_training = sample_after_training
        self.decode_mols_save_path = decode_mols_save_path

        # initialize dataloader
        smiles_lat = pd.read_csv(input_data_path)
        latent_space_mols = smiles_lat.drop('SMILES', axis=1).values
        latent_space_mols = latent_space_mols.reshape(latent_space_mols.shape[0], 56)

        self.dataloader = torch.utils.data.DataLoader(LatentMolsDataset(latent_space_mols), shuffle=True,
                                                      batch_size=self.batch_size, drop_last=True)

        # load discriminator
        discriminator_name = 'discriminator.txt' if self.starting_epoch == 1 else str(
            self.starting_epoch) + '_discriminator.txt'
        discriminator_path = os.path.join(output_model_folder, discriminator_name)
        self.D = Discriminator.load(discriminator_path)
        # self.D = Discriminator(latent_space_mols[0].shape)
        # load generator
        generator_name = 'generator.txt' if self.starting_epoch == 1 else str(
            self.starting_epoch) + '_generator.txt'
        generator_path = os.path.join(output_model_folder, generator_name)
        self.G = Generator(.load(generator_path)
        # self.G = Generator(latent_space_mols[0].shape)
        # initialize sampler
        self.Sampler = Sampler(self.G)

        # initialize optimizer
        self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=self.lr, betas=(self.b1, self.b2))
        self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=self.lr, betas=(self.b1, self.b2))

        # Tensor
        cuda = True if torch.cuda.is_available() else False
        if cuda:
            self.G.cuda()
            self.D.cuda()
        self.Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    def run(self):

        print("Run began.")
        print("Message: %s" % self.message)
        sys.stdout.flush()

        batches_done = 0
        disc_loss_log = []
        g_loss_log = []

        for epoch in range(self.starting_epoch, self.n_epochs + self.starting_epoch):
            disc_loss_per_batch = []
            g_loss_log_per_batch = []
            for i, real_mols in enumerate(self.dataloader):

                # Configure input
                real_mols = real_mols.type(self.Tensor)
                # real_mols = np.squeeze(real_mols, axis=1)

                # ---------------------
                #  Train Discriminator
                # ---------------------

                self.optimizer_D.zero_grad()

                # Generate a batch of mols from noise
                fake_mols = self.Sampler.sample(real_mols.shape[0])

                # Real mols
                real_validity = self.D(real_mols)
                # Fake mols
                fake_validity = self.D(fake_mols)
                # Gradient penalty
                gradient_penalty = self.compute_gradient_penalty(real_mols.data, fake_mols.data)
                # Adversarial loss
                d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + self.lambda_gp * gradient_penalty
                disc_loss_per_batch.append(d_loss.item())

                d_loss.backward()
                self.optimizer_D.step()
                self.optimizer_G.zero_grad()

                # Train the generator every n_critic steps
                if i % self.n_critic == 0:
                    # -----------------
                    #  Train Generator
                    # -----------------

                    # Generate a batch of mols
                    fake_mols = self.Sampler.sample(real_mols.shape[0])
                    # Loss measures generator's ability to fool the discriminator
                    # Train on fake images
                    fake_validity = self.D(fake_mols)
                    g_loss = -torch.mean(fake_validity)
                    g_loss_log_per_batch.append(g_loss.item())

                    g_loss.backward()
                    self.optimizer_G.step()

                    batches_done += self.n_critic

                # If last batch in the set
                if i == len(self.dataloader) - 1:
                    if epoch % self.save_interval == 0:
                        generator_save_path = os.path.join(self.output_model_folder,
                                                           str(epoch) + '_generator.txt')
                        discriminator_save_path = os.path.join(self.output_model_folder,
                                                               str(epoch) + '_discriminator.txt')
                        self.G.save(generator_save_path)
                        self.D.save(discriminator_save_path)

                    disc_loss_log.append([time.time(), epoch, np.mean(disc_loss_per_batch)])
                    g_loss_log.append([time.time(), epoch, np.mean(g_loss_log_per_batch)])

                    # Print and log
                    print(
                        "[Epoch %d/%d]  [Disc loss: %f] [Gen loss: %f] "
                        % (epoch, self.n_epochs + self.starting_epoch, disc_loss_log[-1][2], g_loss_log[-1][2])
                    )
                    sys.stdout.flush()

        # log the losses
        with open(os.path.join(self.output_model_folder, 'disc_loss.json'), 'w') as json_file:
            json.dump(disc_loss_log, json_file)
        with open(os.path.join(self.output_model_folder, 'gen_loss.json'), 'w') as json_file:
            json.dump(g_loss_log, json_file)

        # Sampling after training
        if self.sample_after_training > 0:
            print("Training finished. Generating sample of latent vectors")
            # sampling mode
            torch.no_grad()
            self.G.eval()

            S = Sampler(generator=self.G)
            latent = S.sample(self.sample_after_training)
            latent = latent.detach().cpu().numpy().tolist()

            sampled_mols_save_path = os.path.join(self.output_model_folder, 'sampled')
            np.save(sampled_mols_save_path+f'_epoch{epoch}', latent)

            # decoding sampled mols

        return 0

    def compute_gradient_penalty(self, real_samples, fake_samples):
        """Calculates the gradient penalty loss for WGAN GP"""
        # Random weight term for interpolation between real and fake samples
        alpha = self.Tensor(np.random.random((real_samples.size(0), 1)))

        # Get random interpolation between real and fake samples
        interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
        d_interpolates = self.D(interpolates)
        fake = self.Tensor(real_samples.shape[0], 1).fill_(1.0)

        # Get gradient w.r.t. interpolates
        gradients = autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()

        return gradient_penalty

In [0]:
trainer = TrainModelRunner('/content/X_JTVAE_250k_rndm_zinc.csv', output_model_folder='/content/model', starting_epoch=0,
                           save_interval=100, message='Starting training', batch_size=2500)

In [0]:
trainer.run()

In [0]:
torch.no_grad()
trainer.G.eval()

S = Sampler(generator=trainer.G)
latent = S.sample(10) #10 samples
latent = latent.detach().cpu().numpy().tolist()

sampled_mols_save_path = os.path.join(trainer.output_model_folder, 'sampled')
np.save(sampled_mols_save_path+f'_epoch{200}', latent)

In [50]:
x = np.load('/content/model/sampled_epoch200.npy')
pd.DataFrame(x).head(10)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55
0,-2.206029,-0.808622,0.227027,-1.398846,-0.100072,2.809301,0.433412,0.617868,2.921274,-0.630039,-0.015929,-1.214874,1.238661,1.919663,1.712633,-2.118173,2.056915,-1.282244,-0.323812,-1.917734,-1.83421,1.901669,1.027108,-1.074955,0.722044,-0.77284,-1.178263,1.204141,0.021915,-0.317627,-0.220762,0.052131,0.456755,0.070525,0.558844,0.530643,0.298532,-0.741765,-0.120004,-0.887192,0.504271,1.033725,0.03654,0.565099,-0.040974,-0.37149,0.317552,1.12457,-1.216062,-0.349802,0.016857,-0.316001,0.645152,-1.070026,-0.2859,0.443744
1,1.682056,-0.609567,-2.252846,1.211958,-1.366811,2.314502,-1.48782,1.330656,0.50841,2.56022,2.649158,3.73915,-0.219646,0.185257,2.063104,-0.191439,1.270549,0.951931,-0.371398,4.485538,-1.377823,-0.388842,-1.598731,-0.44401,-2.914255,1.085253,0.036,-0.626002,0.120057,-0.377855,0.867117,0.566207,0.42251,-0.635424,0.363804,0.223369,-0.074575,0.13914,0.039169,-0.178168,-0.175273,0.185514,-0.198111,0.602899,0.084276,-1.319473,0.293325,0.157579,-0.099214,0.16448,-0.283875,0.160912,0.265407,0.191268,-0.078609,0.022873
2,2.058921,-1.012089,-4.200319,0.859445,1.797249,1.58921,-3.770724,0.441195,1.834026,-1.979702,-0.520992,2.860998,-2.34806,-0.369602,-1.072005,-4.138757,0.461854,-0.450401,-0.299388,-0.568555,-1.126057,-0.018332,3.364588,-2.255006,-2.616036,-0.795993,1.292032,-1.284919,-0.466509,-0.49887,0.47718,-0.26682,1.61243,0.253728,0.177948,-0.054492,0.49981,0.10387,-0.225704,-0.456521,0.024753,0.131566,-0.312732,0.850896,0.593622,-0.478496,0.323845,-0.067242,-0.15722,0.246735,-0.273384,0.359672,0.109671,0.27143,0.086816,0.233399
3,1.496926,3.138901,-2.547379,-0.869359,-0.259328,1.416225,-0.99407,-0.486854,0.142025,0.237491,-1.986327,0.143168,0.062135,-0.776502,1.883481,-0.397742,-1.397495,-0.964522,-1.80816,1.43058,2.632986,-0.582195,0.175714,-1.508701,-1.337555,2.593554,-0.872461,3.568985,0.134516,-0.299169,0.990468,-0.287957,0.345039,-0.259158,0.078833,0.049462,0.319563,-0.35535,0.452621,-0.530031,-0.214212,0.212138,-0.075205,0.615783,-0.042233,-0.567815,0.207993,0.016102,-0.34616,0.084241,-0.218341,0.211247,0.210123,-0.039403,-0.195077,-0.499018
4,0.791522,-4.199571,-4.615912,-0.83769,1.186579,-0.089598,-1.146151,4.647312,1.291135,-0.028676,-1.718559,-1.847364,-1.370815,-1.421265,2.017358,-0.440322,-1.473093,-2.497284,-2.246482,0.538316,-1.22764,-0.100853,-1.246739,-2.504039,-0.14666,2.601627,1.545972,2.32655,-0.976732,-0.191328,0.557248,0.556218,5.861549,0.148932,0.195171,0.479189,0.460419,-0.287761,-0.376164,-0.515892,0.411546,0.644988,-0.570256,0.53602,0.870963,-0.239492,0.083884,0.099846,0.012018,0.201893,0.340828,1.037974,-0.435755,-0.107969,0.086815,0.554617
5,1.041741,0.082437,0.549398,-2.424614,0.661732,3.69233,0.298974,1.701455,-2.166981,3.105842,-0.451497,-2.087396,0.219262,-1.05965,2.685484,-0.943232,-2.12047,0.191212,0.415367,-0.152114,-0.974302,1.787914,0.364055,2.367908,2.612651,-1.235037,0.640992,-0.342927,1.212635,-0.657298,0.674154,-0.25598,-4.612574,-0.206135,0.060253,-0.206599,-0.530276,-0.092508,0.270761,-0.348374,-0.05245,-0.162833,0.291443,0.158516,-0.568759,-0.187536,0.397226,0.076472,0.231934,0.490374,-0.57881,-0.747969,0.173777,0.228035,0.326183,-0.095992
6,0.236834,0.447515,0.942494,1.831178,4.035197,-2.715159,-0.494452,-1.390876,0.988349,-0.32475,2.090473,0.823172,0.336053,-0.798366,-4.623056,-1.926243,-3.884985,3.242795,1.386794,-1.520089,-0.96593,-3.402725,1.61808,-1.161801,-0.997457,-1.92452,2.012156,0.129474,-0.090356,-1.096475,0.341787,0.791985,0.778682,-0.449166,0.449953,-0.095116,-0.952586,0.459136,1.150347,-0.370117,-0.104506,0.267578,-0.497688,0.174427,0.314535,-0.46033,0.096316,0.96455,-0.072824,-1.03922,0.396921,-0.054091,0.634633,-0.134242,-0.533577,-0.755789
7,0.658501,-1.475041,0.748745,-1.426473,-3.699597,-1.634,-2.526913,1.611579,0.905675,-0.149053,2.128803,1.691138,-0.324888,-0.90845,1.949752,0.952909,-2.622906,-1.713309,0.754372,-0.925261,0.833147,-0.777138,-1.055742,0.722648,-2.36138,-0.383122,-0.511467,-0.363378,-0.724928,-0.183622,0.36176,0.512733,4.46567,0.089012,0.133642,0.174444,0.723054,-0.35185,-0.112913,0.225106,-0.012036,0.409238,-0.554399,0.402663,0.550017,-0.153775,-0.121379,-0.276993,0.304675,-0.380483,0.013872,0.904209,-0.299967,0.329411,0.069959,-0.408021
8,1.748999,1.833417,-2.441207,-0.946759,2.490193,-4.485895,-2.535448,-2.59842,-1.294683,-1.386219,-1.564114,-0.846953,0.879919,2.339151,-1.867399,1.283344,1.502694,-3.004918,-0.397677,0.129223,-3.509833,1.425902,0.361515,0.774268,-1.216181,3.697596,0.882637,-0.627216,-0.160732,-0.582654,0.298953,0.306886,-0.629769,-0.199482,-0.321718,0.176495,0.329783,-0.545061,0.351163,-0.011748,-0.224851,0.035192,0.15722,0.37542,-0.353154,0.003838,0.033852,-0.421979,-0.318678,-0.069865,-0.041326,0.035976,-0.313842,-0.428427,0.038867,0.614626
9,1.833144,-1.303949,4.043544,-0.219405,2.552707,-2.697711,-2.173731,-0.637294,1.118104,0.177372,-0.791029,-2.880402,1.589534,1.483211,1.656625,-0.573594,-0.920326,-0.048768,2.529475,-0.892568,0.557739,-0.0139,0.060773,0.489898,2.318283,-1.098597,-0.217817,-1.117312,0.88358,-0.099005,0.751728,0.1693,-3.689019,-0.402877,-0.259092,-0.304338,0.335484,0.279515,-0.072168,0.276214,-0.273628,-0.182982,-0.109527,0.335881,0.285368,0.264924,0.129178,0.119394,0.143332,-0.385476,-0.111564,-0.85459,-0.287042,0.624932,-0.62319,0.550105
