In [1]:
%cd ..

/home/eli/AnacondaProjects/categorical_bpl


In [2]:
import argparse
import collections
import pyro
import torch
import numpy as np
import data_loader.data_loaders as module_data
import model.model as module_arch
from parse_config import ConfigParser
from trainer import Trainer

In [3]:
%matplotlib inline

In [4]:
# pyro.enable_validation(True)
# torch.autograd.set_detect_anomaly(True)

In [5]:
# fix random seeds for reproducibility
SEED = 123
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)

In [6]:
Args = collections.namedtuple('Args', 'config resume device')
config = ConfigParser.from_args(Args(config='chemical_config.json', resume=None, device=None))

In [7]:
logger = config.get_logger('train')

In [8]:
# setup data_loader instances
data_loader = config.init_obj('data_loader', module_data)
valid_data_loader = data_loader.split_validation()

In [9]:
# build model architecture, then print to console
model = config.init_obj('arch', module_arch, len(data_loader.dataset.alphabet), data_loader.dataset.max_len)

In [10]:
optimizer = pyro.optim.ReduceLROnPlateau({
    'optimizer': torch.optim.Adam,
    'optim_args': {
        "lr": 1e-3,
        "weight_decay": 0,
        "amsgrad": True
    },
    "patience": 500,
    "factor": 0.1,
    "verbose": True,
})

In [11]:
# optimizer = config.init_obj('optimizer', pyro.optim)

In [12]:
trainer = Trainer(model, [], optimizer, config=config,
                  data_loader=data_loader,
                  valid_data_loader=valid_data_loader,
                  lr_scheduler=optimizer, log_images=False, log_step=128)

In [13]:
trainer.train()

    epoch          : 1
    loss           : 13462.489185600187
    val_loss       : 12778.562715682812
    val_log_likelihood: -12651.409677257805
    val_log_marginal: -12673.81804152637
    epoch          : 2
    loss           : 12646.095470171631
    val_loss       : 12618.526757684005
    val_log_likelihood: -12576.564716772644
    val_log_marginal: -12599.938644158414
    epoch          : 3
    loss           : 12606.822964162015
    val_loss       : 12592.67585207355
    val_log_likelihood: -12569.1071792287
    val_log_marginal: -12585.684240669389
    epoch          : 4
    loss           : 12590.151348770936
    val_loss       : 12587.28955834878
    val_log_likelihood: -12567.976368156793
    val_log_marginal: -12582.2468197438
    epoch          : 5
    loss           : 12584.497491308673
    val_loss       : 12583.06028246757
    val_log_likelihood: -12566.50969696676
    val_log_marginal: -12576.339917911204
    epoch          : 6
    loss           : 12578.471785340416
 

In [14]:
model.cpu()

SelfiesAutoencodingModel(
  (_operad): FreeOperad(
    (generator_0): RecurrentDecoder(
      (recurrence): GRU(12, 64, batch_first=True)
      (decoder): Sequential(
        (0): Linear(in_features=64, out_features=18, bias=True)
        (1): Softmax(dim=-1)
      )
    )
    (generator_1): RecurrentDecoder(
      (recurrence): GRU(12, 64, batch_first=True)
      (decoder): Sequential(
        (0): Linear(in_features=64, out_features=18, bias=True)
        (1): Softmax(dim=-1)
      )
    )
    (generator_2): RecurrentDecoder(
      (recurrence): GRU(12, 64, num_layers=2, batch_first=True)
      (decoder): Sequential(
        (0): Linear(in_features=128, out_features=18, bias=True)
        (1): Softmax(dim=-1)
      )
    )
    (generator_3): RecurrentDecoder(
      (recurrence): GRU(12, 64, num_layers=2, batch_first=True)
      (decoder): Sequential(
        (0): Linear(in_features=128, out_features=18, bias=True)
        (1): Softmax(dim=-1)
      )
    )
    (generator_4): Recurren