In [1]:
import wget

In [2]:
#url = 'http://www.futurecrew.com/skaven/song_files/mp3/razorback.mp3'
#filename = wget.download(url)
#filename

In [3]:
wget.download('http://molcyclegan.ardigen.com/250k_rndm_zinc_drugs_clean_3_canonized.csv')
wget.download('http://molcyclegan.ardigen.com/X_JTVAE_250k_rndm_zinc.csv')

'X_JTVAE_250k_rndm_zinc.csv'

In [4]:
%conda install -c pytorch pytorch torchvision

Collecting package metadata (current_repodata.json): ...working... done
Note: you may need to restart the kernel to use updated packages.

Solving environment: ...working... done

# All requested packages already installed.



In [5]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import torch.optim as optim

In [6]:
data = pd.read_csv('X_JTVAE_250k_rndm_zinc.csv')

In [7]:
smiles = data['SMILES'].values
np.savetxt(r'smiles.txt', smiles, fmt='%s')

In [8]:
class Discriminator(nn.Module):
    def __init__(self, data_shape):
        super(Discriminator, self).__init__()
        self.data_shape = data_shape

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(self.data_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )

    def forward(self, mol):
        validity = self.model(mol)
        return validity

    def save(self, path):
        save_dict = {
            'model': self.model.state_dict(),
            'data_shape': self.data_shape,
        }
        torch.save(save_dict, path)
        return

    @staticmethod
    def load(path):
        save_dict = torch.load(path)
        D = Discriminator(save_dict['data_shape'])
        D.model.load_state_dict(save_dict["model"])

        return D

In [9]:
import torch.nn as nn
import numpy as np
import torch


class Generator(nn.Module):
    def __init__(self, data_shape, latent_dim=None):
        super(Generator, self).__init__()
        self.data_shape = data_shape

        # latent dim of the generator is one of the hyperparams.
        # by default it is set to the prod of data_shapes
        self.latent_dim = int(np.prod(self.data_shape)) if latent_dim is None else latent_dim

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(self.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(self.data_shape))),
            # nn.Tanh() # expecting latent vectors to be not normalized
        )

    def forward(self, z):
        out = self.model(z)
        return out

    def save(self, path):
        save_dict = {
            'latent_dim': self.latent_dim,
            'model': self.model.state_dict(),
            'data_shape': self.data_shape,
        }
        torch.save(save_dict, path)

        return

    @staticmethod
    def load(path):
        save_dict = torch.load(path)
        G = Generator(save_dict['data_shape'], latent_dim=save_dict['latent_dim'])
        G.model.load_state_dict(save_dict["model"])

        return G

In [10]:
class Sampler(object):
    """
    Sampling the mols the generator.
    All scripts should use this class for sampling.
    """

    def __init__(self, generator: Generator):
        self.set_generator(generator)

    def set_generator(self, generator):
        self.G = generator

    def sample(self, n):
        # Sample noise as generator input
        z = torch.cuda.FloatTensor(np.random.uniform(-1, 1, (n, self.G.latent_dim)))
        # Generate a batch of mols
        return self.G(z)

In [11]:
from torch.utils import data
import json
import numpy as np


class LatentMolsDataset(data.Dataset):
    def __init__(self, latent_space_mols):
        self.data = latent_space_mols

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]


In [12]:
import pickle
import os
import torch
import torch.autograd as autograd
import numpy as np
import json
import time
import sys


class TrainModelRunner:
    # Loss weight for gradient penalty
    lambda_gp = 10

    def __init__(self, input_data_path, output_model_folder, decode_mols_save_path='', n_epochs=2000, starting_epoch=1,
                 batch_size=2500, lr=0.0002, b1=0.5, b2=0.999,  n_critic=5,
                 save_interval=1000, sample_after_training=30000, message=""):
        self.message = message

        # init params
        self.input_data_path = input_data_path
        self.output_model_folder = output_model_folder
        self.n_epochs = n_epochs
        self.starting_epoch = starting_epoch
        self.batch_size = batch_size
        self.lr = lr
        self.b1 = b1
        self.b2 = b2
        self.n_critic = n_critic
        self.save_interval = save_interval
        self.sample_after_training = sample_after_training
        self.decode_mols_save_path = decode_mols_save_path

        # initialize dataloader
        smiles_lat = pd.read_csv(input_data_path)
        latent_space_mols = smiles_lat.drop('SMILES', axis=1).values
        latent_space_mols = latent_space_mols.reshape(latent_space_mols.shape[0], 56)

        self.dataloader = torch.utils.data.DataLoader(LatentMolsDataset(latent_space_mols), shuffle=True,
                                                      batch_size=self.batch_size, drop_last=True)

        # load discriminator
        discriminator_name = 'discriminator.txt' if self.starting_epoch == 1 else str(
            self.starting_epoch) + '_discriminator.txt'
        discriminator_path = os.path.join(output_model_folder, discriminator_name)
        #self.D = Discriminator.load(discriminator_path)
        self.D = Discriminator(latent_space_mols[0].shape)
        # load generator
        generator_name = 'generator.txt' if self.starting_epoch == 1 else str(
            self.starting_epoch) + '_generator.txt'
        generator_path = os.path.join(output_model_folder, generator_name)
        #self.G = Generator.load(generator_path)
        self.G = Generator(latent_space_mols[0].shape)
        # initialize sampler
        self.Sampler = Sampler(self.G)

        # initialize optimizer
        self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=self.lr, betas=(self.b1, self.b2))
        self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=self.lr, betas=(self.b1, self.b2))

        # Tensor
        cuda = True if torch.cuda.is_available() else False
        if cuda:
            self.G.cuda()
            self.D.cuda()
        self.Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    def run(self):

        print("Run began.")
        print("Message: %s" % self.message)
        sys.stdout.flush()

        batches_done = 0
        disc_loss_log = []
        g_loss_log = []

        for epoch in range(self.starting_epoch, self.n_epochs + self.starting_epoch):
            disc_loss_per_batch = []
            g_loss_log_per_batch = []
            for i, real_mols in enumerate(self.dataloader):

                # Configure input
                real_mols = real_mols.type(self.Tensor)
                # real_mols = np.squeeze(real_mols, axis=1)

                # ---------------------
                #  Train Discriminator
                # ---------------------

                self.optimizer_D.zero_grad()

                # Generate a batch of mols from noise
                fake_mols = self.Sampler.sample(real_mols.shape[0])

                # Real mols
                real_validity = self.D(real_mols)
                # Fake mols
                fake_validity = self.D(fake_mols)
                # Gradient penalty
                gradient_penalty = self.compute_gradient_penalty(real_mols.data, fake_mols.data)
                # Adversarial loss
                d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + self.lambda_gp * gradient_penalty
                disc_loss_per_batch.append(d_loss.item())

                d_loss.backward()
                self.optimizer_D.step()
                self.optimizer_G.zero_grad()

                # Train the generator every n_critic steps
                if i % self.n_critic == 0:
                    # -----------------
                    #  Train Generator
                    # -----------------

                    # Generate a batch of mols
                    fake_mols = self.Sampler.sample(real_mols.shape[0])
                    # Loss measures generator's ability to fool the discriminator
                    # Train on fake images
                    fake_validity = self.D(fake_mols)
                    g_loss = -torch.mean(fake_validity)
                    g_loss_log_per_batch.append(g_loss.item())

                    g_loss.backward()
                    self.optimizer_G.step()

                    batches_done += self.n_critic

                # If last batch in the set
                if i == len(self.dataloader) - 1:
                    if epoch % self.save_interval == 0:
                        generator_save_path = os.path.join(self.output_model_folder,
                                                           str(epoch) + '_generator.txt')
                        discriminator_save_path = os.path.join(self.output_model_folder,
                                                               str(epoch) + '_discriminator.txt')
                        self.G.save(generator_save_path)
                        self.D.save(discriminator_save_path)

                    disc_loss_log.append([time.time(), epoch, np.mean(disc_loss_per_batch)])
                    g_loss_log.append([time.time(), epoch, np.mean(g_loss_log_per_batch)])

                    # Print and log
                    print(
                        "[Epoch %d/%d]  [Disc loss: %f] [Gen loss: %f] "
                        % (epoch, self.n_epochs + self.starting_epoch, disc_loss_log[-1][2], g_loss_log[-1][2])
                    )
                    sys.stdout.flush()

        # log the losses
        with open(os.path.join(self.output_model_folder, 'disc_loss.json'), 'w') as json_file:
            json.dump(disc_loss_log, json_file)
        with open(os.path.join(self.output_model_folder, 'gen_loss.json'), 'w') as json_file:
            json.dump(g_loss_log, json_file)

        # Sampling after training
        if self.sample_after_training > 0:
            print("Training finished. Generating sample of latent vectors")
            # sampling mode
            torch.no_grad()
            self.G.eval()

            S = Sampler(generator=self.G)
            latent = S.sample(self.sample_after_training)
            latent = latent.detach().cpu().numpy().tolist()

            sampled_mols_save_path = os.path.join(self.output_model_folder, 'sampled')
            np.save(sampled_mols_save_path+f'_epoch{epoch}', latent)

            # decoding sampled mols

        return 0

    def compute_gradient_penalty(self, real_samples, fake_samples):
        """Calculates the gradient penalty loss for WGAN GP"""
        # Random weight term for interpolation between real and fake samples
        alpha = self.Tensor(np.random.random((real_samples.size(0), 1)))

        # Get random interpolation between real and fake samples
        interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
        d_interpolates = self.D(interpolates)
        fake = self.Tensor(real_samples.shape[0], 1).fill_(1.0)

        # Get gradient w.r.t. interpolates
        gradients = autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()

        return gradient_penalty

In [13]:
trainer = TrainModelRunner('X_JTVAE_250k_rndm_zinc.csv', output_model_folder='content', starting_epoch=200,
                           save_interval=100, message='Starting training', batch_size=2500)

In [14]:
trainer.run()

Run began.
Message: Starting training
[Epoch 200/2200]  [Disc loss: -4.918017] [Gen loss: -0.419686] 
[Epoch 201/2200]  [Disc loss: -10.800661] [Gen loss: -1.444345] 
[Epoch 202/2200]  [Disc loss: -9.768504] [Gen loss: -0.499753] 
[Epoch 203/2200]  [Disc loss: -8.670980] [Gen loss: -0.277073] 
[Epoch 204/2200]  [Disc loss: -8.569320] [Gen loss: -0.447810] 
[Epoch 205/2200]  [Disc loss: -7.927479] [Gen loss: -0.464426] 
[Epoch 206/2200]  [Disc loss: -7.679630] [Gen loss: -0.648365] 
[Epoch 207/2200]  [Disc loss: -7.189933] [Gen loss: -0.734310] 
[Epoch 208/2200]  [Disc loss: -6.711613] [Gen loss: -0.595327] 
[Epoch 209/2200]  [Disc loss: -6.476677] [Gen loss: -0.740527] 
[Epoch 210/2200]  [Disc loss: -6.196911] [Gen loss: -0.783275] 
[Epoch 211/2200]  [Disc loss: -5.884022] [Gen loss: -0.933188] 
[Epoch 212/2200]  [Disc loss: -5.683596] [Gen loss: -0.894367] 
[Epoch 213/2200]  [Disc loss: -5.482324] [Gen loss: -0.748679] 
[Epoch 214/2200]  [Disc loss: -5.200372] [Gen loss: -0.846595] 
[

0

In [15]:
torch.no_grad()
trainer.G.eval()

S = Sampler(generator=trainer.G)
latent = S.sample(10) #10 samples
latent = latent.detach().cpu().numpy().tolist()

sampled_mols_save_path = os.path.join(trainer.output_model_folder, 'sampled')
np.save(sampled_mols_save_path+f'_epoch{200}', latent)

In [17]:
x = np.load('content/sampled_epoch200.npy')
pd.DataFrame(x).head(10)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,46,47,48,49,50,51,52,53,54,55
0,1.638836,-0.616414,-3.202151,0.009196,-2.119055,1.406643,-0.288569,0.728176,1.062475,-0.720555,...,0.309986,-0.232794,-0.393898,-0.498157,0.453065,-1.587266,-0.212633,-0.147607,-0.112923,0.268307
1,2.098586,-2.920185,-0.92161,-0.21071,1.939441,-0.419797,-1.819821,-2.933575,2.609533,2.776092,...,0.222477,-0.429017,0.079016,0.308999,-0.155979,0.617917,-0.542015,0.161656,0.061596,-0.098435
2,1.60334,0.779736,3.524892,1.078962,-0.617806,0.202965,-0.816952,-1.342952,0.513388,-3.130149,...,0.373169,-0.384275,-0.360362,0.876034,0.296635,-1.419499,-0.048151,0.438009,0.773474,0.320471
3,0.573565,0.641949,0.788232,0.773578,0.278844,1.286416,-0.480327,2.061178,-0.784136,1.422693,...,0.472464,0.298562,-0.242873,0.18565,-0.327381,0.075236,0.261322,-0.225767,0.385955,0.086198
4,-1.46108,-1.294397,2.549797,-2.771791,-1.44919,1.575721,1.540542,0.649556,-2.976316,-2.89235,...,-0.106666,-0.21872,-0.108841,0.194058,-0.343664,1.299873,-1.034401,-0.085661,-0.572886,-0.183947
5,1.243765,1.393844,-0.937783,-2.016922,0.530487,0.798823,1.923109,2.612517,-0.685,0.355754,...,0.629894,-0.736601,-0.451012,0.281935,-0.005658,0.62574,-0.567256,0.354323,0.635459,0.208598
6,-0.280874,-2.460422,2.250018,0.652199,-0.840923,0.464562,-3.036115,-0.623445,2.949517,1.943163,...,0.436232,-0.294117,0.231873,0.306376,-0.140524,0.187396,0.055132,-0.087789,-0.119752,0.412944
7,3.511398,-0.091787,-0.219192,0.313985,-1.573894,2.981424,0.247866,-2.049631,-2.851445,-0.44411,...,0.414965,-0.255028,-0.118179,0.311007,-0.055165,1.258763,-0.117244,0.088544,0.053068,0.374792
8,-1.808075,-1.455871,0.118823,-1.008678,-1.252403,-0.634293,-2.631734,-0.015655,1.059319,1.270841,...,0.583149,-0.622613,-0.294913,0.030339,-0.802155,-1.120866,0.112123,0.16661,-0.054769,-0.881775
9,1.108043,-0.107384,-0.9165,0.421288,-1.178407,1.392352,-3.051135,-1.33591,-2.592261,1.934546,...,0.386187,-0.259708,-0.347451,0.462199,-0.020424,0.163021,0.191441,-0.275179,-0.007054,0.15512
