In [102]:
import os
from functools import partial
import operator as op
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as trans
import more_trans
import more_sampler
import exprlib
import trainlib
import vmdata
from aecorr.trainutils import tbatch2cbatch
import aecorr.eval
import aecorr.models.unet_pred3_f1to2 as net_module

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

%matplotlib inline

rundir = os.path.expandvars('$PROJ_HOME/experiments/src/aecorr')
todir = os.path.expandvars('$PROJ_HOME/experiments/ipynb/data.experiments-aecorr')

In [90]:
net = net_module.Autoencoder()
trainlib.load_checkpoint(net, os.path.join(rundir, 'run.1', 'save'), 'checkpoint_{}_{}.pth', (0, 29990))
net.eval()
for p in net.parameters():
    p.requires_grad_(False)

In [3]:
root = vmdata.dataset_root(1, 8, 2)

In [4]:
nmlstat = vmdata.get_normalization_stats(root, bw=True)
normalize = trans.Normalize(*nmlstat)
denormalize = more_trans.DeNormalize(*nmlstat)

In [6]:
normalize

Normalize(mean=(0.5116854310035706,), std=(0.07510992884635925,))

In [5]:
denormalize

DeNormalize(mean=(0.5116854310035706,), std=(0.07510992884635925,))

In [7]:
help(more_trans.BWCAEPreprocess.__init__)

Help on function __init__ in module more_trans:

__init__(self, normalize:torchvision.transforms.transforms.Normalize, pool_scale:int=1, downsample_scale:int=1, to_rgb:bool=False)
    :param normalize: the normalization transform
    :param pool_scale: the overall scale of the pooling operations in
           subsequent encoder; the image will be cropped to (H', W') where
           H' and W' are the nearest positive integers to H and W that are
           the power of ``pool_scale``, so that ``unpool(pool(x))`` is of
           the same shape as ``x``
    :param downsample_scale: the scale to downsample the video frames
    :param to_rgb: if True, at the last step convert from B&W image to
           RGB image



In [11]:
help(more_sampler.SlidingWindowBatchSampler.__init__)

Help on function __init__ in module more_sampler:

__init__(self, indices, window_width:int, shuffled:bool=False, batch_size:int=1, drop_last:bool=False)
    :param indices: array-like integer indices to sample; when presented as
           a list of arrays, no sample will span across more than one array
    :param window_width: the width of the window; if ``window_width`` is
           larger than the length of ``indices`` or the length of one of
           the sublists, then that list won't be sampled
    :param shuffled: whether to shuffle sampling, but the indices order
           within a batch is never shuffled
    :param batch_size: how many batches to yield upon each sampling
    :param drop_last: True to drop the remaining batches if the number of
           remaining batches is less than ``batch_size``
    
    Note on ``batch_size``
    ----------------------
    
    When ``batch_size = 2``, assuming that the two batch of indices are
    ``[1, 2, 3, 4]`` and ``[4, 5, 6, 7]`

In [8]:
transform = more_trans.BWCAEPreprocess(normalize, net_module.pool_scale)

In [41]:
def evalpred_it(root, transform, detransform, indices):
    with vmdata.VideoDataset(root, transform=transform) as vdset:
        sam = more_sampler.SlidingWindowBatchSampler(indices, 3)
        for frames in DataLoader(vdset, batch_sampler=sam):
            frames = more_trans.rearrange_temporal_batch(frames, 3)
            inputs, targets = frames[:,:,[0,2],...], frames[:,:,[1],...]
            inputs, targets = tbatch2cbatch(inputs), tbatch2cbatch(targets)
            
            outputs = net(inputs)
            
            inputs = inputs.detach().transpose(0, 1)
            prediction = outputs.detach().squeeze(0)
            target = targets.detach().squeeze(0)
            inputs = list(map(op.methodcaller('numpy'),
                              map(partial(torch.clamp, min=0.0, max=1.0),
                                  map(detransform, inputs))))
            prediction = torch.clamp(detransform(prediction), 0.0, 1.0).numpy()
            target = torch.clamp(detransform(target), 0.0, 1.0).numpy()
            yield inputs[0][0], target[0], inputs[1][0], prediction[0]

In [91]:
_eit = evalpred_it(root, transform, denormalize, range(500, 510))

In [100]:
for _ in range(4):
    f1, f2, f3, pf2 = next(_eit)
f = f1, f2, f3

In [101]:
def showf(i, showp):
    if not showp:
        plt.imshow(f[i], cmap='gray')
    else:
        plt.imshow(pf2, cmap='gray')

interact(showf, i=widgets.IntSlider(min=0,max=2,step=1,value=0), showp=False);

interactive(children=(IntSlider(value=0, description='i', max=2), Checkbox(value=False, description='showp'), …

In [None]:
def evalcorr_it(root, transform, detransform, indices):
    with vmdata.VideoDataset(root, transform=transform) as vdset:
        sam = more_sampler.SlidingWindowBatchSampler(indices, 3)
        for frames in DataLoader(vdset, batch_sampler=sam):
            frames = more_trans.rearrange_temporal_batch(frames, 3)
            inputs, targets = frames[:,:,[0,2],...], frames[:,:,[1],...]
            inputs, targets = tbatch2cbatch(inputs), tbatch2cbatch(targets)
            outputs = net(inputs)
            
            inputs = inputs.detach().transpose(0, 1)
            prediction = outputs.detach().squeeze(0)
            target = targets.detach().squeeze(0)
            inputs = list(map(op.methodcaller('numpy'),
                              map(partial(torch.clamp, min=0.0, max=1.0),
                                  map(detransform, inputs))))
            prediction = torch.clamp(detransform(prediction), 0.0, 1.0).numpy()
            target = torch.clamp(detransform(target), 0.0, 1.0).numpy()
            yield inputs[0][0], target[0], inputs[1][0], prediction[0]