In [None]:
%run ../../import_envs.py
import numpy as np
import torch
import os
# training parameters
K = 3 # number of digits
T = 10 # timesteps
NUM_EPOCHS = 1000
APG_SWEEPS = 0
SAMPLE_SIZE = 20
BATCH_SIZE = 10
LR =  1 * 1e-4

## Model Parameters
FRAME_PIXELS = 96
DIGIT_PIXELS = 28
NUM_HIDDEN_DIGIT = 200
NUM_HIDDEN_COOR = 200
Z_WHERE_DIM = 2 # z_where dims
Z_WHAT_DIM = 10

MODEL_NAME = 'oneshot-as-baseline'
RESAMPLING_STRATEGY = 'systematic'
MODEL_VERSION = 'bmnist-%ddigits-%s' % (K, MODEL_NAME)

print('inference method:%s, resampling:%s, apg sweeps:%s, epochs:%s, sample size:%s, batch size:%s, learning rate:%s' % (MODEL_NAME, RESAMPLING_STRATEGY, APG_SWEEPS, NUM_EPOCHS, SAMPLE_SIZE, BATCH_SIZE, LR))

MNIST_MEAN_PATH = '../mnist_mean.npy'

CUDA = torch.cuda.is_available()
DEVICE = torch.device('cuda:0')

In [None]:
# train data path
DATA_PATHS = []
TRAIN_DATA_DIR = '/data/hao/nvi_data/bmnist/%ddigits/' % K
# TRAIN_DATA_DIR = '/data/hao/apg_data/bmnist/train/'

print('===Loading training data from %s===' % TRAIN_DATA_DIR)
for file in os.listdir(TRAIN_DATA_DIR):
    DATA_PATHS.append(os.path.join(TRAIN_DATA_DIR, file))
print('===%d groups of data files are loaded.===' % len(DATA_PATHS))

In [None]:
from BMNIST.apg_modeling import init_model
from BMNIST.affine_transformer import Affine_Transformer

model, optimizer = init_model(frame_pixels=FRAME_PIXELS,
                               digit_pixels=DIGIT_PIXELS, 
                               num_hidden_digit=NUM_HIDDEN_DIGIT, 
                               num_hidden_coor=NUM_HIDDEN_COOR, 
                               z_where_dim=Z_WHERE_DIM, 
                               z_what_dim=Z_WHAT_DIM, 
                               CUDA=CUDA, 
                               DEVICE=DEVICE, 
                               LOAD_VERSION=MODEL_VERSION, 
                               LR=LR)
AT = Affine_Transformer(frame_pixels=FRAME_PIXELS, 
                        digit_pixels=DIGIT_PIXELS, 
                        CUDA=CUDA, 
                        DEVICE=DEVICE)

from resample import Resampler
resampler = Resampler(strategy=RESAMPLING_STRATEGY,
                      sample_size=SAMPLE_SIZE,
                      CUDA=CUDA, 
                      DEVICE=DEVICE)

In [None]:
from BMNIST.apg_training import train
train(optimizer=optimizer, 
      model=model,
      AT=AT,
      resampler=resampler,
      apg_sweeps=APG_SWEEPS,
      data_paths=DATA_PATHS,
      mnist_mean_path=MNIST_MEAN_PATH, 
      K=K,
      num_epochs=NUM_EPOCHS,
      sample_size=SAMPLE_SIZE,
      batch_size=BATCH_SIZE,
      CUDA=CUDA,
      DEVICE=DEVICE,
      MODEL_VERSION=MODEL_VERSION)