In [3]:
import os
import collections
import json
import tempfile
import gc

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as trans
import torchvision.models
import cv2
from tqdm import tqdm
import multiprocessing

import vmdata
import more_trans
from pyflow import coarse2fine_flow as opticalflow

from ezfirstae.loaddata import SlidingWindowBatchSampler

%matplotlib inline
%load_ext memory_profiler

The memory_profiler extension is already loaded. To reload it, use:
  %reload_ext memory_profiler


`flow_params`:

- alpha
- ratio ($\in [0.4, 0.98]$)
- minWidth
- nOuterFPIterations
- nInnerFPIterations
- nSORIterations
- colType (0 for RGB, 1 for GRAY)

In [18]:
FlowParams = collections.namedtuple('FlowParams', (
    'alpha', 'ratio', 'minWidth', 'nOuterFPIterations', 'nInnerFPIterations',
    'nSORIterations', 'colType'))
FlowParams_colType = {'rgb': 0, 'gray': 1}

In [19]:
flow_params = FlowParams(
    alpha=0.012,
    ratio=0.75,
    minWidth=20,
    nOuterFPIterations=7,
    nInnerFPIterations=1,
    nSORIterations=30,
    colType=FlowParams_colType['gray'],
)

assert 0.4 <= flow_params.ratio <= 0.98

In [7]:
def visualize_flow(u, v):
    assert u.shape == v.shape
    hsv = np.zeros(u.shape[:2]+(3,), dtype=np.uint8)
    hsv[:, :, 0] = 255
    hsv[:, :, 1] = 255
    mag, ang = cv2.cartToPolar(u, v)
    hsv[..., 0] = ang * 180 / np.pi / 2
    hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
    rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB).astype(np.float64) / 255.0
    return rgb

In [23]:
def opticalflow_caller(image_pair):
    im1, im2 = image_pair
    u, v, _ = opticalflow(im1, im2, *flow_params)
    uv = np.stack((u, v), axis=2)
    return uv

In [4]:
def compute_flows(settings, num_workers=1):
    root = settings['data']['root']
    transform = [trans.ToTensor()]
    if settings['data']['normalize']:
        transform.append(trans.Normalize(*vmdata.get_normalization_stats(root, bw=True)))
    transform = trans.Compose(transform)
    indices = settings['data']['indices']
    flow_params = FlowParams(**settings['flow'])
    
    all_uvs = []
    with vmdata.VideoDataset(root, transform=transform) as vdset:
        sam = SlidingWindowBatchSampler(indices, 2,
                                        batch_size=num_workers, drop_last=True)
        dataloader = DataLoader(vdset, batch_sampler=sam)
        with multiprocessing.Pool(num_workers) as pool:
            for j, image_pairs in tqdm(enumerate(map(more_trans.chw2hwc, more_trans.numpy_loader(dataloader))),
                                       total=int(np.ceil((len(indices)-1)/num_workers)), ascii=True):
                image_pairs = np.split(image_pairs.astype(np.float64).copy(order='C'),
                                       num_workers, axis=0)
                uvs = pool.map(opticalflow_caller, image_pairs)
                all_uvs.extend(uvs)
                if j and j % 2 == 0:
                    gc.collect()
    all_uvs = np.stack(all_uvs)  # shape: NHW2
    return all_uvs

def load_json(filename):
    with open(filename) as infile:
        return json.load(infile)

def compute_flows_caller(settings, num_workers=1):
    basedir = 'data.experiments-flow/flows'
    all_setting_files = [os.path.join(basedir, x)
                         for x in os.listdir(basedir)
                         if x.endswith('.json')]
    all_settings = list(map(load_json, all_setting_files))
    try:
        found = all_settings.index(settings)
    except ValueError:
        all_uvs = compute_flows(settings, num_workers=num_workers)
        with tempfile.NamedTemporaryFile(mode='w', delete=False, dir=basedir,
                                         prefix='flowd_', suffix='.json') as outfile:
            filename = os.path.splitext(outfile.name)[0]
            json.dump(settings, outfile)
        np.save(filename + '.npy', all_uvs)
    else:
        data = os.path.splitext(all_setting_files[found])[0] + '.npy'
        all_uvs = np.load(data)
    return all_uvs

In [41]:
root = vmdata.prepare_dataset_root(9, (8, 0, 0))
normalize = trans.Normalize(*vmdata.get_normalization_stats(root, bw=True))
transform = trans.Compose([
    trans.ToTensor(),
    normalize,
])
indices = np.arange(1000)

In [42]:
flow_settings = flow_params._asdict()
data_settings = {'root': root, 'normalize': trans.Normalize in map(type, transform.transforms),
                 'indices': list(map(int, indices))}
settings = {'flow': flow_settings, 'data': data_settings}

In [None]:
compute_flows_caller(settings, num_workers=4)