In [1]:
%cd ../..

/home/eli/AnacondaProjects/combinators


In [2]:
import logging
import torch
import tqdm

In [3]:
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [4]:
from combinators import lens, sampler
from combinators.model import collections
from examples.moving_mnist import data
from examples.multi_mnist import apg_inference, generative, proposals, spatial_transformer

In [5]:
LR = 2e-4
NUM_SWEEPS = 2
SAMPLE_BUDGET = 20

In [6]:
T = 10
NUM_DIGITS = 3
NUM_EPOCHS = 500
NUM_PARTICLES = SAMPLE_BUDGET // NUM_SWEEPS
BATCH_SIZE = 5

In [7]:
NUM_HIDDEN_DIGIT = 400
NUM_HIDDEN_LOCATION = 400

In [8]:
DIGIT_SIDE = 28
FRAME_SIDE = 96

WHAT_DIM = 10
WHERE_DIM = 2

In [9]:
data_args = {
    'data': 'movingmnist',
    'batch_size': BATCH_SIZE,
    'train': True,
    'timesteps': T,
    'num_digits': NUM_DIGITS,
    'frame_size': FRAME_SIDE,
    'dv': 0.1
}

In [10]:
data_paths = data.data_loader_indices(train=data_args['train'],
                                      timesteps=data_args['timesteps'],
                                      num_digits=data_args['num_digits'],
                                      frame_size=data_args['frame_size'],
                                      dv=data_args['dv'])

pbar = tqdm.tqdm(range(len(data_paths)))
for chunk_idx in pbar:
    train_loader = data.setup_data_loader(data_paths[chunk_idx], data_args['batch_size'], train=data_args['train'])

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:06<00:00,  9.11it/s]


In [11]:
BATCH_SHAPE = (BATCH_SIZE, NUM_PARTICLES, NUM_DIGITS)

In [12]:
attention = spatial_transformer.SpatialTransformer(FRAME_SIDE, DIGIT_SIDE)

In [13]:
object_codes = generative.InitialObjectCodes(WHAT_DIM)
proposal = proposals.ObjectCodesProposal(attention, NUM_HIDDEN_DIGIT, WHAT_DIM)
object_codes = sampler.importance_box('init_what', object_codes, BATCH_SHAPE, proposal, lens.PRO(0), lens.PRO(1))

In [14]:
object_locs  = generative.InitialObjectLocations(WHERE_DIM)
proposal = proposals.InitialLocationsProposal(attention, FRAME_SIDE, NUM_HIDDEN_LOCATION, WHERE_DIM)
object_locs = sampler.importance_box('init_where', object_locs, BATCH_SHAPE, proposal, lens.PRO(0), lens.PRO(1))

In [15]:
step_locs = generative.StepObjects(NUM_HIDDEN_DIGIT, WHERE_DIM, WHAT_DIM, attention)
proposal = proposals.StepLocationsProposal(attention, FRAME_SIDE, NUM_HIDDEN_LOCATION, WHERE_DIM)
step_locs = sampler.importance_box('step_where', step_locs, BATCH_SHAPE, proposal, lens.PRO(2), lens.PRO(1))

In [16]:
step_ssm = collections.parameterized_ssm(lens.PRO(1), lens.PRO(1), step_locs)

In [17]:
apg_model = (object_locs @ object_codes) >> collections.sequential(step_ssm, T)

In [18]:
apg_inference.apg(apg_model, 1, lr=LR, num_sweeps=NUM_SWEEPS)

RuntimeError: number of dims don't match in permute