In [1]:
%cd ..

/home/eli/AnacondaProjects/categorical_bpl


In [2]:
import argparse
import collections
import pyro
import torch
import numpy as np
import data_loader.data_loaders as module_data
import model.model as module_arch
from parse_config import ConfigParser
from trainer import Trainer

In [3]:
%matplotlib inline

In [4]:
# pyro.enable_validation(True)
# torch.autograd.set_detect_anomaly(True)

In [5]:
# fix random seeds for reproducibility
SEED = 123
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)

In [6]:
Args = collections.namedtuple('Args', 'config resume device')
config = ConfigParser.from_args(Args(config='chemical_config.json', resume=None, device=None))

In [7]:
logger = config.get_logger('train')

In [8]:
# setup data_loader instances
data_loader = config.init_obj('data_loader', module_data)
valid_data_loader = data_loader.split_validation()

In [9]:
# build model architecture, then print to console
model = config.init_obj('arch', module_arch)

In [10]:
optimizer = pyro.optim.ReduceLROnPlateau({
    'optimizer': torch.optim.Adam,
    'optim_args': {
        "lr": 1e-4,
        "weight_decay": 0,
        "amsgrad": True
    },
    "patience": 5,
    "factor": 0.1,
    "verbose": True,
})

In [11]:
# optimizer = config.init_obj('optimizer', pyro.optim)

In [12]:
trainer = Trainer(model, [], optimizer, config=config,
                  data_loader=data_loader,
                  valid_data_loader=valid_data_loader,
                  lr_scheduler=optimizer, log_images=False)

In [13]:
trainer.train()

    epoch          : 1
    loss           : 59610.15647442051
    val_loss       : 29540.532146309375
    epoch          : 2
    loss           : 24987.04930140785
    val_loss       : 18361.310375298774
    epoch          : 3
    loss           : 16661.594989867746
    val_loss       : 15367.864824351635
    epoch          : 4
    loss           : 14878.382201454067
    val_loss       : 14075.20717163447
    epoch          : 5
    loss           : 14300.013958511092
    val_loss       : 13858.643186481631
    epoch          : 6
    loss           : 14135.414915742322
    val_loss       : 13357.040671203818
    epoch          : 7
    loss           : 13753.11090594781
    val_loss       : 14594.07436124646
    epoch          : 8
    loss           : 14078.343145620023
    val_loss       : 13992.593780232859
    epoch          : 9
    loss           : 14754.08815504124
    val_loss       : 16101.55443220529
    epoch          : 10
    loss           : 14478.094203284983
    val_loss    

In [14]:
model.cpu()

MolecularVaeCategoryModel(
  (_category): FreeCategory(
    (generator_0): MolecularDecoder(
      (pre_recurrence_linear): Sequential(
        (0): Linear(in_features=196, out_features=34, bias=True)
        (1): SELU()
      )
      (recurrence1): GRUCell(34, 64)
      (recurrence2): GRUCell(64, 64)
      (decoder): Sequential(
        (0): Linear(in_features=64, out_features=34, bias=True)
        (1): Softmax(dim=-1)
      )
    )
    (generator_1): MolecularDecoder(
      (pre_recurrence_linear): Sequential(
        (0): Linear(in_features=196, out_features=34, bias=True)
        (1): SELU()
      )
      (recurrence1): GRUCell(34, 64)
      (recurrence2): GRUCell(64, 64)
      (decoder): Sequential(
        (0): Linear(in_features=64, out_features=34, bias=True)
        (1): Softmax(dim=-1)
      )
    )
    (generator_2): MolecularDecoder(
      (pre_recurrence_linear): Sequential(
        (0): Linear(in_features=196, out_features=34, bias=True)
        (1): SELU()
      )
     

In [15]:
valid_xs, valid_ys = list(valid_data_loader)[-1]

In [16]:
m, recons = model(observations=valid_xs, train=False)

In [20]:
(recons == valid_xs).all(dim=-1).to(dtype=torch.float).mean(dim=0)

tensor([0.8036, 0.1845, 0.1845, 0.2321, 0.0417, 0.1429, 0.0952, 0.0833, 0.1071,
        0.0952, 0.0774, 0.0357, 0.1488, 0.0952, 0.1667, 0.1488, 0.0833, 0.1607,
        0.1310, 0.1131, 0.0833, 0.0655, 0.1012, 0.0833, 0.1131, 0.1429, 0.1250,
        0.1429, 0.0655, 0.0893, 0.1250, 0.1131, 0.0833, 0.0893, 0.1071, 0.1131,
        0.1786, 0.1667, 0.1250, 0.1964, 0.2321, 0.3690, 0.3750, 0.4940, 0.5655,
        0.6488, 0.6905, 0.7262, 0.7619, 0.8155, 0.8512, 0.8750, 0.8988, 0.9107,
        0.9405, 0.9583, 0.9702, 0.9762, 0.9940, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 