In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import time
import os
import random

from utils import *
from encoder import Encoder
from decoder import DecodeNext, Decoder

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Data & Model Parameters

In [15]:
smiles = list(fetch_smiles_gdb13('./data/gdb13/')[0])

In [16]:
params = make_params(smiles=smiles, GRU_HIDDEN_DIM=256, LATENT_DIM=128, to_file='gdb13_params.json')
#params = make_params(from_file='gdb13_params.json')

In [17]:
total_n = 100000
train_n = 100000
test_n = total_n - train_n

one_hots = to_one_hot(random.sample(smiles, total_n), params)

train_dataloader = DataLoader(one_hots[:train_n], batch_size=10, shuffle=True)
#test_dataloader = DataLoader(one_hots[train_n:], batch_size=5, shuffle=True)

print(f'total_n = {total_n}')
print(f'train_n = {train_n}')
print(f'test_n = {test_n}')

total_n = 100000
train_n = 100000
test_n = 0


### Model

In [40]:
encoder = Encoder(params)
decoder = Decoder(params)
print(encoder)

Encoder(
  (gru): GRU(21, 256, batch_first=True)
  (dense_encoder): Sequential(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): Tanh()
    (2): Dropout(p=0.1, inplace=False)
    (3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): Linear(in_features=256, out_features=128, bias=True)
    (5): Tanh()
    (6): Dropout(p=0.1, inplace=False)
    (7): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): Linear(in_features=128, out_features=128, bias=True)
  )
)


In [41]:
#encoder.load_state_dict(torch.load('encoder_weights.pth'))
#decoder.load_state_dict(torch.load('decoder_weights.pth'))

### Train

In [42]:
LR = 0.00001
EPOCHS = 2

In [50]:
encoder_optimizer = optim.Adam(encoder.parameters(), lr=LR)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=LR)

criterion = nn.CrossEntropyLoss()

losses = []

encoder.train()
decoder.train()

for epoch_n in range(EPOCHS):
    for x in train_dataloader:
        
        # x.shape = (N, L, C)
        
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        
        latents = encoder(x)
        
        y = decoder(latents, target=x)
        
        loss = criterion(y.transpose(1, 2), torch.argmax(x, dim=2))

        losses.append(float(loss))
        
        loss.backward()
        
        encoder_optimizer.step()
        decoder_optimizer.step()
        
torch.save(encoder.state_dict(), 'encoder_weights.pth')
torch.save(decoder.state_dict(), 'decoder_weights.pth')

torch.Size([10, 128])
Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/jshe/smiles_vae/env/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3526, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/scratch/local/jobs/10533712/ipykernel_3051911/1458238080.py", line 23, in <module>
    y = decoder(latents, target=x)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jshe/smiles_vae/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jshe/smiles_vae/decoder.py", line 104, in forward
    prediction, hidden = self.decode_next(inp, hidden)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jshe/smiles_vae/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jshe/smiles_vae/decoder.py", line 44, in f

In [None]:
plt.plot(losses)

## Test

In [None]:
evaluate_ae(encoder, decoder, smiles, 1000, params=params)