In [1]:
%cd ../..

/home/eli/AnacondaProjects/combinators


In [2]:
import logging
import torch
import tqdm

In [3]:
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [4]:
from combinators import lens, sampler
from combinators.model import collections
from examples.moving_mnist import data
from examples.multi_mnist import apg_inference, generative, proposals, spatial_transformer

In [5]:
LR = 1e-4
NUM_SWEEPS = 6
SAMPLE_BUDGET = 60

In [6]:
T = 10
NUM_DIGITS = 3
NUM_EPOCHS = 10
NUM_PARTICLES = SAMPLE_BUDGET // NUM_SWEEPS
BATCH_SIZE = 5
PATIENCE = 1

In [7]:
NUM_HIDDEN_DIGIT = 400
NUM_HIDDEN_LOCATION = 400

In [8]:
DIGIT_SIDE = 28
FRAME_SIDE = 96

WHAT_DIM = 10
WHERE_DIM = 2

In [9]:
data_args = {
    'data': 'movingmnist',
    'batch_size': BATCH_SIZE,
    'train': True,
    'timesteps': T,
    'num_digits': NUM_DIGITS,
    'frame_size': FRAME_SIDE,
    'dv': 0.1
}

In [10]:
class ChunkLoader:
    def __init__(self):
        self.data_paths = data.data_loader_indices(
            train=data_args['train'], timesteps=data_args['timesteps'],
            num_digits=data_args['num_digits'],
            frame_size=data_args['frame_size'],
            dv=data_args['dv']
        )
    
    def chunks(self):
        pbar = tqdm.tqdm(range(len(self.data_paths)))
        for chunk in pbar:
            train = data.setup_data_loader(self.data_paths[chunk],
                                           data_args['batch_size'],
                                           train=data_args['train'])
            yield chunk, train

In [11]:
loader = ChunkLoader()

In [12]:
BATCH_SHAPE = (BATCH_SIZE,)
PARTICLE_SHAPE = (NUM_PARTICLES,)

In [13]:
attention = spatial_transformer.SpatialTransformer(FRAME_SIDE, DIGIT_SIDE, WHAT_DIM, NUM_HIDDEN_DIGIT)

In [14]:
object_codes = generative.InitialObjectCodes(WHAT_DIM, NUM_DIGITS)
proposal = proposals.ObjectCodesProposal(attention, NUM_HIDDEN_DIGIT, WHAT_DIM)
object_codes = sampler.importance_box('init_what', object_codes, proposal, BATCH_SHAPE, PARTICLE_SHAPE, lens.PRO(0), lens.PRO(1))

In [15]:
object_locs  = generative.InitialObjectLocations(WHERE_DIM, NUM_DIGITS)
proposal = proposals.InitialLocationsProposal(attention, FRAME_SIDE, DIGIT_SIDE, NUM_HIDDEN_LOCATION, WHERE_DIM)
object_locs = sampler.importance_box('init_where', object_locs, proposal, BATCH_SHAPE, PARTICLE_SHAPE, lens.PRO(0), lens.PRO(1))

In [16]:
step_locs = generative.StepObjects(WHERE_DIM, attention)
proposal = proposals.StepLocationsProposal(attention, FRAME_SIDE, DIGIT_SIDE, NUM_HIDDEN_LOCATION, WHERE_DIM)
step_locs = sampler.importance_box('step_where', step_locs, proposal, BATCH_SHAPE, PARTICLE_SHAPE, lens.PRO(2), lens.PRO(1))

In [17]:
step_ssm = collections.sequential_ssm(lens.PRO(1), lens.PRO(1), step_locs, T)

In [18]:
apg_model = (object_locs @ object_codes) >> step_ssm

In [19]:
apg_inference.apg(apg_model, NUM_EPOCHS, PARTICLE_SHAPE, loader.chunks(), lr=LR, num_sweeps=NUM_SWEEPS, patience=PATIENCE)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [01:19<00:00,  1.33s/it]
  0%|                                                                                                                                                                                                                                                              | 0/10 [00:00<?, ?it/s]
  0%|                                                                                                                                                                                                                                                              | 0/60 [00:00<?, ?it/s][A
  2%|████                                                                                                                                           

tensor([-1.7253e+10, -1.7262e+10, -1.7267e+10, -1.7267e+10, -1.7268e+10,
        -1.7266e+10, -1.7263e+10, -1.7269e+10, -1.7273e+10, -1.7278e+10])