In [None]:
!git clone https://github.com/dssikdar/asdrp_QGAN/

Cloning into 'asdrp_QGAN'...
remote: Enumerating objects: 219, done.[K
remote: Counting objects: 100% (219/219), done.[K
remote: Compressing objects: 100% (196/196), done.[K
remote: Total 219 (delta 96), reused 27 (delta 10), pack-reused 0[K
Receiving objects: 100% (219/219), 40.74 MiB | 6.18 MiB/s, done.
Resolving deltas: 100% (96/96), done.


In [None]:
!pip install wget

Collecting wget
  Downloading wget-3.2.zip (10 kB)
Building wheels for collected packages: wget
  Building wheel for wget (setup.py) ... [?25l[?25hdone
  Created wheel for wget: filename=wget-3.2-py3-none-any.whl size=9672 sha256=a54050436c7a21d645112cb063ae4cd9a98d86f2dbec86b298a6bb2ac6e89336
  Stored in directory: /root/.cache/pip/wheels/a1/b6/7c/0e63e34eb06634181c63adacca38b79ff8f35c37e3c13e3c02
Successfully built wget
Installing collected packages: wget
Successfully installed wget-3.2


In [None]:
import wget

In [None]:
wget.download('http://molcyclegan.ardigen.com/250k_rndm_zinc_drugs_clean_3_canonized.csv')
wget.download('http://molcyclegan.ardigen.com/X_JTVAE_250k_rndm_zinc.csv')

'X_JTVAE_250k_rndm_zinc.csv'

In [None]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import torch.optim as optim

In [None]:
data = pd.read_csv('X_JTVAE_250k_rndm_zinc.csv')

In [None]:
smiles = data['SMILES'].values
np.savetxt(r'smiles.txt', smiles, fmt='%s')

In [None]:
class Discriminator(nn.Module):
    def __init__(self, data_shape):
        super(Discriminator, self).__init__()
        self.data_shape = data_shape

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(self.data_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )

    def forward(self, mol):
        validity = self.model(mol)
        return validity

    def save(self, path):
        save_dict = {
            'model': self.model.state_dict(),
            'data_shape': self.data_shape,
        }
        torch.save(save_dict, path)
        return

    @staticmethod
    def load(path):
        save_dict = torch.load(path)
        D = Discriminator(save_dict['data_shape'])
        D.model.load_state_dict(save_dict["model"])

        return D

In [None]:
import torch.nn as nn
import numpy as np
import torch


class Generator(nn.Module):
    def __init__(self, data_shape, latent_dim=None):
        super(Generator, self).__init__()
        self.data_shape = data_shape

        # latent dim of the generator is one of the hyperparams.
        # by default it is set to the prod of data_shapes
        self.latent_dim = int(np.prod(self.data_shape)) if latent_dim is None else latent_dim

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(self.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(self.data_shape))),
            # nn.Tanh() # expecting latent vectors to be not normalized
        )

    def forward(self, z):
        out = self.model(z)
        return out

    def save(self, path):
        save_dict = {
            'latent_dim': self.latent_dim,
            'model': self.model.state_dict(),
            'data_shape': self.data_shape,
        }
        torch.save(save_dict, path)

        return

    @staticmethod
    def load(path):
        save_dict = torch.load(path)
        G = Generator(save_dict['data_shape'], latent_dim=save_dict['latent_dim'])
        G.model.load_state_dict(save_dict["model"])

        return G

In [None]:
class Sampler(object):
    """
    Sampling the mols the generator.
    All scripts should use this class for sampling.
    """

    def __init__(self, generator: Generator):
        self.set_generator(generator)

    def set_generator(self, generator):
        self.G = generator

    def sample(self, n):
        # Sample noise as generator input
        z = torch.cuda.FloatTensor(np.random.uniform(-1, 1, (n, self.G.latent_dim)))
        #z = torch.FloatTensor(np.random.uniform(-1, 1, (n, self.G.latent_dim)))
        # Generate a batch of mols
        return self.G(z)

In [None]:
from torch.utils import data
import json
import numpy as np


class LatentMolsDataset(data.Dataset):
    def __init__(self, latent_space_mols):
        self.data = latent_space_mols

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

In [None]:
import pickle
import os
import torch
import torch.autograd as autograd
import numpy as np
import json
import time
import sys


class TrainModelRunner:
    # Loss weight for gradient penalty
    lambda_gp = 10

    def __init__(self, input_data_path, output_model_folder, decode_mols_save_path='', n_epochs=2000, starting_epoch=1,
                 batch_size=2500, lr=0.0002, b1=0.5, b2=0.999,  n_critic=5,
                 save_interval=1000, sample_after_training=30000, message=""):
        self.message = message

        # init params
        self.input_data_path = input_data_path
        self.output_model_folder = output_model_folder
        self.n_epochs = n_epochs
        self.starting_epoch = starting_epoch
        self.batch_size = batch_size
        self.lr = lr
        self.b1 = b1
        self.b2 = b2
        self.n_critic = n_critic
        self.save_interval = save_interval
        self.sample_after_training = sample_after_training
        self.decode_mols_save_path = decode_mols_save_path

        # initialize dataloader
        smiles_lat = pd.read_csv(input_data_path)
        latent_space_mols = smiles_lat.drop('SMILES', axis=1).values
        latent_space_mols = latent_space_mols.reshape(latent_space_mols.shape[0], 56)

        self.dataloader = torch.utils.data.DataLoader(LatentMolsDataset(latent_space_mols), shuffle=True,
                                                      batch_size=self.batch_size, drop_last=True)

        # load discriminator
        discriminator_name = 'discriminator.txt' if self.starting_epoch == 1 else str(
            self.starting_epoch) + '_discriminator.txt'
        discriminator_path = os.path.join(output_model_folder, discriminator_name)
        #self.D = Discriminator.load(discriminator_path)
        self.D = Discriminator(latent_space_mols[0].shape)
        # load generator
        generator_name = 'generator.txt' if self.starting_epoch == 1 else str(
            self.starting_epoch) + '_generator.txt'
        generator_path = os.path.join(output_model_folder, generator_name)
        #self.G = Generator.load(generator_path)
        self.G = Generator(latent_space_mols[0].shape)
        # initialize sampler
        self.Sampler = Sampler(self.G)

        # initialize optimizer
        self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=self.lr, betas=(self.b1, self.b2))
        self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=self.lr, betas=(self.b1, self.b2))

        # Tensor
        cuda = True if torch.cuda.is_available() else False
        if cuda:
            self.G.cuda()
            self.D.cuda()
        self.Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    def run(self):

        print("Run began.")
        print("Message: %s" % self.message)
        sys.stdout.flush()

        batches_done = 0
        disc_loss_log = []
        g_loss_log = []

        for epoch in range(self.starting_epoch, self.n_epochs + self.starting_epoch):
            disc_loss_per_batch = []
            g_loss_log_per_batch = []
            for i, real_mols in enumerate(self.dataloader):

                # Configure input
                real_mols = real_mols.type(self.Tensor)
                # real_mols = np.squeeze(real_mols, axis=1)

                # ---------------------
                #  Train Discriminator
                # ---------------------

                self.optimizer_D.zero_grad()

                # Generate a batch of mols from noise
                fake_mols = self.Sampler.sample(real_mols.shape[0])

                # Real mols
                real_validity = self.D(real_mols)
                # Fake mols
                fake_validity = self.D(fake_mols)
                # Gradient penalty
                gradient_penalty = self.compute_gradient_penalty(real_mols.data, fake_mols.data)
                # Adversarial loss
                d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + self.lambda_gp * gradient_penalty
                disc_loss_per_batch.append(d_loss.item())

                d_loss.backward()
                self.optimizer_D.step()
                self.optimizer_G.zero_grad()

                # Train the generator every n_critic steps
                if i % self.n_critic == 0:
                    # -----------------
                    #  Train Generator
                    # -----------------

                    # Generate a batch of mols
                    fake_mols = self.Sampler.sample(real_mols.shape[0])
                    # Loss measures generator's ability to fool the discriminator
                    # Train on fake images
                    fake_validity = self.D(fake_mols)
                    g_loss = -torch.mean(fake_validity)
                    g_loss_log_per_batch.append(g_loss.item())

                    g_loss.backward()
                    self.optimizer_G.step()

                    batches_done += self.n_critic

                # If last batch in the set
                if i == len(self.dataloader) - 1:
                    if epoch % self.save_interval == 0:
                        generator_save_path = os.path.join(self.output_model_folder,
                                                           str(epoch) + '_generator.txt')
                        discriminator_save_path = os.path.join(self.output_model_folder,
                                                               str(epoch) + '_discriminator.txt')
                        self.G.save(generator_save_path)
                        self.D.save(discriminator_save_path)

                    disc_loss_log.append([time.time(), epoch, np.mean(disc_loss_per_batch)])
                    g_loss_log.append([time.time(), epoch, np.mean(g_loss_log_per_batch)])

                    # Print and log
                    print(
                        "[Epoch %d/%d]  [Disc loss: %f] [Gen loss: %f] "
                        % (epoch, self.n_epochs + self.starting_epoch, disc_loss_log[-1][2], g_loss_log[-1][2])
                    )
                    sys.stdout.flush()

        # log the losses
        with open(os.path.join(self.output_model_folder, 'disc_loss.json'), 'w') as json_file:
            json.dump(disc_loss_log, json_file)
        with open(os.path.join(self.output_model_folder, 'gen_loss.json'), 'w') as json_file:
            json.dump(g_loss_log, json_file)

        # Sampling after training
        if self.sample_after_training > 0:
            print("Training finished. Generating sample of latent vectors")
            # sampling mode
            torch.no_grad()
            self.G.eval()

            S = Sampler(generator=self.G)
            latent = S.sample(self.sample_after_training)
            latent = latent.detach().cpu().numpy().tolist()

            sampled_mols_save_path = os.path.join(self.output_model_folder, 'sampled')
            np.save(sampled_mols_save_path+f'_epoch{epoch}', latent)

            # decoding sampled mols

        return 0

    def compute_gradient_penalty(self, real_samples, fake_samples):
        """Calculates the gradient penalty loss for WGAN GP"""
        # Random weight term for interpolation between real and fake samples
        alpha = self.Tensor(np.random.random((real_samples.size(0), 1)))

        # Get random interpolation between real and fake samples
        interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
        d_interpolates = self.D(interpolates)
        fake = self.Tensor(real_samples.shape[0], 1).fill_(1.0)

        # Get gradient w.r.t. interpolates
        gradients = autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()

        return gradient_penalty

In [None]:
trainer = TrainModelRunner('X_JTVAE_250k_rndm_zinc.csv', output_model_folder='asdrp_QGAN/MolGAN-master/sussyoutput', starting_epoch=200,
                           save_interval=100, message='Starting training', batch_size=2000)

In [None]:
trainer.run()

Run began.
Message: Starting training
[Epoch 200/2200]  [Disc loss: -6.147428] [Gen loss: -0.772320] 
[Epoch 201/2200]  [Disc loss: -11.277042] [Gen loss: -0.600219] 
[Epoch 202/2200]  [Disc loss: -9.544904] [Gen loss: -0.734304] 
[Epoch 203/2200]  [Disc loss: -9.059465] [Gen loss: -0.666370] 
[Epoch 204/2200]  [Disc loss: -8.377227] [Gen loss: -0.814009] 
[Epoch 205/2200]  [Disc loss: -7.747880] [Gen loss: -1.014137] 
[Epoch 206/2200]  [Disc loss: -7.186549] [Gen loss: -0.921502] 
[Epoch 207/2200]  [Disc loss: -6.730028] [Gen loss: -0.854973] 
[Epoch 208/2200]  [Disc loss: -6.565425] [Gen loss: -0.742636] 
[Epoch 209/2200]  [Disc loss: -6.468088] [Gen loss: -0.506219] 
[Epoch 210/2200]  [Disc loss: -6.420172] [Gen loss: -0.623610] 
[Epoch 211/2200]  [Disc loss: -6.334639] [Gen loss: -0.640086] 
[Epoch 212/2200]  [Disc loss: -6.042214] [Gen loss: -0.907074] 
[Epoch 213/2200]  [Disc loss: -5.801292] [Gen loss: -1.164438] 
[Epoch 214/2200]  [Disc loss: -5.593796] [Gen loss: -0.930459] 
[

0

In [None]:
torch.no_grad()
trainer.G.eval()

Generator(
  (model): Sequential(
    (0): Linear(in_features=56, out_features=128, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Linear(in_features=128, out_features=256, bias=True)
    (3): BatchNorm1d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Linear(in_features=256, out_features=512, bias=True)
    (6): BatchNorm1d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Linear(in_features=512, out_features=1024, bias=True)
    (9): BatchNorm1d(1024, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Linear(in_features=1024, out_features=56, bias=True)
  )
)

In [None]:
S = Sampler(generator=trainer.G)
latent = S.sample(10) #10 samples
latent = latent.detach().cpu().numpy().tolist()

In [None]:
sampled_mols_save_path = os.path.join(trainer.output_model_folder, 'sampled')
np.save(sampled_mols_save_path+f'_epoch{200}', latent)

In [None]:
x = np.load('asdrp_QGAN/MolGAN-master/sussyoutput/sampled_epoch2199.npy')


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55
0,2.108176,1.893925,-3.531663,-0.611552,0.386928,-3.160022,1.457177,-1.1939,-0.062447,-1.117513,1.200686,-3.744967,-1.917441,0.677957,-0.528983,2.053841,0.437652,0.864186,1.252618,0.575265,0.976551,-2.21902,-0.433706,1.54131,2.447764,-1.420221,-0.345573,-0.049259,-1.433996,-0.473958,-0.061518,-0.243983,6.301006,-0.005938,0.271337,0.553598,-0.022877,-0.318652,0.211859,-0.331486,0.1354,0.507311,-0.0863,0.423206,0.429623,-0.036758,0.507386,0.041732,-0.137661,0.171211,0.178237,1.409303,0.224058,0.603406,-0.124114,0.258185
1,-1.601339,0.174058,0.08321,-0.788245,0.364303,2.409225,-0.395136,-0.195722,0.362953,0.323255,2.160412,0.796323,-1.706544,1.579946,-0.545106,0.446278,-0.778673,-4.236077,0.085227,-2.789953,1.088765,0.739895,-0.678084,-1.48003,-2.545705,-1.226079,0.649361,0.588944,0.192852,-0.295155,0.11689,-0.169057,0.243616,-0.185312,0.178626,0.339491,0.167866,-0.577285,0.016296,-0.136196,0.163278,0.409702,-0.133714,0.207612,-0.212279,-0.636208,0.170323,0.267853,-0.010772,-0.036889,0.218873,-0.264131,0.172077,0.10014,0.129457,0.518875
2,0.207651,-1.724208,0.729644,-1.764438,-1.325767,-0.572624,0.636105,0.810259,2.043674,0.569043,0.17538,1.318225,1.475044,3.139221,-0.20459,-0.878074,1.450194,-0.488714,1.179344,1.505802,0.751478,0.63027,-1.621246,-1.964098,-1.455383,1.749321,3.091902,2.849564,0.560035,-0.586099,-0.218906,-0.449067,-5.488092,-0.140497,0.663939,0.223002,0.319845,-0.468372,-0.772095,0.027254,0.156951,-0.173376,0.177082,0.030242,-0.618764,-0.985512,-0.094498,0.077197,-0.194528,0.536833,-0.141101,-0.836092,0.077225,-0.040108,-0.255305,0.497587
3,-0.220726,3.110145,-0.886518,2.642729,3.047859,1.553695,1.47387,-1.732403,-1.269199,-0.169491,-0.689389,-1.854081,-1.182386,-1.340529,0.10559,-2.901687,0.875737,-1.429951,-0.511327,-0.017867,-1.639271,2.381382,0.598968,2.076913,-0.185359,-0.112717,-1.455705,-2.558298,-1.203199,-0.81512,-0.431266,-0.309226,7.11986,-0.104389,0.644677,0.267055,0.453848,0.016537,0.104072,0.219923,0.68567,0.247947,-0.076761,0.434822,0.366815,0.089258,-0.019362,-0.162237,0.241175,0.312404,0.532124,1.12385,0.195017,1.137102,0.15887,-0.378291
4,-1.302804,0.541726,0.77534,2.477511,4.415025,-0.227878,1.812546,-1.770043,-1.704075,1.094787,-1.174227,0.377891,0.796054,0.685887,-3.34765,-3.026154,0.40982,0.313774,0.692881,-1.003518,0.321564,-0.860844,-0.507852,2.690151,1.486841,0.357927,-1.469642,-1.949453,1.122194,-0.299658,0.583442,0.183581,-5.79023,0.054076,0.280517,0.039578,0.151142,-0.278478,-0.201802,-0.248606,0.165125,0.320416,0.479366,0.351012,-0.347549,-0.352713,0.343215,0.075667,-0.439752,-0.021589,-0.353433,-1.160199,-0.153297,-0.239455,0.112124,0.485329
5,0.850899,-0.367891,0.0533,4.327947,-0.807246,1.41058,3.054457,1.122534,-0.885894,-0.384086,3.713985,1.187921,3.487158,-1.308813,0.667317,-0.178507,1.35538,-1.545679,-1.471369,0.338529,1.792373,0.210256,-2.797135,1.3218,-1.173632,-1.690038,0.063852,0.209533,0.192796,-0.361357,0.35494,0.666062,0.489887,0.057242,0.440699,0.618505,0.304703,0.223065,-0.414209,-0.075034,0.422172,0.566487,-0.525887,0.360869,0.448938,-0.134354,0.242784,0.038018,-0.26142,-0.424177,0.382973,-0.272298,-0.031653,0.429844,0.645901,0.092834
6,1.882317,-0.348617,0.015124,-0.35304,0.856482,-3.705119,-0.631203,1.119163,1.990649,2.728358,-1.843945,1.364928,0.609366,2.045374,0.707159,-1.603126,1.93753,-2.799426,-2.861299,1.513339,-0.589136,0.066376,-3.997556,-1.13853,-2.248581,3.586187,-2.256194,-3.26282,0.141995,-0.71657,0.463983,0.210473,0.402558,-0.241086,0.230814,0.227646,0.083044,0.215937,-0.109724,0.039225,0.366316,0.093046,0.125975,0.133543,0.237096,-0.267968,0.329049,-0.465808,-0.501705,0.582014,0.000966,0.193867,-0.309983,0.30675,0.507072,0.400872
7,-1.883415,-0.861364,-1.387948,-3.751523,0.466555,0.772368,-1.371509,2.243115,0.850862,-0.819226,-0.828084,-0.971318,1.532037,3.137849,-3.202545,0.303064,0.187929,-1.946013,2.093164,-3.082435,0.359136,1.632171,-1.450697,-0.553949,-2.649521,-1.92867,0.468992,-4.060611,0.200461,0.124838,0.679952,-0.063609,0.183703,-0.070417,0.541722,0.106253,-0.046521,0.278379,-0.108084,-0.411816,0.197806,0.334131,-0.434464,0.502487,0.43993,-0.352097,0.055169,0.292626,0.041102,0.164666,0.366699,-0.221942,0.014066,0.400174,-0.282275,-0.191831
8,1.124615,-0.632122,0.027952,0.483904,0.694588,0.321421,-1.22865,-1.004581,0.933281,-1.571095,2.227562,0.569001,-0.971475,1.185517,-0.80903,-0.933985,1.964947,0.673757,-1.030414,-0.971317,-1.247316,1.889507,-1.688542,-0.528522,-0.642218,1.948211,1.95815,3.892483,-0.038769,-0.173264,0.661169,0.15132,0.272111,-0.398363,0.031218,0.055217,-0.025554,-0.068006,0.17567,-0.217705,-0.133551,0.007492,0.266154,0.482403,-0.131734,-0.189266,0.020141,0.507343,-0.756525,-0.308155,-0.273907,0.003519,-0.011051,-0.404833,-0.371575,0.181702
9,-0.771928,-0.214135,-0.031244,0.605053,3.580932,-4.023792,-1.678055,-1.591185,0.019677,0.510885,2.885331,-0.845186,-1.837593,0.470665,2.60408,0.724229,-1.669727,2.509564,1.597082,-0.893468,-2.170784,-0.54984,-1.166973,0.447936,1.740774,1.128944,-0.914106,-1.249107,-0.038979,-0.116311,0.942655,-0.430814,0.279531,0.179374,0.09095,0.312987,0.081257,-0.279056,0.239012,-0.071061,0.282288,0.371802,-0.009744,0.545176,-0.128488,-0.20179,0.110897,-0.218299,0.052779,0.135324,0.380547,-0.191108,0.121924,0.365756,0.092322,0.334798


In [None]:
pd.DataFrame(x).head(10).data_shape()

AttributeError: ignored

In [None]:
count = 0
for a in range(1,31):
    for b in range(1, 31):
        for c in rnage(1, 31):
            if a*b*c == 30: count += 1

In [None]:
count