In [1]:
import cv2
import torch
import os
import numpy as np
from torch.utils.data import DataLoader
import torchvision.transforms as T
from thanos.dataset import (
    IPN, binary_label_transform, 
    IPN_HAND_ROOT, INPUT_MEAN, INPUT_STD)

from thanos.trainers.data_augmentation import ( 
    get_temporal_transform_fn,
    get_train_spatial_transform_fn,
    get_val_spatial_transform_fn)

import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import display, HTML

In [2]:
def plot_sequence_images(image_array):
    ''' Display images sequence as an animation in jupyter notebook
    
    Args:
        image_array(numpy.ndarray): image_array.shape equal to (num_images, height, width, num_channels)
    '''
    dpi = 72.0
    xpixels, ypixels = image_array[0].shape[:2]
    fig = plt.figure(figsize=(ypixels/dpi, xpixels/dpi), dpi=dpi)
    im = plt.figimage(image_array[0])

    def animate(i):
        im.set_array(image_array[i])
        return (im,)

    anim = animation.FuncAnimation(fig, animate, frames=len(image_array), interval=33, repeat_delay=1, repeat=True)
    display(HTML(anim.to_html5_video()))

In [3]:
ann_path = os.path.join(IPN_HAND_ROOT, "annotations", "ipnall.json")

In [4]:
ipn = IPN(IPN_HAND_ROOT, ann_path, "training",
#         spatial_transform=get_train_spatial_transform_fn(), 
        temporal_transform=get_temporal_transform_fn(16),
        target_transform=binary_label_transform)
dataloader = DataLoader(ipn, batch_size=8, shuffle=True)

[INFO]: IPN Dataset - training is loading...


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4039/4039 [00:00<00:00, 19403.11it/s]


In [5]:
sequences, targets = next(iter(dataloader))

In [13]:
seq = sequences[0].permute(0, 2, 3, 1).numpy()
plot_sequence_images(seq)

<Figure size 320x240 with 0 Axes>