In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import gzip
import h5py
import numpy as np
import argparse
import os
import torch.optim as optim
from model import MolecularICVAE
from utils import decode_smiles_from_indexes, load_dataset

###  load data and create conditional input from labels



(1) **y_train_ic (batch_size, 1,33)** is the condition input which is concatenated with the SMILES one-hot encoding vector **X_train (batch_size, 120,33)** as the input **(batch_size, 121,33)** of ICVAE.

(2) **y_train_l (batch_size, 128)** is the condition input which is used to constrain the latent vector (batch_size, 120,33) to the molecular properties.

note: we only set the first two dimension of latent vector as the conditions.

In [2]:
lat_dim = 128

X_train, X_test, charset = load_dataset('./data/processed.h5')

y_train = np.load("./prop_np/weight/y_train_norm.npy")
y_test = np.load("./prop_np/weight/y_test_norm.npy")

y_train_ic = np.repeat(y_train[:,np.newaxis], 33, -1)
y_test_ic = np.repeat(y_test[:,np.newaxis], 33, -1)

y_train_l = np.repeat(y_train[:, np.newaxis], lat_dim, axis=1)
y_test_l = np.repeat(y_test[:, np.newaxis], lat_dim, axis=1)

y_train_l[:,2:] = 0.
y_test_l[:,2:] = 0.

In [3]:
torch_X_train = torch.from_numpy(X_train).type(torch.FloatTensor)
torch_X_test = torch.from_numpy(X_test).type(torch.FloatTensor)

torch_ic_train = torch.from_numpy(y_train_ic).type(torch.FloatTensor) 
torch_ic_test = torch.from_numpy(y_test_ic).type(torch.FloatTensor)

torch_l_train = torch.from_numpy(y_train_l).type(torch.FloatTensor) 
torch_l_test = torch.from_numpy(y_test_l).type(torch.FloatTensor)

torch_lc_train = torch.from_numpy(y_train).type(torch.FloatTensor) 
torch_lc_test = torch.from_numpy(y_test).type(torch.FloatTensor)

train = torch.utils.data.TensorDataset(torch_X_train, torch_ic_train, torch_l_train, torch_lc_train)
test = torch.utils.data.TensorDataset(torch_X_test, torch_ic_test, torch_l_test, torch_lc_test)

In [6]:
train_loader = torch.utils.data.DataLoader(train, shuffle=True, batch_size=250)
test_loader = torch.utils.data.DataLoader(test, shuffle=True, batch_size=250)

### correlate the latent vector with molecular properties

the only difference between our model with other vae-based models is that we use **(z_mean-y_arg)** to fouce the mean value of latent vector into the molecular property value, while other vae-based models only use **z_mean**.

**note: y_arg is the y_train_l**.

In [7]:
def vae_loss(x_decoded_mean, x, z_mean, z_logvar, y_arg):
    xent_loss = F.binary_cross_entropy(x_decoded_mean, x, size_average=False)
    kl_loss = -0.5 * torch.sum(1 + z_logvar - (z_mean-y_arg).pow(2) - z_logvar.exp())
    return 0.5*xent_loss + kl_loss, kl_loss

In [8]:
torch.manual_seed(42)
epochs = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MolecularICVAE().to(device)
optimizer = optim.Adam(model.parameters())

In [9]:
def train(epoch):
    model.train()
    train_loss = 0
    KL_loss = 0
    latent_arr = []
    label_arr = []
    for batch_idx, data in enumerate(train_loader):
        oh, ic, l, lc = data
        oh,label_ic,label_l, label_lc = oh.unsqueeze(1).to(device),ic.to(device),l.to(device),lc.to(device)
        optimizer.zero_grad()
        output, mean, logvar,latent = model(oh, label_ic,label_lc)
        loss, kl_loss = vae_loss(output, oh.squeeze(1), mean, logvar, label_l)
        loss.backward()
        train_loss += loss.item()
        KL_loss+=kl_loss.item()
        optimizer.step()
        latent_arr.append(latent.cpu().detach().numpy())
        label_arr.append(label_l.cpu().detach().numpy()[:,:7])
        
    print('train CL: '+str((train_loss-KL_loss) / len(train_loader.dataset)) + '  train KL: ' + str(KL_loss / len(train_loader.dataset)))
    
    return train_loss / len(train_loader.dataset),latent_arr, label_arr

In [10]:
def test(epoch):
    model.eval()
    test_loss = 0
    KL_loss = 0
    for batch_idx, data in enumerate(test_loader):
        oh, ic, l, lc = data
        oh,label_ic,label_l, label_lc = oh.unsqueeze(1).to(device),ic.to(device),l.to(device),lc.to(device)
        optimizer.zero_grad()
        output, mean, logvar,latent = model(oh, label_ic,label_lc)
        loss, kl_loss = vae_loss(output, oh.squeeze(1), mean, logvar, label_l)
        KL_loss+=kl_loss.item()
        test_loss += loss.item()
    print('test CL: '+str((test_loss-KL_loss) / len(test_loader.dataset)) + '  test KL: ' + str(KL_loss / len(test_loader.dataset)))
    
    return test_loss / len(test_loader.dataset)

In [11]:
for epoch in range(1, epochs + 1):
    train_loss,latent_arr,label_arr = train(epoch)
    test_loss = test(epoch)



train CL: 362.369909375  train KL: 18481.088859375
test CL: 343.8789375  test KL: 4754.1374625
train CL: 342.2986390625  train KL: 4099.9815203125
test CL: 341.67598125  test KL: 3650.54915625
train CL: 341.5542484375  train KL: 3435.884471875
test CL: 341.30831875  test KL: 3134.9903125
train CL: 339.8165671875  train KL: 2930.931703125
test CL: 338.0603875  test KL: 2607.9593625
train CL: 339.5169796875  train KL: 9302.8506375
test CL: 337.26789375  test KL: 4212.7898875
train CL: 336.34605625  train KL: 3676.73779375
test CL: 335.8059625  test KL: 3147.043775
train CL: 335.875625  train KL: 2905.365784375
test CL: 334.623925  test KL: 2576.554525
train CL: 335.09417734375  train KL: 2372.05545625
test CL: 333.939609375  test KL: 2120.351171875
train CL: 333.43574453125  train KL: 2002.349496875
test CL: 333.579584375  test KL: 1814.442915625
train CL: 332.8581484375  train KL: 1745.85172734375
test CL: 331.87445625  test KL: 1610.31475
train CL: 332.35548828125  train KL: 1543.73333

train CL: 317.39893754882814  train KL: 108.87058256835938
test CL: 317.1351146484375  test KL: 154.5897400390625
train CL: 316.88756840820315  train KL: 88.16607827148438
test CL: 318.8325912109375  test KL: 136.1044595703125
train CL: 317.42407421875  train KL: 101.75393046875
test CL: 315.840619140625  test KL: 282.828018359375
train CL: 316.602931640625  train KL: 99.5677369140625
test CL: 315.91912734375  test KL: 190.274603125
train CL: 317.6676323730469  train KL: 86.41975024414063
test CL: 318.426579296875  test KL: 142.376055859375
train CL: 316.85592214355466  train KL: 112.35228684082031
test CL: 315.4366634765625  test KL: 132.1144537109375
train CL: 317.2624266845703  train KL: 87.81812761230469
test CL: 316.17753671875  test KL: 168.98747421875
train CL: 317.0639587402344  train KL: 94.91612778320312
test CL: 316.14378125  test KL: 117.1209578125
train CL: 316.4265678955078  train KL: 82.0211715576172
test CL: 317.583737109375  test KL: 136.905144140625
train CL: 316.1514

### save latent vector and model

save the latent vector for draw the latent image,

and save the model for sampling the molecule.

In [18]:
latent_np = np.array(latent_arr)
label_np = np.array(label_arr)

latent_np = latent_np[:,:,:2].reshape((-1,2))
label_np = label_np[:,:,:2].reshape((-1,2))
np.save("./result/latent/MW_latent.npy", latent_np)
np.save("./result/latent/MW_label.npy", label_np)

torch.save(model.state_dict(), "./result/model/MW_model.pth")