In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

import sys
sys.path.append("/home/kelvinfung/Documents/bounce-digits")
from models import *

In [None]:
N, C, F, H, W = 50, 1, 5, 64, 64
vid_sample = torch.randn(N, C, F, H, W)

## Convolution outputs

In [None]:
# Without stride; same output HxW
conv3d = nn.Conv3d(1, 5,
                   stride=(1, 1, 1),
                   kernel_size=(3,3,3), padding=(1,1,1))

conv3d(vid_sample).shape

In [None]:
# With stride; Downsample by halving height and width
conv3d = nn.Conv3d(1, 5,
                   stride=(1, 2, 2),
                   kernel_size=(3,3,3), padding=(1,1,1))

conv3d(vid_sample).shape

In [None]:
# Tranpose; Same height and width
conv3dtranspose = nn.ConvTranspose3d(1, 5,
                           stride=(1,1,1),kernel_size=(3,3,3), 
                           padding=(1,1,1))

conv3dtranspose(vid_sample).shape

In [None]:
# Tranpose upsample; Doubles the height and width
conv3dtranspose = nn.ConvTranspose3d(1, 5,
                           stride=(1,2,2),kernel_size=(3,3,3), 
                           padding=(1,1,1), output_padding=(0,1,1))

conv3dtranspose(vid_sample).shape

## Sample batch of context frames

In [None]:
N, C, F, H, W = 50, 1, 5, 64, 64
sample_batch = torch.randn(N, C, F, H, W)

In [None]:
sample_batch = torch.randint_like(vid_sample, low=-5, high=15)
plt.hist(sample_batch.detach().numpy().flatten())

## FutureDiscriminator

In [None]:
# config = {
#     'nframes_pred': 5,
#     'nframes_in' : 5,
#     'batch_norm' : False,
#     'w_norm' : True,
#     'loss' : 'wgan_gp',
#     'd_gdrop' : False,
#     'padding' : 'zero',
#     'lrelu' : True,
#     'd_sigmoid' : False,
#     'nz' : 512,            # dim of input noise vector z    
#     'nc' : 1,              # number of channels
#     'ndf' : 512,           # discriminator first layer's feature dim
#     'd_cond' : True
# }

# dis = Discriminator(config)

## DRNet

In [None]:
import socket
import numpy as np
from torchvision import datasets, transforms
from torch.autograd import Variable


######## Refer to the implementation given at 'https://github.com/edenton/drnet-py' for indepth understanding ###########

class MovingMNIST(object):
    
    """Data Handler that creates Bouncing MNIST dataset on the fly."""

    def __init__(self, train, data_root, seq_len=20, num_digits=2, image_size=64):
        path = data_root
        self.seq_len = seq_len
        self.num_digits = num_digits  
        self.image_size = image_size 
        self.step_length = 0.1
        self.digit_size = 32
        self.seed_is_set = False # multi threaded loading

        self.data = datasets.MNIST(
            path,
            train=train,
            download=True,
            transform=transforms.Compose(
                [transforms.Resize(self.digit_size),
                 transforms.ToTensor()]))

        self.N = len(self.data) 

    def set_seed(self, seed):
        if not self.seed_is_set:
            self.seed_is_set = True
            np.random.seed(seed)
          
    def __len__(self):
        return self.N

    def __getitem__(self, index):
        self.set_seed(index)
        image_size = self.image_size
        digit_size = self.digit_size
        x = np.zeros((self.seq_len,
                      image_size, 
                      image_size, 
                      3),
                    dtype=np.float32)
        for n in range(self.num_digits):
            idx = np.random.randint(self.N)
            digit, _ = self.data[idx]

            sx = np.random.randint(image_size-digit_size)
            sy = np.random.randint(image_size-digit_size)
            dx = np.random.randint(-4, 4)
            dy = np.random.randint(-4, 4)
            for t in range(self.seq_len):
                if sy < 0:
                    sy = 0 
                    dy = -dy
                elif sy >= image_size-32:
                    sy = image_size-32-1
                    dy = -dy
                    
                if sx < 0:
                    sx = 0 
                    dx = -dx
                elif sx >= image_size-32:
                    sx = image_size-32-1
                    dx = -dx
                   
                x[t, sy:sy+32, sx:sx+32, n] = np.copy(digit.numpy())
                sy += dy
                sx += dx
        # pick on digit to be in front
        front = np.random.randint(self.num_digits)
        for cc in range(self.num_digits):
            if cc != front:
                x[:, :, :, cc][x[:, :, :, front] > 0] = 0
        return x

def sequence_input(seq):
    return [Variable(x.type(torch.cuda.FloatTensor)) for x in seq]
    
def normalize_data(sequence):
    sequence.transpose_(0, 1)
    sequence.transpose_(3, 4).transpose_(2, 3)

    return sequence_input(sequence)

def get_training_batch(train_loader):
	while True:
		for sequence in train_loader:
			batch = normalize_data(sequence)
			yield batch

def get_testing_batch(test_loader):
	while True:
		for sequence in test_loader:
			batch = normalize_data(sequence)
			yield batch

def make_rgb_plot(ctx, tgt, pred, epoch=999):
    num_ctx_frames= ctx.shape[1]
    num_tgt_frames = tgt.shape[1]

    def show_frames(frames, ax, row_label=None):
        for i, frame in enumerate(frames):
            ax[i].imshow(frame)
            ax[i].set_xticks([])
            ax[i].set_yticks([])

        if row_label is not None:
            ax[0].set_ylabel(row_label)

    ctx_frames = ctx.squeeze().permute(1, 2, 3, 0).cpu().numpy()
    tgt_frames = tgt.squeeze().permute(1, 2, 3, 0).cpu().numpy()
    pred_frames = pred.squeeze().permute(1, 2, 3, 0).cpu().numpy()

    fig, ax = plt.subplots(3, max(num_ctx_frames, num_tgt_frames),
                       figsize = (9, 5))
    fig.suptitle(f"EPOCH {epoch}", y=0.93)
    show_frames(ctx_frames, ax[0], "Context")
    show_frames(tgt_frames, ax[1], "Target")
    show_frames(pred_frames, ax[2], "Prediction")

    return fig

In [None]:
data_root = "/home/kelvinfung/Documents/bounce-digits/data/"
seq_len=10
image_width=128
batch_size=16

train_data = MovingMNIST(
            train=True,
            data_root=data_root,
            seq_len=seq_len,
            image_size=image_width,
            num_digits=2)
test_data = MovingMNIST(
        train=False,
        data_root=data_root,
        seq_len=seq_len,
        image_size=image_width,
        num_digits=2)

train_loader = DataLoader(train_data, 
                        num_workers=4, 
                        batch_size=batch_size,
                        shuffle=True, 
                        drop_last=True, 
                        pin_memory=True)
test_loader = DataLoader(test_data, 
                        num_workers=4, 
                        batch_size=16,
                        shuffle=False, 
                         drop_last=True, 
                         pin_memory=True)

train_generator = get_training_batch(train_loader)
test_generator = get_testing_batch(test_loader)

In [None]:
x = next(train_generator)
print(len(x))
x[0].shape
# x: list of tensors of length = seq_len
# x[0] tensor of shape: B * C * H * W

In [None]:
channels=3
pose_dim=5
discriminator_dim=100

scene_discriminator = SceneDiscriminator(pose_dim, discriminator_dim).to("cuda")
pose_encoder = Encoder(channels, pose_dim).to("cuda")


In [None]:
target = torch.cuda.FloatTensor(batch_size, 1)
x1 = x[0].to("cuda")  # First frame of all videos in batch: BS x C x H x W
x2 = x[1].to("cuda")  # Second frame of all videos in batch
h_p1 = pose_encoder(x1)[0].detach()  # Pose of first frames of all videos in a batch: BS x pose_dim x 1 x 1
h_p2 = pose_encoder(x2)[0].detach()

In [None]:
half = int(batch_size/2)
rp = torch.randperm(half).cuda()
h_p2[:half] = h_p2[rp]  # Permute first half of h_p2; allowing frames from different videos to be compared by the discriminator

In [None]:
target[:half] = 1
target[half:] = 0

In [None]:
print(f"h_p1 shape: {h_p1.shape}")
out = scene_discriminator(h_p1, h_p2)
print(f"scene discriminator out shape: {out.shape}")

In [None]:
bce = nn.MSELoss()(out, Variable(target))
acc =out[:half].gt(0.5).sum() + out[half:].le(0.5).sum()
acc

In [None]:
torch.cuda.FloatTensor(batch_size, 1).fill_(0.5)

In [2]:
mod = DRNetMain(channels=3, content_dim=128, pose_dim=5,
                 discriminator_dim=100,
                 learning_rate=1e-3,
                 alpha=1,
                 beta=0.1)

In [4]:
bce = nn.BCELoss()
pred = torch.tensor([0.9, 0.8, 0.9, 0.1, 0, 0.2])
target = torch.tensor([1., 1., 1., 0., 0., 0.])
bce(pred, target)

tensor(0.1271)