In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import time
import os
import random

from utils import *
from encoder import Encoder
from decoder import DecodeNext, Decoder

%load_ext autoreload
%autoreload 2

### Data & Model Parameters

In [3]:
smiles = list(fetch_smiles_gdb13('./data/gdb13/')[0])

In [4]:
params = make_params(smiles=smiles, GRU_HIDDEN_DIM=256, LATENT_DIM=128, to_file='gdb13_params.json')
#params = make_params(from_file='gdb13_params.json')

### Model

In [5]:
encoder = Encoder(params)
decoder = Decoder(params)

In [6]:
encoder.load_state_dict(torch.load('weights/encoder_weights.pth'))
decoder.load_state_dict(torch.load('weights/decoder_weights.pth'))

<All keys matched successfully>

## Test

In [9]:
evaluate_ae(encoder, decoder, smiles, 1000, params=params)

mean token matching: 0.1872750073671341

['[NH3+]C1COC2=C1SC1=C2COCO1', '[NH2+]=C1NCC2CC1N1CCOC2C1', 'O=C1CC2[NH2+]CCC2C2=C1C=NS2', 'O=C1CNC2CC3=C(OC=C3)C12', 'OC1C2OCC11C[NH2+]C(C1)CCO2', '[NH3+]C1C2CC3CC4COC3C(O2)C14', 'OC1C2CCC(O2)C1C1=CC=CN1', 'OC12CCOCC34CC1[NH+](CC23)C4', 'OC1COC2=C(C1)C1=C(O2)C=NS1', 'NC1=C2C=CC3=NCCN3C2=CO1']
['<BOS>2--2<EOS>-]22<EOS><EOS><EOS>]]]22<EOS>2l<EOS>2<EOS><EOS><EOS><EOS>=N2<EOS><EOS>22l<EOS>2<EOS>2<EOS>', '<BOS>2--2<EOS>-]222<EOS>2<EOS>25+<EOS><EOS><EOS><EOS>22H2)]==N)N3+]22<EOS>2=', '<BOS>C254=C=N=-]221<EOS>13H<EOS><EOS><EOS>2<EOS><EOS><EOS>2<EOS><EOS><EOS><EOS>2<EOS><EOS><EOS><EOS>=<EOS><EOS><EOS>', '<BOS>2--2<EOS>-]22<EOS><EOS><EOS>]]]2<EOS>2<EOS><EOS>2N1<EOS>22222<EOS>22-CN2<EOS>22', '<BOS>C2S1==NNNNNNN=N2N=NHH<EOS><EOS>2)<EOS>11<EOS>1HH<EOS><EOS><EOS><EOS><EOS><EOS><EOS>', '<BOS>OS1N=NN1HH<EOS><EOS>2<EOS><EOS><EOS><EOS>C2l<EOS><EOS>5l<EOS>222<EOS>2<EOS>l)N12lHN', '<BOS>C251=C=N=N==C=H<EOS>2=<EOS>2l<EOS>N2]1]==H22]<EOS><EOS>2<EOS><EOS>=', '<BOS>C