In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

path_to_dir = "/home/kelvinfung/Documents/bounce-digits/"

import sys
sys.path.append(path_to_dir)
from models import *
from data.data_classes import *

## Load data

In [2]:
num_ctx_frames = 5
num_tgt_frames = 5
batch_size = 50

mnist_module = moving_mnist = TwoColourMovingMNISTDataModule(batch_size,               
                                                             num_ctx_frames, 
                                                             num_tgt_frames,
                                                             split_ratio=[0.2, 0.05, 0.75])
mnist_module.setup()

In [3]:
train_dataloader = mnist_module.train_dataloader()
for train_x_batch, train_y_batch in train_dataloader:
    break

train_x_batch.shape  # bs x C x F x H x W

torch.Size([50, 3, 5, 128, 128])

## Test metrics

In [None]:
from skimage.metrics import structural_similarity, peak_signal_noise_ratio

In [None]:
vid1, vid2 = train_x_batch, train_y_batch

In [None]:
psnr = PSNR()
ssim = SSIM()
print(f'psnr: {psnr(vid1, vid2)}')
print(f'ssim: {ssim(vid1, vid2)}')

In [None]:
print(f'psnr: {psnr(vid1, vid1)}')
print(f'ssim: {ssim(vid1, vid1)}')

## Test plotting function

In [None]:
def make_grid(ctx, tgt, pred, epoch, cmap='gray'):
    num_ctx_frames= ctx.shape[1]
    num_tgt_frames = tgt.shape[1]

    def show_frames(frames, ax, row_label=None):
        for i, frame in enumerate(frames):
            if cmap is not None:
                ax[i].imshow(frame, cmap)
            else:
                ax[i].imshow(frame)
            ax[i].set_xticks([])
            ax[i].set_yticks([])

        if row_label is not None:
            ax[0].set_ylabel(row_label)

    ctx_frames = ctx.squeeze().cpu().numpy()
    tgt_frames = tgt.squeeze().cpu().numpy()
    pred_frames = pred.squeeze().cpu().numpy()

    fig, ax = plt.subplots(3, max(num_ctx_frames, num_tgt_frames),
                       figsize = (9, 5))
    fig.suptitle(f"EPOCH {epoch}", y=0.93)
    show_frames(ctx_frames, ax[0], "Context")
    show_frames(tgt_frames, ax[1], "Target")
    show_frames(pred_frames, ax[2], "Prediction")


    return fig

In [None]:
fig = make_grid(train_x_batch[5], train_y_batch[5], train_y_batch[5], 1)

In [None]:
def fig2rgb_array(fig):
    fig.canvas.draw()
    buf = fig.canvas.tostring_rgb()
    ncols, nrows = fig.canvas.get_width_height()
    shp = (nrows, ncols, 3)
    return np.frombuffer(buf, dtype=np.uint8).reshape(shp)

In [None]:
arr = fig2rgb_array(fig)
plt.imshow(arr)

In [None]:
arr.dtype

In [None]:
from PIL import Image
Image.fromarray(arr)

In [None]:
plt.hist(arr.flatten())