In [1]:
import sys
sys.path.append('../src')

In [2]:
import argparse
import math
import os
import time
from functools import partial

import numpy as np
import torch
import visdom

import pyro
import pyro.contrib.examples.multi_mnist as multi_mnist
import pyro.optim as optim
import pyro.poutine as poutine
from components.AIR import AIR, latents_to_tensor
from pyro.contrib.examples.util import get_data_directory
from pyro.infer import SVI, JitTraceGraph_ELBO, TraceGraph_ELBO
from utils.viz import draw_many, tensor_to_objs
from utils.visualizer import plot_mnist_sample
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
%matplotlib inline

## Args

In [3]:
parser = argparse.ArgumentParser(description="Pyro AIR example", argument_default=argparse.SUPPRESS)
parser.add_argument('-n', '--num-steps', type=int, default=int(1e8),
                    help='number of optimization steps to take')
parser.add_argument('-b', '--batch-size', type=int, default=64,
                    help='batch size')
parser.add_argument('-lr', '--learning-rate', type=float, default=1e-4,
                    help='learning rate')
parser.add_argument('-blr', '--baseline-learning-rate', type=float, default=1e-3,
                    help='baseline learning rate')
parser.add_argument('--progress-every', type=int, default=1,
                    help='number of steps between writing progress to stdout')
parser.add_argument('--eval-every', type=int, default=0,
                    help='number of steps between evaluations')
parser.add_argument('--baseline-scalar', type=float,
                    help='scale the output of the baseline nets by this value')
parser.add_argument('--no-baselines', action='store_true', default=False,
                    help='do not use data dependent baselines')
parser.add_argument('--encoder-net', type=int, nargs='+', default=[200],
                    help='encoder net hidden layer sizes')
parser.add_argument('--decoder-net', type=int, nargs='+', default=[200],
                    help='decoder net hidden layer sizes')
parser.add_argument('--predict-net', type=int, nargs='+',
                    help='predict net hidden layer sizes')
parser.add_argument('--embed-net', type=int, nargs='+',
                    help='embed net architecture')
parser.add_argument('--bl-predict-net', type=int, nargs='+',
                    help='baseline predict net hidden layer sizes')
parser.add_argument('--non-linearity', type=str,
                    help='non linearity to use throughout')
parser.add_argument('--viz', action='store_true', default=True,
                    help='generate vizualizations during optimization')
parser.add_argument('--viz-every', type=int, default=100,
                    help='number of steps between vizualizations')
parser.add_argument('--visdom-env', default='main',
                    help='visdom enviroment name')
parser.add_argument('--load', type=str, default="/Users/chamathabeysinghe/Projects/monash/VAE_v2/checkpoints/model-size-75-3ants.ckpt",
                    help='load previously saved parameters')
parser.add_argument('--save', type=str, default="/Users/chamathabeysinghe/Projects/monash/VAE_v2/checkpoints/model-size-75-3ants.ckpt",
                    help='save parameters to specified file')
parser.add_argument('--save-every', type=int, default=100,
                    help='number of steps between parameter saves')
parser.add_argument('--cuda', action='store_true', default=False,
                    help='use cuda')
parser.add_argument('--jit', action='store_true', default=False,
                    help='use PyTorch jit')
parser.add_argument('-t', '--model-steps', type=int, default=3,
                    help='number of time steps')
parser.add_argument('--rnn-hidden-size', type=int, default=256,
                    help='rnn hidden size')
parser.add_argument('--encoder-latent-size', type=int, default=50,
                    help='attention window encoder/decoder latent space size')
parser.add_argument('--decoder-output-bias', type=float,
                    help='bias added to decoder output (prior to applying non-linearity)')
parser.add_argument('--decoder-output-use-sigmoid', action='store_true',
                    help='apply sigmoid function to output of decoder network')
parser.add_argument('--window-size', type=int, default=28,
                    help='attention window size')
parser.add_argument('--z-pres-prior', type=float, default=0.5,
                    help='prior success probability for z_pres')
parser.add_argument('--z-pres-prior-raw', action='store_true', default=False,
                    help='use --z-pres-prior directly as success prob instead of a geometric like prior')
parser.add_argument('--anneal-prior', choices='none lin exp'.split(), default='none',
                    help='anneal z_pres prior during optimization')
parser.add_argument('--anneal-prior-to', type=float, default=1e-7,
                    help='target z_pres prior prob')
parser.add_argument('--anneal-prior-begin', type=int, default=0,
                    help='number of steps to wait before beginning to anneal the prior')
parser.add_argument('--anneal-prior-duration', type=int, default=100000,
                    help='number of steps over which to anneal the prior')
parser.add_argument('--pos-prior-mean', type=float,
                    help='mean of the window position prior')
parser.add_argument('--pos-prior-sd', type=float,
                    help='std. dev. of the window position prior')
parser.add_argument('--scale-prior-mean', type=float,
                    help='mean of the window scale prior')
parser.add_argument('--scale-prior-sd', type=float,
                    help='std. dev. of the window scale prior')
parser.add_argument('--no-masking', action='store_true', default=False,
                    help='do not mask out the costs of unused choices')
parser.add_argument('--seed', type=int, help='random seed', default=None)
parser.add_argument('-v', '--verbose', action='store_true', default=False,
                    help='write hyper parameters and network architecture to stdout')

_StoreTrueAction(option_strings=['-v', '--verbose'], dest='verbose', nargs=0, const=True, default=False, type=None, choices=None, help='write hyper parameters and network architecture to stdout', metavar=None)

In [4]:
# vars(parser.parse_args(""))
args = argparse.Namespace(**vars(parser.parse_args("")))

In [5]:
def make_prior(k):
    assert 0 < k <= 1
    u = 1 / (1 + k + k**2 + k**3)
    p0 = 1 - u
    p1 = 1 - (k * u) / p0
    p2 = 1 - (k**2 * u) / (p0 * p1)
    trial_probs = [p0, p1, p2]
    # dist = [1 - p0, p0 * (1 - p1), p0 * p1 * (1 - p2), p0 * p1 * p2]
    # print(dist)
    return lambda t: trial_probs[t]

In [6]:
# if 'save' in args:
#     if os.path.exists(args.save):
#         raise RuntimeError('Output file "{}" already exists.'.format(args.save))

if args.seed is not None:
    pyro.set_rng_seed(args.seed)

# Build a function to compute z_pres prior probabilities.
if args.z_pres_prior_raw:
    def base_z_pres_prior_p(t):
        return args.z_pres_prior
else:
    base_z_pres_prior_p = make_prior(args.z_pres_prior)

# Wrap with logic to apply any annealing.
def z_pres_prior_p(opt_step, time_step):
    p = base_z_pres_prior_p(time_step)
    if args.anneal_prior == 'none':
        return p
    else:
        decay = dict(lin=lin_decay, exp=exp_decay)[args.anneal_prior]
        return decay(p, args.anneal_prior_to, args.anneal_prior_begin,
                     args.anneal_prior_duration, opt_step)



## Data loader

In [7]:
def load_data():
#     inpath = './air/.data'
#     X_np, Y = multi_mnist.load(inpath)
    X_np = np.load('/Users/chamathabeysinghe/Projects/monash/test/variational_auto_encoder/data/ANTS-SYNTHETIC/original: 300-resize: 75-3ants.npy')
    X_np = X_np.astype(np.float32)
    X_np /= 255.0
    X = torch.from_numpy(X_np)
    # Using FloatTensor to allow comparison with values sampled from
    # Bernoulli.
    counts = torch.FloatTensor([1 for objs in X_np])
    return X, counts

In [47]:
def load_data2():
#     inpath = './air/.data'
#     X_np, Y = multi_mnist.load(inpath)

    path = '/Users/chamathabeysinghe/Projects/monash/VAE_v2/data/synthetic/complex_dataset/original:_300-resize:_75-1ants/{:05d}.png'
    X_np = []
    for i in range(50):      
        img = np.asarray(Image.open(path.format(i)))
        X_np.append(img)
        
    X_np = np.asarray(X_np)    
    X_np = X_np.astype(np.float32)
    X_np /= 255.0
    X = torch.from_numpy(X_np)
    # Using FloatTensor to allow comparison with values sampled from
    # Bernoulli.
    counts = torch.FloatTensor([1 for objs in X_np])
    return X, counts

In [48]:
X, counts = load_data2()
X_size = X.size(0)
if args.cuda:
    X = X.cuda()
    
    

In [49]:
X.shape

torch.Size([50, 75, 75])

## Model

In [50]:
model_arg_keys = ['window_size',
                  'rnn_hidden_size',
                  'decoder_output_bias',
                  'decoder_output_use_sigmoid',
                  'baseline_scalar',
                  'encoder_net',
                  'decoder_net',
                  'predict_net',
                  'embed_net',
                  'bl_predict_net',
                  'non_linearity',
                  'pos_prior_mean',
                  'pos_prior_sd',
                  'scale_prior_mean',
                  'scale_prior_sd']
model_args = {key: getattr(args, key) for key in model_arg_keys if key in args}

In [51]:
air = AIR(
        num_steps=args.model_steps,
        x_size=75,
        use_masking=not args.no_masking,
        use_baselines=not args.no_baselines,
        z_what_size=args.encoder_latent_size,
        use_cuda=args.cuda,
        **model_args
    )

In [52]:
if 'load' in args:
    print('Loading parameters...')
    air.load_state_dict(torch.load(args.load))

Loading parameters...


## Visualize

In [53]:
count = 100
vis = visdom.Visdom(env=args.visdom_env)

Setting up a new session...


## Post processing

In [54]:
def img2arr(img):
    # assumes color image
    # returns an array suitable for sending to visdom
    return np.array(img.getdata(), np.uint8).reshape(img.size + (3,)).transpose((2, 0, 1))

def arr2img(arr):
    # arr is expected to be a 2d array of floats in [0,1]
    return Image.frombuffer('L', arr.shape, (arr * 255).astype(np.uint8).tostring(), 'raw', 'L', 0, 1)

def bounding_box(z_where, x_size):
    """This doesn't take into account interpolation, but it's close
    enough to be usable."""
    w = x_size / z_where.s
    h = x_size / z_where.s
    xtrans = -z_where.x / z_where.s * x_size / 2.
    ytrans = -z_where.y / z_where.s * x_size / 2.
    x = (x_size - w) / 2 + xtrans  # origin is top left
    y = (x_size - h) / 2 + ytrans
    return (x, y), w, h

def colors(k):
    return [(255, 0, 0), (0, 255, 0), (0, 0, 255)][k % 3]

In [55]:
def draw_one(imgarr, z_arr):
    # Note that this clipping makes the visualisation somewhat
    # misleading, as it incorrectly suggests objects occlude one
    # another.
    clipped = np.clip(imgarr.detach().cpu().numpy(), 0, 1)
    img = arr2img(clipped).convert('RGB')
    draw = ImageDraw.Draw(img)
    for k, z in enumerate(z_arr):
        # It would be better to use z_pres to change the opacity of
        # the bounding boxes, but I couldn't make that work with PIL.
        # Instead this darkens the color, and skips boxes altogether
        # when z_pres==0.
        if z.pres > 0:
            (x, y), w, h = bounding_box(z, imgarr.size(0))
            color = tuple(map(lambda c: int(c * z.pres), colors(k)))
            crop_img = clipped[int(y):int(y)+int(h), int(x):int(x)+int(w)]
            white_count = np.count_nonzero(crop_img>0.01)
            black_count = np.count_nonzero(crop_img<0.01)
            if (black_count > 0 and white_count / black_count < 0.09):
                continue
            draw.rectangle([x, y, x + w, y + h], outline=color)
    is_relaxed = any(z.pres != math.floor(z.pres) for z in z_arr)
    fmtstr = '{:.1f}' if is_relaxed else '{:.0f}'
    draw.text((0, 0), fmtstr.format(sum(z.pres for z in z_arr)), fill='white')
    return img2arr(img)

In [56]:
def draw_many_custom(imgarrs, z_arr):
    # canvases is expected to be a (n,w,h) numpy array
    # z_where_arr is expected to be a list of length n
    return [draw_one(imgarr, z) for (imgarr, z) in zip(imgarrs.cpu(), z_arr)]


In [57]:
examples_to_viz = X[0:50]
trace = poutine.trace(air.guide).get_trace(examples_to_viz, None)
z, recons = poutine.replay(air.prior, trace=trace)(examples_to_viz.size(0))
z_wheres = tensor_to_objs(latents_to_tensor(z))

# Show data with inferred objection positions.
vis.images(draw_many_custom(examples_to_viz, z_wheres))
# Show reconstructions of data.
vis.images(draw_many(examples_to_viz, z_wheres))

'window_389193b9e673de'

## Repeat experiments

In [None]:
examples_to_viz = X[0:48]

In [None]:
def repeat_experiment():
    trace = poutine.trace(air.guide).get_trace(examples_to_viz, None)
    z, recons = poutine.replay(air.prior, trace=trace)(examples_to_viz.size(0))
    z_wheres = tensor_to_objs(latents_to_tensor(z))

    # Show data with inferred objection positions.
    vis.images(draw_many(examples_to_viz, z_wheres))
    # Show reconstructions of data.
    # vis.images(draw_many(recons, z_wheres))

In [None]:
for i in range(50):
    repeat_experiment()