# ChemVAE Implementation

In [1]:
from utils import *
import numpy as np

## Import and Pre-Process Data

In [2]:
args_input = '../data/zinc.csv'
args_train = 100000
args_val = 1000
args_output = True
args_colc = 'SMILES'

In [4]:
# GitHub data directory contains only tar.gz version
# import pandas as pd
# df = pd.read_csv('../data/zinc.tar.gz', compression='gzip', header=0, sep=',', error_bad_lines=False)
# df.columns[0] = 'SMILES'

df = import_data(args_input)
X_train, X_test = return_splits(df, n_train=args_train, n_test=args_val, col_chem=args_colc)
char2idx, idx2char, train_idx, test_idx = create_data(X_train, X_test, colname=args_colc)   
train_oh, test_oh = check_conversions(idx2char, train_idx, X_train, test_idx, X_test)

There are 0 training index conversion errors
There are 0 testing index conversion errors

There are 0 training one-hot conversion errors
There are 0 testing one-hot conversion errors


In [5]:
print(max([len(i) for i in X_train]))
print(max([len(i) for i in X_test]))

57
50


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [8]:
train_oh.shape[1] ## d_input

60

In [9]:
train_oh.shape[2] ## d_output

29

In [220]:
class ChemVAE(nn.Module):
    def __init__(self, d_input, d_output):
        super(ChemVAE, self).__init__()

        self.conv_1 = nn.Conv1d(d_input, 9, kernel_size=9)
        self.conv_2 = nn.Conv1d(9, 9, kernel_size=9)
        self.conv_3 = nn.Conv1d(9, 10, kernel_size=11)
        self.linear_0 = nn.Linear(30, 435)
        self.linear_1 = nn.Linear(435, 292)
        self.linear_2 = nn.Linear(435, 292)

        self.linear_3 = nn.Linear(292, 292)
        self.gru = nn.GRU(292, 501, 3, batch_first=True)
        self.linear_4 = nn.Linear(501, d_output)
        
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax()

    def encode(self, x):
        x = self.relu(self.conv_1(x))
        x = self.relu(self.conv_2(x))
        x = self.relu(self.conv_3(x))
        x = x.view(x.size(0), -1)
        x = F.selu(self.linear_0(x))
        z_mu = self.linear_1(x)
        z_logvar = self.linear_2(x)
        return z_mu, z_logvar

    def sampling(self, mu, log_var):
        '''
        Sample from latent space, z ~ N(μ, σ**2)
        '''
        sigma = torch.exp(log_var / 2)

        #epsilon = torch.randn(sigma.size()).float()
        epsilon = torch.randn_like(sigma).float()

        self.z_mean = mu
        self.z_logvar = log_var

        # use the reparameterization trick
        return mu + sigma * epsilon

    def decode(self, z):
        z = F.selu(self.linear_3(z))
        z = z.view(z.size(0), 1, z.size(-1)).repeat(1, 60, 1)
        output, hn = self.gru(z)
        out_reshape = output.contiguous().view(-1, output.size(-1))
        y0 = F.softmax(self.linear_4(out_reshape), dim=1)
        y = y0.contiguous().view(output.size(0), -1, y0.size(-1))
        return y

    def forward(self, x):
        z_mean, z_logvar = self.encode(x)
        z = self.sampling(z_mean, z_logvar)
        output = self.decode(z)
        return output

In [110]:
def calc_kld(z_mean, z_logvar):
    '''
    Calculate KL divergence
    '''
    return -0.5 * torch.mean(1 + z_logvar - z_mean**2 - z_logvar.exp())

In [None]:
args_lr = 0.001
args_dynlr = True
args_batch_size = 200

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = "cpu"

if args_dynlr:
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', 
                                                     factor = 0.8, 
                                                     patience = 3,
                                                     min_lr = 0.0001)

X_train = torch.from_numpy(train_oh.astype(np.float32))
X_test = torch.from_numpy(test_oh.astype(np.float32))

torch.manual_seed(42)

# train_loader = torch.utils.data.DataLoader(X_train, batch_size=200, shuffle=True)

train_loader = torch.utils.data.DataLoader(X_train, 
                                           batch_size=args_batch_size,
                                           shuffle=True, 
                                           num_workers=6,
                                           drop_last = True)

test_loader = torch.utils.data.DataLoader(X_test, 
                                          batch_size=args_batch_size,
                                          shuffle=True, 
                                          num_workers=6,
                                          drop_last = True)

epochs = 100

model = ChemVAE(train_oh.shape[1], train_oh.shape[2]).to(device)
optimizer = optim.Adam(model.parameters(), lr = args_lr)

def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        kld = calc_kld(model.z_mean, model.z_logvar)
        loss = F.binary_cross_entropy(output, data, size_average=False) + kld
        loss.backward()
        train_loss += loss
        optimizer.step()
        if epoch % 10 == 0 and batch_idx == 0:
            print(f'{epoch} / {batch_idx + 1}\t{loss:.4f}')
            print('train', train_loss / len(train_loader.dataset))
            print(f'KLD: {kld:.4f}')
    return train_loss / len(train_loader.dataset)

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)

#     if batch_idx==0:
#           inp = data.cpu().numpy()
#           outp = output.cpu().detach().numpy()
#           lab = data.cpu().numpy()
#           print("Input:")
#           print(decode_smiles_from_indexes(map(from_one_hot_array, inp[0]), charset))
#           print("Label:")
#           print(decode_smiles_from_indexes(map(from_one_hot_array, lab[0]), charset))
#           sampled = outp[0].reshape(1, 120, len(charset)).argmax(axis=2)[0]
#           print("Output:")
#           print(decode_smiles_from_indexes(sampled, charset))

10 / 1	16609.7051
train tensor(0.1661, device='cuda:0', grad_fn=<DivBackward0>)
KLD: 3.7999
20 / 1	12087.0645
train tensor(0.1209, device='cuda:0', grad_fn=<DivBackward0>)
KLD: 4.0856
30 / 1	9504.0215
train tensor(0.0950, device='cuda:0', grad_fn=<DivBackward0>)
KLD: 3.9255
40 / 1	8807.6455
train tensor(0.0881, device='cuda:0', grad_fn=<DivBackward0>)
KLD: 3.8718
50 / 1	12133.7646
train tensor(0.1213, device='cuda:0', grad_fn=<DivBackward0>)
KLD: 4.1607
60 / 1	7215.5122
train tensor(0.0722, device='cuda:0', grad_fn=<DivBackward0>)
KLD: 3.9230


In [187]:
output = model(X_test)

In [199]:
output_idx = list(output.argmax(axis=2).cpu().numpy())

In [200]:
output_idx[0]

array([25, 11, 11, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19,
       19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19,
        4, 26, 26, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28],
      dtype=int64)

In [201]:
output_char = [convert_num2str(i, idx2char) for i in output_idx]

In [203]:
output_char

['CCccccccccccccccccccccccccccccccc1',
 'CCcccccccccccccccccccccccccccccccccc1',
 'CCcccccccccccccccccccccccccccccccc',
 'CCcccccccccccccccccccccccc1',
 'CCcccccccccccccccccccccccccccccccc',
 'CCccccccccccccccccccccccccc1',
 'CCcccccccccccccccccccccc11',
 'CCccccccccccccccccccccccc1',
 'CCccccccccccccccccccccccccccccc',
 'CCccccccccccccccccccccccccc1',
 'CCcccccccccccccccccccccccccccccc1',
 'CCccccccccccccccccccccccccccc1',
 'CCccccccccccccccccccccccccccc1',
 'CCccccccccccccccccccccccccccccccccccc1',
 'CCcccccccccccccccccccccccccccccccc',
 'CCcccccccccccccccccccccccccc1',
 'CCccccccccccccccccccccccc1',
 'CCcccccccccccccccccccccccccccccc',
 'CCccccccccccccccccccccccccccccccccc1',
 'CCccccccccccccccccccccccccccc1',
 'CCcccccccccccccccccccccccccccc1',
 'CCccccccccccccccccccccccccccc1',
 'CCccccccccccccccccccccccccccc1',
 'CCcccccccccccccccccccccccccc1',
 'CCccccccccccccccccccccccccccccc',
 'CCcccccccccccccccccccccccccc1',
 'CCccccccccccccccccccccccccccc1',
 'CCcccccccccccccccccccccccccccc

In [None]:
#criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr = args.lr)
if DYN_LR:
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', 
                                factor = 0.8, 
                                patience = 3,
                                min_lr = 0.0001)

dataloader = torch.utils.data.DataLoader(X_train, 
                                        batch_size=args.batch_size,
                                        shuffle=True, 
                                        num_workers=6,
                                        drop_last = True)

val_dataloader = torch.utils.data.DataLoader(X_test, 
                                        batch_size=args.batch_size,
                                        shuffle=True, 
                                        num_workers=6,
                                        drop_last = True)

best_epoch_loss_val = 100000
x_train_data_per_epoch = X_train.shape[0] - X_train.shape[0]%args.batch_size
x_val_data_per_epoch = X_test.shape[0] - X_test.shape[0]%args.batch_size
print("Div Quantities",x_train_data_per_epoch,x_val_data_per_epoch)
print()
print("###########################################################################")
for epoch in range(args.epochs):
    epoch_loss = 0
    print("Epoch -- {}".format(epoch))

    for i, data in enumerate(dataloader):

        inputs = data.float().to(device)
        #inputs = inputs.reshape(batch_size, -1).float()
        optimizer.zero_grad()

        input_recon = model(inputs)
        latent_loss_val = latent_loss(model.z_mean, model.z_sigma)
        loss = F.binary_cross_entropy(input_recon, inputs, size_average=False) + latent_loss_val
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()


    print("Train Loss -- {:.3f}".format(epoch_loss/x_train_data_per_epoch))
    ###Add 1 Image per Epoch for Visualisation
    data_point_sampled = random.randint(0,args.batch_size-1)

    print("INPUT",inputs[data_point_sampled])
    print("OUTPUT",input_recon[data_point_sampled].reshape(1, 120, len(vocab)))

    print("Input -- ",onehot_to_smiles(inputs[data_point_sampled].reshape(1, 120, len(vocab)).cpu().detach(), inv_dict))
    print("Output -- ",onehot_to_smiles(input_recon[data_point_sampled].reshape(1, 120, len(vocab)).cpu().detach(), inv_dict))

    #####################Validation Phase
    epoch_loss_val = 0
    for i, data in enumerate(val_dataloader):

        inputs = data.to(device).float()
        #inputs = inputs.reshape(batch_size, -1).float()
        input_recon = model(inputs)
        latent_loss_val = latent_loss(model.z_mean, model.z_sigma)
        loss = F.binary_cross_entropy(input_recon, inputs, size_average=False) + latent_loss_val
        epoch_loss_val += loss.item()
    print("Validation Loss -- {:.3f}".format(epoch_loss_val/x_val_data_per_epoch))
    print()
    scheduler.step(epoch_loss_val)

    ###Add 1 Image per Epoch for Visualisation
    #data_point_sampled = random.randint(0,args.batch_size)
    #add_img(inputs[data_point_sampled], inv_dict, "Original_"+str(epoch))
    #add_img(model(inputs[data_point_sampled:data_point_sampled+1]), inv_dict, "Recon_"+str(epoch))

    checkpoint = {'model': model.state_dict(),
                'dict':vocab,
                'inv_dict':inv_dict,
                }

    #Saves when loss is lower than best validation loss till now and all models after 100 epochs
    if epoch_loss_recon_val < best_epoch_loss_val or epoch > 100:
        torch.save(checkpoint, args.save_loc+'/'+str(epoch)+'checkpoint.pth')
    #update best epoch loss
    best_epoch_loss_val = min(epoch_loss_val, best_epoch_loss_val)
#evaluate(model, X_train, vocab, inv_dict)

In [None]:
def init_weights(layer):
    '''
    Initialize weights based on layer type
    
    Args:
        layer (torch.nn): neural network whose weights to initialize
    '''
    if type(layer) == nn.Conv1d:
        init.normal_(m.weight.data)
        m.bias.data.fill_(0.01)
    if type(layer) == nn.Linear:
        n = m.in_features
        y = 1.0/np.sqrt(n)
        m.weight.data.uniform_(-y, y)
        m.bias.data.fill_(0)
    if type(layer) == nn.GRU:
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)      

In [None]:
class Encoder(nn.Module):
    def __init__(self, d_input, d_hidden_1, d_hidden_2):
        super(Encoder, self).__init__()
        self.linear1 = nn.Linear(d_input, d_hidden_1)
        self.linear2 = nn.Linear(d_hidden_1, d_hidden_2)
        self.d_hidden_1 = d_hidden_1
        self.d_hidden_2 = d_hidden_2

    def forward(self, x):
        h = F.relu(self.linear1(x))
        z = F.relu(self.linear2(h))
        return z

class Decoder(torch.nn.Module):
    def __init__(self, d_input, d_hidden_1, d_hidden_2):
        super(Decoder,self).__init__()
        self.linear1 = nn.Linear(d_hidden_2, d_hidden_1)
        self.linear2 = nn.Linear(d_hidden_1, d_input)

    def forward(self, z):
        h = F.relu(self.linear1(z))
        x = F.relu(self.linear2(h))
        return x

class VAE(nn.Module):

    def __init__(self, encoder, decoder, d_latent):
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.d_latent = d_latent

        self.d_hidden_1 = self.encoder.d_hidden_1
        self.d_hidden_2 = self.encoder.d_hidden_2

        self._enc_mu = torch.nn.Linear(self.d_hidden_2, self.d_latent)
        self._enc_log_var = torch.nn.Linear(self.d_hidden_2, self.d_latent)

    def _sample_latent(self, h_enc):
        '''
        Sample from latent space, z ~ N(μ, σ**2)
        '''
        mu = self._enc_mu(h_enc)
        log_var = self._enc_log_var(h_enc)
        sigma = torch.exp(log_var / 2)

        #epsilon = torch.randn(sigma.size()).float()
        epsilon = torch.randn_like(sigma).float()

        self.z_mean = mu
        self.z_sigma = sigma

        # use the reparameterization trick
        return mu + sigma * epsilon

    def forward(self, state):
        h_enc = self.encoder(state)
        z = self._sample_latent(h_enc)
        x_recon = self.decoder(z)
        return x_recon

    def get_num_params(self):
        print("Encoder--",sum(p.numel() for p in self.encoder.parameters() if p.requires_grad))
        print("Decoder--",sum(p.numel() for p in self.decoder.parameters() if p.requires_grad))
        print("Total--",sum(p.numel() for p in self.parameters() if p.requires_grad))

In [None]:
vocab_size = train_oh.shape[2]
input_dim = train_oh.shape[1] * vocab_size
d_hidden_1 = 200
d_hidden_2 = 120
d_latent = 60

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [153]:
q_bidir = True
q_d_h = 256
q_n_layers = 1
q_dropout = 0.5
d_n_layers = 3
d_dropout = 0
d_z = 128
d_d_h = 512
# from data import *

class MolVAE(nn.Module):
  def __init__(self, char, vector):
    super().__init__()
    self.vocabulary = vocab
    self.vector = vector
    
    n_vocab, d_emb = len(vocab), vector.size(1)
    self.x_emb = nn.Embedding(n_vocab, d_emb, char2idx['<pad>'])
    self.x_emb.weight.data.copy_(vector)
  
    # encoder
    self.encoder_rnn = nn.GRU(d_emb,q_d_h,num_layers=q_n_layers,batch_first=True,
                              dropout=q_dropout if q_n_layers > 1 else 0,bidirectional=q_bidir)
    q_d_last = q_d_h * (2 if q_bidir else 1)
    self.q_mu = nn.Linear(q_d_last, d_z)
    self.q_logvar = nn.Linear(q_d_last, d_z)
  
    # decoder
    self.decoder_rnn = nn.GRU(d_emb + d_z,d_d_h,num_layers=d_n_layers,batch_first=True,
                              dropout=d_dropout if d_n_layers > 1 else 0)
    self.decoder_latent = nn.Linear(d_z, d_d_h)
    self.decoder_fullyc = nn.Linear(d_d_h, n_vocab)
  
    # save model parameters as nn.ModuleList
    self.encoder = nn.ModuleList([self.encoder_rnn,self.q_mu,self.q_logvar])
    self.decoder = nn.ModuleList([self.decoder_rnn,self.decoder_latent,self.decoder_fullyc])
    self.vae = nn.ModuleList([self.x_emb,self.encoder,self.decoder])
    
  @property
  def device(self):
    return next(self.parameters()).device

  def string2tensor(self, string, device='model'):
    ids = convert_str2num(string, add_bos=True, add_eos=True)
    tensor = torch.tensor(ids, dtype=torch.long,device=self.device if device == 'model' else device)
    return tensor

  def tensor2string(self, tensor):
    ids = tensor.tolist()
    string = convert_num2str(ids, rem_bos=True, rem_eos=True)
    return string
  
  def forward(self,x):
    z, kl_loss = self.forward_encoder(x)
    recon_loss = self.forward_decoder(x, z)
    print("forward")
    return kl_loss, recon_loss
  
  def forward_encoder(self,x):
    x = [self.x_emb(i_x) for i_x in x]
    x = nn.utils.rnn.pack_sequence(x)
    _, h = self.encoder_rnn(x, None)
    h = h[-(1 + int(self.encoder_rnn.bidirectional)):]
    h = torch.cat(h.split(1), dim=-1).squeeze(0)
    mu, logvar = self.q_mu(h), self.q_logvar(h)
    eps = torch.randn_like(mu)
    z = mu + (logvar / 2).exp() * eps
    kl_loss = 0.5 * (logvar.exp() + mu ** 2 - 1 - logvar).sum(1).mean()
    return z, kl_loss
  
  def forward_decoder(self,x, z):
    lengths = [len(i_x) for i_x in x]
    x = nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value= char2idx['<pad>'])
    x_emb = self.x_emb(x)
    z_0 = z.unsqueeze(1).repeat(1, x_emb.size(1), 1)
    x_input = torch.cat([x_emb, z_0], dim=-1)
    x_input = nn.utils.rnn.pack_padded_sequence(x_input, lengths, batch_first=True)
    h_0 = self.decoder_latent(z)
    h_0 = h_0.unsqueeze(0).repeat(self.decoder_rnn.num_layers, 1, 1)
    output, _ = self.decoder_rnn(x_input, h_0)
    output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
    y = self.decoder_fullyc(output)
    
    recon_loss = F.cross_entropy(y[:, :-1].contiguous().view(-1, y.size(-1)),
                                 x[:, 1:].contiguous().view(-1),ignore_index= char2idx['<pad>'])
    return recon_loss
  
  def sample_z_prior(self,n_batch):
    return torch.randn(n_batch,self.q_mu.out_features,device= self.x_emb.weight.device)

  def sample(self,n_batch, max_len=100, z=None, temp=1.0):
    with torch.no_grad():
      if z is None:
        z = self.sample_z_prior(n_batch)
        z = z.to(self.device)
        z_0 = z.unsqueeze(1)
        h = self.decoder_latent(z)
        h = h.unsqueeze(0).repeat(self.decoder_rnn.num_layers, 1, 1)
        w = torch.tensor(char2idx['<bos>'], device=self.device).repeat(n_batch)
        x = torch.tensor([char2idx['<pad>']], device=device).repeat(n_batch, max_len)
        x[:, 0] = char2idx['<bos>']
        end_pads = torch.tensor([max_len], device=self.device).repeat(n_batch)
        eos_mask = torch.zeros(n_batch, dtype=torch.uint8, device=self.device)

        for i in range(1, max_len):
          x_emb = self.x_emb(w).unsqueeze(1)
          x_input = torch.cat([x_emb, z_0], dim=-1)

          o, h = self.decoder_rnn(x_input, h)
          y = self.decoder_fullyc(o.squeeze(1))
          y = F.softmax(y / temp, dim=-1)

          w = torch.multinomial(y, 1)[:, 0]
          x[~eos_mask, i] = w[~eos_mask]
          i_eos_mask = ~eos_mask & (w == char2idx['<eos>'])
          end_pads[i_eos_mask] = i + 1
          eos_mask = eos_mask | i_eos_mask
          
          new_x = []
          for i in range(x.size(0)):
            new_x.append(x[i, :end_pads[i]])

    return [self.tensor2string(i_x) for i_x in new_x]

In [112]:
temp = torch.tensor(train_idx, dtype=torch.long, device=self.device if device == 'model' else device)

In [114]:
temp.shape

torch.Size([1000, 50])

In [136]:
import numpy as np

CHARSET = [' ', '#', '(', ')', '+', '-', '/', '1', '2', '3', '4', '5', '6', '7',
           '8', '=', '@', 'B', 'C', 'F', 'H', 'I', 'N', 'O', 'P', 'S', '[', '\\', ']',
           'c', 'l', 'n', 'o', 'r', 's']

class OneHotFeaturizer(object):
    def __init__(self, charset=CHARSET, padlength=120):
        self.charset = CHARSET
        self.pad_length = padlength

    def featurize(self, smiles):
        return np.array([self.one_hot_encode(smi) for smi in smiles])

    def one_hot_array(self, i):
        return [int(x) for x in [ix == i for ix in range(len(self.charset))]]

    def one_hot_index(self, c):
        return self.charset.index(c)

    def pad_smi(self, smi):
        return smi.ljust(self.pad_length)

    def one_hot_encode(self, smi):
        return np.array([
            self.one_hot_array(self.one_hot_index(x)) for x in self.pad_smi(smi)
            ])

    def one_hot_decode(self, z):
        z1 = []
        for i in range(len(z)):
            s = ''
            for j in range(len(z[i])):
                oh = np.argmax(z[i][j])
                s += self.charset[oh]
            z1.append([s.strip()])
        return z1

    def decode_smiles_from_index(self, vec):
        return ''.join(map(lambda x: CHARSET[x], vec)).strip()

In [132]:
oh_smiles = ohf.featurize(train)

In [146]:
ohf = OneHotFeaturizer()
oh_smiles = ohf.featurize(X_train)
print(oh_smiles.shape)

(1000, 120, 37)


In [98]:
from sklearn import model_selection
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F

class MolecularVAE(nn.Module):
    def __init__(self):
        super(MolecularVAE, self).__init__()

        # The input filter dim should be 35
        #  corresponds to the size of CHARSET
        self.conv1d1 = nn.Conv1d(37, 9, kernel_size=9)  
        self.conv1d2 = nn.Conv1d(9, 9, kernel_size=9)
        self.conv1d3 = nn.Conv1d(9, 10, kernel_size=11)
        self.fc0 = nn.Linear(940, 435)
        self.fc11 = nn.Linear(435, 292)
        self.fc12 = nn.Linear(435, 292)

        self.fc2 = nn.Linear(292, 292)
        self.gru = nn.GRU(292, 501, 3, batch_first=True)
        self.fc3 = nn.Linear(501, 37)

    def encode(self, x):
        h = F.relu(self.conv1d1(x))
        h = F.relu(self.conv1d2(h))
        h = F.relu(self.conv1d3(h))
        h = h.view(h.size(0), -1)
        h = F.selu(self.fc0(h))
        return self.fc11(h), self.fc12(h)

    def reparametrize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = 1e-2 * torch.randn_like(std)
            w = eps.mul(std).add_(mu)
            return w
        else:
            return mu

    def decode(self, z):
        z = F.selu(self.fc2(z))
        z = z.view(z.size(0), 1, z.size(-1)).repeat(1, 60, 1)
        out, h = self.gru(z)
        out_reshape = out.contiguous().view(-1, out.size(-1))
        y0 = F.softmax(self.fc3(out_reshape), dim=1)
        y = y0.contiguous().view(out.size(0), -1, y0.size(-1))
        return y

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        return self.decode(z), mu, logvar

In [140]:
import numpy as np
from sklearn import model_selection
import torch
import torch.utils.data
from torch import nn, optim
import torch.nn.functional as F
# import h5py

def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, size_average=False)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

X = oh_smiles.astype(np.float32)

train = torch.utils.data.TensorDataset(torch.from_numpy(X))
train_loader = torch.utils.data.DataLoader(train, batch_size=250, shuffle=True)
torch.manual_seed(42)

epochs = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = MolecularVAE().to(device)
optimizer = optim.Adam(model.parameters())

def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(train_loader):
        data = data[0].transpose(1,2).to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data.transpose(1,2), mu, logvar)
        loss.backward()
        train_loss += loss
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f'{epoch} / {batch_idx}\t{loss:.4f}')
    print('train', train_loss / len(train_loader.dataset))
    return train_loss / len(train_loader.dataset)

def test(epoch):
    model.eval()
    test_loss = 0
    for batch_idx, data in enumerate(test_loader):
        data = data[0].transpose(1,2).to(device)
        recon_batch, mu, logvar = model(data)
        test_loss += loss_function(recon_batch, data.transpose(1,2), mu, logvar).item()
    print('test', test_loss / len(test_loader))

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    if epoch % 1 == 0:
        torch.save(model.state_dict(),
                './weights/vae-{:03d}-{}.pth'.format(epoch, train_loss))



1 / 0	138462.9219
train tensor(436.6977, device='cuda:0', grad_fn=<DivBackward0>)
2 / 0	84523.0391
train tensor(268.6349, device='cuda:0', grad_fn=<DivBackward0>)
3 / 0	56684.4297
train tensor(200.5436, device='cuda:0', grad_fn=<DivBackward0>)
4 / 0	45009.7422
train tensor(178.0354, device='cuda:0', grad_fn=<DivBackward0>)
5 / 0	43012.2305
train tensor(170.2285, device='cuda:0', grad_fn=<DivBackward0>)
6 / 0	41145.2773
train tensor(162.1197, device='cuda:0', grad_fn=<DivBackward0>)
7 / 0	38937.4180
train tensor(151.7357, device='cuda:0', grad_fn=<DivBackward0>)
8 / 0	35954.7070
train tensor(140.3110, device='cuda:0', grad_fn=<DivBackward0>)
9 / 0	33518.8320
train tensor(129.7825, device='cuda:0', grad_fn=<DivBackward0>)
10 / 0	33578.4609
train tensor(133.5314, device='cuda:0', grad_fn=<DivBackward0>)
11 / 0	32651.2031
train tensor(126.3311, device='cuda:0', grad_fn=<DivBackward0>)
12 / 0	31844.3652
train tensor(124.3849, device='cuda:0', grad_fn=<DivBackward0>)
13 / 0	30241.6758
train 

In [151]:
start_vec.shape

torch.Size([1, 120, 37])

In [147]:
start_vec = torch.from_numpy(oh.featurize([start]).astype(np.float32)).to('cuda')

In [152]:
model = MolecularVAE()
model.load_state_dict(torch.load('./weights/vae-100-102.80235290527344.pth'))
model.to('cuda')
model.eval()

start = 'C[C@@H]1CN(C(=O)c2cc(Br)cn2C)CC[C@H]1[NH3+]'
start = start.ljust(120)
oh = OneHotFeaturizer()
start_vec = torch.from_numpy(oh.featurize([start]).astype(np.float32)).to('cuda')

recon_x = model(start_vec)[0].cpu().detach().numpy()
y = np.argmax(recon_x, axis=2)
print(start)
print(oh.decode_smiles_from_index(y[0]))

RuntimeError: Given groups=1, weight of size [9, 37, 9], expected input[1, 120, 37] to have 37 channels, but got 120 channels instead

In [None]:
class Encoder(nn.Module):
    def __init__(self, hidden_size, embedding_size, embedding_dim, rnn, num_layers, bidirectional, device): 
        super(Encoder, self).__init__()        
        self.latent_dim = 1024

        # Encoder Setup

    def forward(self,input_seq):

        return mean, logv


class Decoder(nn.Module):
    def __init__(self): 
        super(Decoder, self).__init__()        
        self.latent_dim = 1024

        # Decoder Setup

    def forward(self, z, actual_input=None):

        #generated_sequence must be BATCH x SEQLENGTH, have type long, and contain
        #the index form of the generated sequences (smiles strings can be generated
        #by passing rows to Lang.indexToSmiles)
        return decoder_output, generated_sequence 


class VAE(nn.Module):      
    def __init__(self): #all sorts of hyper parameters should be passed here            
        super(VAE, self).__init__()        
          
        self.latent_dim = 1024
       
        self.encoder = Encoder()
        self.decoder = Decoder()
                                                                                          
    def forward(self, input_seq):

        mean, logv = self.encoder(input_seq)
        
        # calculate z

        decoder_output, generated_sequence = self.decoder(z,actual_input=input_seq)
                
        return decoder_output, generated_sequence, (mean, logv, z)

# Snippets

In [3]:
from tdc.generation import MolGen
data = MolGen(name = 'MOSES')
split = data.get_split()

Downloading...
100%|████████████████████████████████████| 75.3M/75.3M [00:08<00:00, 8.63MiB/s]
Loading...
Done!


In [8]:
from tdc.chem_utils import MolConvert
converter = MolConvert(src = 'SMILES', dst = 'SELFIES')

In [9]:
converter

<tdc.chem_utils.featurize.molconvert.MolConvert at 0x2a4c5a8c9e8>

In [None]:
## http://bits.csb.pitt.edu/mscbio2066/assign7/
#!/usr/bin/env python3

import gzip
import torch
import numpy as np
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import init
import argparse

class Lang:
    '''Predefined mapping from characters to indices for our
    reduced alphabet of SMILES with methods for converting.
    You must use this mapping.'''
    
    def __init__(self):
        self.chartoindex = {'EOS': 0,'SOS': 1, 'C': 2, '(': 3,
                '=': 4, 'O': 5, ')': 6, '[': 7, '-': 8, ']': 9,
                'N': 10, '+': 11, '1': 12, 'P': 13, '2': 14,'3': 15,
                '4': 16, 'S': 17, '#': 18, '5': 19,'6': 20, '7': 21,
                'H': 22, 'I': 23, 'B': 24, 'F': 25, '8': 26, '9': 27
                } 
        self.indextochar = {0: 'EOS', 1: 'SOS', 2: 'C', 3: '(',
                4: '=', 5: 'O', 6: ')', 7: '[', 8: '-', 9: ']',
                10: 'N', 11: '+', 12: '1', 13: 'P', 14: '2', 15: '3',
                16: '4', 17: 'S', 18: '#', 19: '5', 20: '6', 21: '7',
                22: 'H', 23: 'I', 24: 'B', 25: 'F', 26: '8', 27: '9'
                }
        self.nchars = 28
        
    def indexesFromSMILES(self, smiles_str):
        index_list = [self.chartoindex[char] for char in smiles_str]
        index_list.append(self.chartoindex["EOS"])
        return np.array(index_list, dtype=np.uint8)
        
    def indexToSmiles(self,indices):
        '''convert list of indices into a smiles string'''
        smiles_str = ''.join(list(map(lambda x: self.indextochar[int(x)] if x != 0.0 else 'E',indices)))
        return smiles_str.split('E')[0] #Only want values before output 'EOS' token


class SmilesDataset(torch.utils.data.Dataset):
    '''Dataset that reads in a gzipped smiles file and converts to a
    numpy array representation.  Note we encountered memory usage issues
    when using variable sequence length batches and so use a fixed size.
    There are likely more memory efficient ways to store this data.'''
    def __init__(self,data_path,max_length=150):
        self.max_length = max_length
        self.language = Lang()
        #TODO - for faster training you will want to preprocess
        #the training set and read in this processed file instead
        #for faster initialization
        with gzip.open(data_path,'rt') as f:
            N = sum(1 for line in f)
        self.examples = np.zeros((N,max_length),dtype=np.uint8)
        with gzip.open(data_path,'rt') as f:
            for i,line in enumerate(f):
                example = line.rstrip()
                ex = self.language.indexesFromSMILES(example)
                self.examples[i][:len(ex)] = ex
                
    def __len__(self):
        return len(self.examples)
        
    def __getitem__(self, idx):
        return torch.tensor(self.examples[idx], dtype=torch.long)
    
    def getIndexToChar(self):
        return self.language.indextochar

In [None]:
class Encoder(nn.Module):
    def __init__(self, hidden_size, embedding_size, embedding_dim,rnn, num_layers, bidirectional, device): 
        super(Encoder, self).__init__()        
        self.latent_dim = 1024

        # Encoder Setup

    def forward(self,input_seq):

        return mean, logv


class Decoder(nn.Module):
    def __init__(self): 
        super(Decoder, self).__init__()        
        self.latent_dim = 1024

        # Decoder Setup

    def forward(self, z, actual_input=None):

        #generated_sequence must be BATCH x SEQLENGTH, have type long, and contain
        #the index form of the generated sequences (smiles strings can be generated
        #by passing rows to Lang.indexToSmiles)
        return decoder_output, generated_sequence 


class VAE(nn.Module):      
    def __init__(self): #all sorts of hyper parameters should be passed here            
        super(VAE, self).__init__()        
          
        self.latent_dim = 1024
       
        self.encoder = Encoder()
        self.decoder = Decoder()
                                                                                          
    def forward(self, input_seq):

        mean, logv = self.encoder(input_seq)
        
        # calculate z

        decoder_output, generated_sequence = self.decoder(z,actual_input=input_seq)
                
        return decoder_output, generated_sequence, (mean, logv, z)
    




if __name__ == '__main__':
    LATENT_DIM = 1024

    parser = argparse.ArgumentParser('Train a Variational Autoencoder')
    parser.add_argument('--train_data','-T',required=True,help='data to train the VAE with')
    parser.add_argument('--out',default='vae_generate.pth',help='File to save generate function to')
    #more arguments...
    
    args = parser.parse_args()

    dataset = SmilesDataset(args.train_data)

    language_mapping = dataset.getIndexToChar()

    vae = VAE(args.hidden_size,len(language_mapping),args.embedding_dim,args.rnn_type,args.num_layers,args.bidirectional, device='cuda').to('cuda')

    #TRAIN THE MODEL


    # This will create the file that you will submit to evaluate SMILES generated from a normal distribution
    # If you want to implement your decoding as a distinct method in your vae module,
    # you will need to wrap calling it in another module and trace that.
    z_1 = torch.normal(0, 1, size=(1, LATENT_DIM),device='cuda')
    with torch.no_grad():
        vae.decoder.eval()
        traced = torch.jit.trace(vae.decoder, z_1.to('cuda'))

        torch.jit.save(traced,args.out)