In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import time
import os
import random

from utils import *
from embedding_utils import *
from encoder import Encoder
from decoder import DecodeNext, Decoder

%load_ext autoreload
%autoreload 2

### Data & Model Parameters

In [None]:
smiles = list(fetch_smiles_gdb13('./data/gdb13/')[0])

In [None]:
params = make_params(smiles=smiles, GRU_HIDDEN_DIM=256, LATENT_DIM=128)

### Model

In [None]:
encoder = Encoder(params)
decoder = Decoder(params)

In [None]:
encoder.load_state_dict(torch.load('weights/encoder_weights.pth'))
decoder.load_state_dict(torch.load('weights/decoder_weights.pth'))

In [None]:
test_smile = ["S1C=CC=C1"]

x = to_one_hot(smiles[0], params)
y = to_one_hot(test_smile, params)

print(smiles[0])
print(test_smile)

In [None]:
z_mean, z_logvar, z = encoder(x)
y_hat = decoder(z)

In [None]:
# Loss

CE_loss = lambda predicted, target : torch.mean(-torch.sum(target * torch.log(predicted)))/21

KL_divergence = lambda z_mean, z_logvar : -0.5 * torch.sum(1 + z_logvar - z_mean ** 2 - torch.exp(z_logvar))

loss = CE_loss(y_hat, x)# + KL_divergence(z_mean, z_logvar) * 0.01

In [None]:
loss

In [None]:
KL_divergence(z_mean, z_logvar)

In [None]:
from_one_hot(y_hat, params)

## Test

In [None]:
evaluate_ae(encoder, decoder, smiles, 1000, params=params)

In [None]:
log = pd.read_csv('log.csv')

In [None]:
plt.plot(log['i'], log[' similarity'])