In [None]:
%cd ..

In [2]:
import argparse
import collections
import pyro
import torch
import numpy as np
import data_loader.data_loaders as module_data
import model.model as module_arch
from parse_config import ConfigParser
from trainer import Trainer

In [3]:
%matplotlib inline

In [4]:
# pyro.enable_validation(True)
# torch.autograd.set_detect_anomaly(True)

In [5]:
# fix random seeds for reproducibility
SEED = 123
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)

In [6]:
Args = collections.namedtuple('Args', 'config resume device')
config = ConfigParser.from_args(Args(config='omniglot_config.json', resume=None, device=None))

In [7]:
logger = config.get_logger('train')

In [8]:
# setup data_loader instances
data_loader = config.init_obj('data_loader', module_data)
valid_data_loader = data_loader.split_validation()

Files already downloaded and verified
Files already downloaded and verified


In [9]:
# build model architecture, then print to console
model = config.init_obj('arch', module_arch)

In [10]:
optimizer = pyro.optim.ReduceLROnPlateau({
    'optimizer': torch.optim.Adam,
    'optim_args': {
        "lr": 1e-3,
        "weight_decay": 0,
        "amsgrad": True
    },
    "patience": 50,
    "factor": 0.1,
    "verbose": True,
})

In [11]:
# optimizer = config.init_obj('optimizer', pyro.optim)

In [12]:
trainer = Trainer(model, [], optimizer, config=config,
                  data_loader=data_loader,
                  valid_data_loader=valid_data_loader)

In [13]:
trainer.train()

    epoch          : 1
    loss           : -8559.506116623847
    val_loss       : -10346.917751057943
    epoch          : 2
    loss           : -51163.041335137896
    val_loss       : -28662.722952651977
    epoch          : 3
    loss           : -71385.55160215237
    val_loss       : -44383.4581509908
    epoch          : 4
    loss           : -76858.99100612153
    val_loss       : -41838.66355142593
    epoch          : 5
    loss           : -81566.3032939322
    val_loss       : -54003.96229199568
    epoch          : 6
    loss           : -91011.69418519296
    val_loss       : -60872.51834947268
    epoch          : 7
    loss           : -86973.46594156355
    val_loss       : -52400.853752628966
    epoch          : 8
    loss           : -95483.38371041317
    val_loss       : -56635.630278937024
    epoch          : 9
    loss           : -104010.95857965866
    val_loss       : -61753.06954429944
    epoch          : 10
    loss           : -110567.59904572148
    