Mostly based on https://github.com/aksub99/molecular-vae/blob/master/Molecular_VAE.ipynb 

Additions:

- Property prediction segment and auxiliary loss
- Different data prep (in load_data.ipynb) that more closely follows the original code https://github.com/aspuru-guzik-group/chemical_vae/
- Sigmoid annealing schedule
- Slower training it seems (TM)
- Validation set and loss

TODO:
- better data loading (canonical only, less storage space)
- teacher forcing gru

In [1]:
import numpy as np
import torch
from torch import nn, optim
from torch.nn import functional as F
import pandas as pd
# imports the torch_xla package
import os
TPU = 'COLAB_TPU_ADDR' in os.environ
if TPU:
  !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl
  import torch_xla
  import torch_xla.core.xla_model as xm


torch.manual_seed(42)

<torch._C.Generator at 0x7f3b3f412b70>

In [4]:
!git clone https://github.com/loodvn/pytorch-chemicalvae.git
!mv pytorch-chemicalvae/data data
# !ls data

Cloning into 'pytorch-chemicalvae'...
remote: Enumerating objects: 30, done.[K
remote: Counting objects: 100% (30/30), done.[K
remote: Compressing objects: 100% (28/28), done.[K
remote: Total 30 (delta 3), reused 22 (delta 1), pack-reused 0[K
Unpacking objects: 100% (30/30), done.


In [2]:
X = np.load('data/train_compressed.npz')['arr_0']
Y = np.load('data/Y_reg.npy')
# X = np.load('data/X_100.npy')
# Y = np.load('data/Y_reg100.npy')

In [3]:
# Put in load_data
from torch.utils.data import DataLoader, TensorDataset, DataLoader

TMP_TRAIN_SIZE = -1
BATCH_SIZE = 256

if TMP_TRAIN_SIZE < 0:
    TMP_TRAIN_SIZE = Y.shape[0]
# 75/25 split
valid_idx = int(TMP_TRAIN_SIZE*0.75)
x_train, y_train, x_valid, y_valid = map(torch.tensor, (X[:valid_idx], Y[:valid_idx], X[valid_idx:], Y[valid_idx:]))
  
del(X)  # Takes up too much RAM

train_ds = TensorDataset(x_train, y_train)
valid_ds = TensorDataset(x_valid, y_valid)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_ds, batch_size=2*BATCH_SIZE)

In [4]:
class ChemVAE(torch.nn.Module):
    def __init__(self):
        super(ChemVAE, self).__init__()
        
        self.latent_dims = 196  # p7, VAEs
        self.num_char = 35 # including +1 for padding
        
        # From Methods/Autoencoder architecture section (p13)
        self.enc_cnn1 = nn.Conv1d(in_channels=120, out_channels=9, kernel_size=9)  # 9,9
        self.enc_cnn2 = nn.Conv1d(in_channels=9, out_channels=9, kernel_size=9)  # 9,9
        self.enc_cnn3 = nn.Conv1d(in_channels=9, out_channels=11, kernel_size=10)  # 10, 11 (filter size, convolutional kernels)
        
        
        self.enc_fc_mu = nn.Linear(11*10, self.latent_dims)  # 11  (out_channels * whatever's left?)
        self.enc_fc_var = nn.Linear(11*10, self.latent_dims)  # 11
        
        
        self.dec_gru = nn.GRU(input_size=self.latent_dims, hidden_size=488, num_layers=3, batch_first=True)  # TODO input_size is latent space?
#         self.dec_gru_last = nn.GRU(input_size = self.latent_dims, hidden_size=488, )  # output GRU layer had one additional input, corresponding to the character sampled from the softmax output
        self.dec_fc  = nn.Linear(488, self.num_char)
        
        self.property_1 = nn.Linear(self.latent_dims, 1000)
        self.property_2 = nn.Linear(1000, 3)
        self.property_dropout = nn.Dropout(p=0.2)
        
        # TODO activation functions? Assuming tanh not relu? Also, difference between F.relu and nn.ReLU?
        self.act = F.relu
        
        
    def encode(self, x):
#         print("initial size:", x.shape)
        x = self.act(self.enc_cnn1(x))
#         print("initial size:", x.shape)
        x = self.act(self.enc_cnn2(x))
        x = self.act(self.enc_cnn3(x))
#         print("size after enc_cnns:", x.shape)

        x = x.view(x.size(0), -1) # Flatten, Keep batch size
        mu = self.enc_fc_mu(x)
        var = self.enc_fc_var(x)

        return mu, var

    def decode(self, z):
#         print("size before reshape", z.size)
        z = z.view(z.size(0), 1, z.size(-1))  # Expand_dims (1, latent_dim) -> (1, 1, latent_dim)
#         print("size mid-reshape", z.size)
        z = z.repeat(1, 120, 1)               # Repeat latent*120: (1, 1, latent_dim) -> (1, 120, latent_dim)
#         print("size after reshape", z.size)
        output, hn = self.dec_gru(z)
        softmax = self.dec_fc(output)
        softmax = F.softmax(softmax, dim=1)
#         print("softmax shape:", softmax.size())
        return softmax
        
    
    
    # Copied from PyTorch VAE example
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def prediction(self, z):
        # two fully connected layers of 1000 neurons, dropout rate of 0.2
        fc1 = self.act(self.property_dropout(self.property_1(z)))
#         print("prop1 shape: ", fc1.shape)
        pred = self.act(self.property_dropout(self.property_2(fc1)))
#         print("prop 2 shape", pred.shape)
        
        # output: batch size * 3
        return pred
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar, z

Training
- variational loss (KL divergence) annealed according to sigmoid schedule after 29 epochs, running for a total 120 epochs.
- output GRU layer had one additional input, corresponding to the character sampled from the softmax output, trained using teacher forcing

Getting output samples from softmax (depending on temperature):
https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html#preparing-for-training

Pytorch training loop over batches:
loss.backward()
opt.step()
opt.zero_grad()

Which reconstruction loss?
CE loss?

In [5]:
def one_hot_array(i, n):
    return map(int, [ix == i for ix in xrange(n)])

def one_hot_index(vec, charset):
    return map(charset.index, vec)

def from_one_hot_array(vec):
    oh = np.where(vec == 1)
    if oh[0].shape == (0, ):
        return None
    return int(oh[0][0])

def decode_smiles_from_indexes(vec, charset):
    return "".join(map(lambda x: charset[x], vec)).strip()

charset = ['n',
 '[',
 'o',
 'I',
 '3',
 'H',
 '+',
 'S',
 '@',
 '8',
 '4',
 '1',
 's',
 'N',
 'F',
 'P',
 '/',
 '=',
 'O',
 'B',
 'C',
 '\\',
 '(',
 '-',
 ']',
 '6',
 ')',
 'r',
 '5',
 '7',
 '2',
 '#',
 'l',
 'c',
 ' ']

In [6]:
def sigmoid_schedule(time_step, slope=1., start=22):
    return float(1 / (1. + np.exp(slope * (start - float(time_step)))))
sigmoid_schedule()

0.9996646498695336

# Baseline: Mean prediction


In [7]:
# 
logP = np.mean(np.abs(Y[:,0].mean()-Y[:,0]))
print("logP baseline: ", logP)
QED = np.mean(np.abs(Y[:,1].mean()-Y[:,1]))
print("QED baseline: ", QED)

logP baseline:  1.1381030935900205
QED baseline:  0.1121743723222245


In [8]:
(np.abs(Y.mean(axis=0)-Y)).mean(axis=0)  # logP, QED, SAS

array([1.13810309, 0.11217437, 0.66557906])

## Train

In [9]:
# From other pytorch implementation
def vae_loss(x_decoded_mean, x, z_mean, z_logvar):
    xent_loss = F.binary_cross_entropy(x_decoded_mean, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp())
    return xent_loss + kl_loss

def xent_loss(x_decoded_mean, x):
    return F.binary_cross_entropy(x_decoded_mean, x, reduction='sum')

def kl_loss(z_mean, z_logvar):
    return -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp())

# prediction loss: mse
def pred_loss(y_pred, y_true):
    return torch.mean((y_pred - y_true).pow(2)).to(device)

def mae(y_pred, y_true):
    return torch.mean(torch.abs(y_pred - y_true))

In [None]:
print("Starting training")
import time
start = time.time()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = xm.xla_device()
epochs = 20 #120

model = ChemVAE().to(device)
optimizer = optim.Adam(model.parameters())

SIGMOID_ANNEALING = True

# From other pytorch implementation TODO reference properly
# TODO save checkpoints every 1 hours
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(train_loader):
        y_true = data[1].to(device)
        data = data[0].to(device)
        optimizer.zero_grad()
        output, mean, logvar, z = model(data)
        pred = model.prediction(z)
#         print("pred:", pred.shape, "y: ", y_true.shape)
        
        if batch_idx==0:
              inp = data.cpu().numpy()
              outp = output.cpu().detach().numpy()
              lab = data.cpu().numpy()
              print("Input:")
              print(decode_smiles_from_indexes(map(from_one_hot_array, inp[0]), charset))
              print("Label:")
              print(decode_smiles_from_indexes(map(from_one_hot_array, lab[0]), charset))
              sampled = outp[0].reshape(1, 120, len(charset)).argmax(axis=2)[0]
              print("Output:")
              print(decode_smiles_from_indexes(sampled, charset))

        
#         print("pred loss: ", pred_loss(pred, y_true), "shape: ", pred_loss(pred, y_true).shape)
        sched = torch.tensor(sigmoid_schedule(epoch)).to(device) if SIGMOID_ANNEALING else 1
        loss = sched*kl_loss(mean, logvar) + xent_loss(output, data) + sched*pred_loss(pred, y_true)
        # import pdb; pdb.set_trace()
        loss.backward()
        train_loss += loss
        optimizer.step()
        if TPU:
          xm.mark_step()
        
        if batch_idx % 100 == 0:
            print(f'epoch {epoch} / batch {batch_idx}\tFull loss: {loss/BATCH_SIZE:.4f}')  # TODO print all of the loss components seperately
            pred_mae = mae(pred, y_true)
            print(f'epoch {epoch} / batch {batch_idx}\tPred mae loss: {pred_mae/BATCH_SIZE:.4f}')
#     print(f'epoch {epoch}: train loss:', (train_loss / len(train_loader.dataset)))
    return train_loss / len(train_loader.dataset)

def eval_model():
    model.eval()
    with torch.no_grad():
      eval_loss = 0
      eval_pred_loss = 0
      logP_loss = 0
      QED_loss = 0

      for batch_idx, data in enumerate(valid_loader):
        y_true = data[1].to(device)
        data = data[0].to(device)
        output, mean, logvar, z = model(data)
        pred = model.prediction(z)

        sched = torch.tensor(sigmoid_schedule(epoch)).to(device)
        loss = sched*kl_loss(mean, logvar) + xent_loss(output, data) + sched*pred_loss(pred, y_true)
        
        eval_loss += loss
        eval_pred_loss += pred_loss(pred, y_true)
        logP_loss += torch.sum(torch.abs(pred[:,0] - y_true[:,0]))  # MAE loss to reproduce Table 2
        QED_loss += torch.sum(torch.abs(pred[:,1] - y_true[:,1]))  # MAE loss to reproduce Table 2

    return eval_loss / len(valid_loader.dataset), eval_pred_loss / len(valid_loader.dataset), torch.sum(logP_loss) / len(valid_loader.dataset), torch.sum(QED_loss / len(valid_loader.dataset))

val_losses = []
val_pred_losses = []
logP_losses = []
qed_losses = []
for epoch in range(1, epochs + 1):
    e_start = time.time()
    train_loss = train(epoch)
    print(f"{epoch} Training loss: {train_loss}")
    e_end = time.time()
    print(f"Time per epoch ({epoch}): {e_end-e_start:.3f}s")

    print("Evaluating...")
    val_loss, eval_pred_loss, logP_loss, qed_loss = eval_model()
    print(f"Evaluation loss (training): \n{val_loss}, \n{eval_pred_loss}, \n{logP_loss}, \n{qed_loss}")
    
    val_losses.append(val_loss.item())
    val_pred_losses.append(eval_pred_loss.item())
    logP_losses.append(logP_loss.item())
    qed_losses.append(qed_loss.item())
    print(f"Elapsed time: {e_end-start:.3f}s")

end = time.time()
print(f"Total time taken: {int(end-start)//60}m{(end-start)%60:.3f}s")

Starting training
Input:
44-on6nF(FFFF-r4PP+\c4-4-on6n4-4646o4-46Oo41#444-on6O1c6F(
Label:
44-on6nF(FFFF-r4PP+\c4-4-on6n4-4646o4-46Oo41#444-on6O1c6F(
Output:
BB====ssssOOOOOrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrr
epoch 1 / batch 0	Full loss: 608.6711
epoch 1 / batch 0	Pred mae loss: 0.0078
epoch 1 / batch 100	Full loss: 534.9409
epoch 1 / batch 100	Pred mae loss: 0.0066
epoch 1 / batch 200	Full loss: 530.1423
epoch 1 / batch 200	Pred mae loss: 0.0047
epoch 1 / batch 300	Full loss: 527.1787
epoch 1 / batch 300	Pred mae loss: 0.0049
epoch 1 / batch 400	Full loss: 529.0965
epoch 1 / batch 400	Pred mae loss: 0.0050
epoch 1 / batch 500	Full loss: 527.3892
epoch 1 / batch 500	Pred mae loss: 0.0049
epoch 1 / batch 600	Full loss: 526.2923
epoch 1 / batch 600	Pred mae loss: 0.0048
epoch 1 / batch 700	Full loss: 526.9380
epoch 1 / batch 700	Pred mae loss: 0.0052
1 Training loss: 531.5980829516166
Time per epoch (1): 224.476s
Evaluat

Plot losses

In [None]:
df = pd.DataFrame({"val": val_losses, "val_pred": val_pred_losses, "logP": logP_losses, "qed": qed_losses})
df_total = df['val']
df_pred = df.drop(columns=["val"])
df_pred.plot()

## tmp
TPU Error: /usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
    130     retain_graph: Optional[bool] = None,
    131     create_graph: bool = False,
--> 132     only_inputs: bool = True,
    133     allow_unused: bool = False
    134 ) -> Tuple[torch.Tensor, ...]:

RuntimeError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1)

Loss after ~10 mins:
Evaluation loss (training):  (tensor(513.6445, device='cuda:0', dtype=torch.float64), tensor(0.0252, device='cuda:0', dtype=torch.float64), tensor([0.0208, 0.0050, 0.0183], device='cuda:0', dtype=torch.float64), tensor([0.0208, 0.0050, 0.0183], device='cuda:0', dtype=torch.float64


Starting training
Input:
4@(FF-BFcFFFF-4rO+]\144r4PP\-n6-4-I6-I6I6416Fc6F@(
Label:
4@(FF-BFcFFFF-4rO+]\144r4PP\-n6-4-I6-I6I6416Fc6F@(
Output:
ssssssssssn#111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111
epoch 1 / batch 0	Full loss: 155817.0156
epoch 1 / batch 0	Pred loss: 1.9640
epoch 1 / batch 100	Full loss: 135395.6875
epoch 1 / batch 100	Pred loss: 1.2291
epoch 1 / batch 200	Full loss: 135623.5313
epoch 1 / batch 200	Pred loss: 1.1131
epoch 1 / batch 300	Full loss: 135002.1563
epoch 1 / batch 300	Pred loss: 1.1840
epoch 1 / batch 400	Full loss: 135334.3438
epoch 1 / batch 400	Pred loss: 1.2385
epoch 1 / batch 500	Full loss: 134454.6094
epoch 1 / batch 500	Pred loss: 1.2087
epoch 1 / batch 600	Full loss: 134295.0781
epoch 1 / batch 600	Pred loss: 1.2694
epoch 1 / batch 700	Full loss: 134368.5938
epoch 1 / batch 700	Pred loss: 1.2350
Time per epoch (1): 208.468s
Evaluating...
Evaluation loss (training): 
523.524375040089, 
0.004267985728560455, 
1.5497679990182962, 
0.3987843898209971
Elapsed time: 218.533s
Input:
44O(44rO+]\c444Or4P+\c4(on
Label:
44O(44rO+]\c444Or4P+\c4(on
Output:
4=(1PFBBNNNNN222221666H(((l
epoch 2 / batch 0	Full loss: 134005.5000
epoch 2 / batch 0	Pred loss: 1.2578
epoch 2 / batch 100	Full loss: 133798.4688
epoch 2 / batch 100	Pred loss: 1.2414
epoch 2 / batch 200	Full loss: 134115.9375
epoch 2 / batch 200	Pred loss: 1.3186
epoch 2 / batch 300	Full loss: 134341.1875
epoch 2 / batch 300	Pred loss: 1.2813
epoch 2 / batch 400	Full loss: 133899.8438
epoch 2 / batch 400	Pred loss: 1.2360
epoch 2 / batch 500	Full loss: 133523.3281
epoch 2 / batch 500	Pred loss: 1.2713
epoch 2 / batch 600	Full loss: 134387.4063
epoch 2 / batch 600	Pred loss: 1.2330
epoch 2 / batch 700	Full loss: 133031.4063
epoch 2 / batch 700	Pred loss: 1.2627
Time per epoch (2): 223.383s
Evaluating...
Evaluation loss (training): 
518.1136467954343, 
0.003951255076182853, 
1.5243500593229338, 
0.36010640248726383
Elapsed time: 462.739s
Input:
44r4PP+\-4F(FFFFF(6O4-on6O(444r4P+\-4-on6Oc44n44c64(
Label:
44r4PP+\-4F(FFFFF(6O4-on6O(444r4P+\-4-on6Oc44n44c64(
Output:
no#-]]\((@FFFFFFFFNNNNNO332222222222222221111cc6H(((
epoch 3 / batch 0	Full loss: 132805.6719
epoch 3 / batch 0	Pred loss: 1.2380
epoch 3 / batch 100	Full loss: 133457.9688
epoch 3 / batch 100	Pred loss: 1.1281
epoch 3 / batch 200	Full loss: 132526.6563
epoch 3 / batch 200	Pred loss: 1.0878
epoch 3 / batch 300	Full loss: 131957.4688
epoch 3 / batch 300	Pred loss: 1.0116
epoch 3 / batch 400	Full loss: 132188.6094
epoch 3 / batch 400	Pred loss: 1.0222
epoch 3 / batch 500	Full loss: 131492.4531
epoch 3 / batch 500	Pred loss: 0.9942
epoch 3 / batch 600	Full loss: 131535.2813
epoch 3 / batch 600	Pred loss: 0.9581
epoch 3 / batch 700	Full loss: 130860.2813
epoch 3 / batch 700	Pred loss: 1.0071
Time per epoch (3): 225.783s
Evaluating...
Evaluation loss (training): 
511.09088476526335, 
0.002368703126348485, 
1.230272634271861, 
0.21858335207106477
Elapsed time: 710.049s
Input:
44-46-46@(FF-4-on6O44c-4-O6on644n44c6F@(
Label:
44-46-46@(FF-4-on6O44c-4-O6on644n44c6F@(
Output:
4=--6nnF(FFFrrrPrr]]++\\\\\\111111cccH((l
epoch 4 / batch 0	Full loss: 130977.4609
epoch 4 / batch 0	Pred loss: 1.0126

# Manually push data through network

In [None]:
example_input = x_train[0]
x = example_input
x = x.view(1, x.size(0), -1).to(device)
print(x.size())
mu, logvar = model.encode(x)
print(mu.shape, logvar.shape)

z = model.reparameterize(mu, logvar)
z.shape
output = model.decode(z)
print("decoded shape: ", output.shape)

out, m, l, z = model.forward(x)
vae_loss(out, x, m, l)

In [None]:
model.prediction(z).shape  # TODO should we still have batch here?