In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

import sys
sys.path.append("/home/kelvinfung/Documents/bounce-digits")
from models import *

In [2]:
N, C, F, H, W = 50, 1, 5, 64, 64
vid_sample = torch.randn(N, C, F, H, W)

In [9]:
sample1 = torch.FloatTensor([[5, 1], [0, 3]])
sample2 = torch.FloatTensor([[3, 4], [10, 9]])

In [10]:
print(nn.L1Loss()(sample1, sample2))
print(nn.MSELoss()(sample1, sample2))

tensor(5.2500)
tensor(37.2500)


## Convolution outputs

In [None]:
# Without stride; same output HxW
conv3d = nn.Conv3d(1, 5,
                   stride=(1, 1, 1),
                   kernel_size=(3,3,3), padding=(1,1,1))

conv3d(vid_sample).shape

In [None]:
# With stride; Downsample by halving height and width
conv3d = nn.Conv3d(1, 5,
                   stride=(1, 2, 2),
                   kernel_size=(3,3,3), padding=(1,1,1))

conv3d(vid_sample).shape

In [None]:
# Tranpose; Same height and width
conv3dtranspose = nn.ConvTranspose3d(1, 5,
                           stride=(1,1,1),kernel_size=(3,3,3), 
                           padding=(1,1,1))

conv3dtranspose(vid_sample).shape

In [None]:
# Tranpose upsample; Doubles the height and width
conv3dtranspose = nn.ConvTranspose3d(1, 5,
                           stride=(1,2,2),kernel_size=(3,3,3), 
                           padding=(1,1,1), output_padding=(0,1,1))

conv3dtranspose(vid_sample).shape

## Sample batch of context frames

In [None]:
N, C, F, H, W = 50, 1, 5, 64, 64
sample_batch = torch.randn(N, C, F, H, W)

In [None]:
sample_batch = torch.randint_like(vid_sample, low=-5, high=15)
plt.hist(sample_batch.detach().numpy().flatten())

## FutureDiscriminator

In [None]:
# config = {
#     'nframes_pred': 5,
#     'nframes_in' : 5,
#     'batch_norm' : False,
#     'w_norm' : True,
#     'loss' : 'wgan_gp',
#     'd_gdrop' : False,
#     'padding' : 'zero',
#     'lrelu' : True,
#     'd_sigmoid' : False,
#     'nz' : 512,            # dim of input noise vector z    
#     'nc' : 1,              # number of channels
#     'ndf' : 512,           # discriminator first layer's feature dim
#     'd_cond' : True
# }

# dis = Discriminator(config)

## DRNet

In [None]:
import socket
import numpy as np
from torchvision import datasets, transforms
from torch.autograd import Variable


######## Refer to the implementation given at 'https://github.com/edenton/drnet-py' for indepth understanding ###########

class MovingMNIST(object):
    
    """Data Handler that creates Bouncing MNIST dataset on the fly."""

    def __init__(self, train, data_root, seq_len=20, num_digits=2, image_size=64):
        path = data_root
        self.seq_len = seq_len
        self.num_digits = num_digits  
        self.image_size = image_size 
        self.step_length = 0.1
        self.digit_size = 32
        self.seed_is_set = False # multi threaded loading

        self.data = datasets.MNIST(
            path,
            train=train,
            download=True,
            transform=transforms.Compose(
                [transforms.Resize(self.digit_size),
                 transforms.ToTensor()]))

        self.N = len(self.data) 

    def set_seed(self, seed):
        if not self.seed_is_set:
            self.seed_is_set = True
            np.random.seed(seed)
          
    def __len__(self):
        return self.N

    def __getitem__(self, index):
        self.set_seed(index)
        image_size = self.image_size
        digit_size = self.digit_size
        x = np.zeros((self.seq_len,
                      image_size, 
                      image_size, 
                      3),
                    dtype=np.float32)
        for n in range(self.num_digits):
            idx = np.random.randint(self.N)
            digit, _ = self.data[idx]

            sx = np.random.randint(image_size-digit_size)
            sy = np.random.randint(image_size-digit_size)
            dx = np.random.randint(-4, 4)
            dy = np.random.randint(-4, 4)
            for t in range(self.seq_len):
                if sy < 0:
                    sy = 0 
                    dy = -dy
                elif sy >= image_size-32:
                    sy = image_size-32-1
                    dy = -dy
                    
                if sx < 0:
                    sx = 0 
                    dx = -dx
                elif sx >= image_size-32:
                    sx = image_size-32-1
                    dx = -dx
                   
                x[t, sy:sy+32, sx:sx+32, n] = np.copy(digit.numpy())
                sy += dy
                sx += dx
        # pick on digit to be in front
        front = np.random.randint(self.num_digits)
        for cc in range(self.num_digits):
            if cc != front:
                x[:, :, :, cc][x[:, :, :, front] > 0] = 0
        return x

def sequence_input(seq):
    return [Variable(x.type(torch.cuda.FloatTensor)) for x in seq]
    
def normalize_data(sequence):
    sequence.transpose_(0, 1)
    sequence.transpose_(3, 4).transpose_(2, 3)

    return sequence_input(sequence)

def get_training_batch(train_loader):
	while True:
		for sequence in train_loader:
			batch = normalize_data(sequence)
			yield batch

def get_testing_batch(test_loader):
	while True:
		for sequence in test_loader:
			batch = normalize_data(sequence)
			yield batch

def make_rgb_plot(ctx, tgt, pred, epoch=999):
    num_ctx_frames= ctx.shape[1]
    num_tgt_frames = tgt.shape[1]

    def show_frames(frames, ax, row_label=None):
        for i, frame in enumerate(frames):
            ax[i].imshow(frame)
            ax[i].set_xticks([])
            ax[i].set_yticks([])

        if row_label is not None:
            ax[0].set_ylabel(row_label)

    ctx_frames = ctx.squeeze().permute(1, 2, 3, 0).cpu().numpy()
    tgt_frames = tgt.squeeze().permute(1, 2, 3, 0).cpu().numpy()
    pred_frames = pred.squeeze().permute(1, 2, 3, 0).cpu().numpy()

    fig, ax = plt.subplots(3, max(num_ctx_frames, num_tgt_frames),
                       figsize = (9, 5))
    fig.suptitle(f"EPOCH {epoch}", y=0.93)
    show_frames(ctx_frames, ax[0], "Context")
    show_frames(tgt_frames, ax[1], "Target")
    show_frames(pred_frames, ax[2], "Prediction")

    return fig

In [None]:
data_root = "/home/kelvinfung/Documents/bounce-digits/data/"
seq_len=10
image_width=128
batch_size=16

train_data = MovingMNIST(
            train=True,
            data_root=data_root,
            seq_len=seq_len,
            image_size=image_width,
            num_digits=2)
test_data = MovingMNIST(
        train=False,
        data_root=data_root,
        seq_len=seq_len,
        image_size=image_width,
        num_digits=2)

train_loader = DataLoader(train_data, 
                        num_workers=4, 
                        batch_size=batch_size,
                        shuffle=True, 
                        drop_last=True, 
                        pin_memory=True)
test_loader = DataLoader(test_data, 
                        num_workers=4, 
                        batch_size=16,
                        shuffle=False, 
                         drop_last=True, 
                         pin_memory=True)

train_generator = get_training_batch(train_loader)
test_generator = get_testing_batch(test_loader)

In [None]:
x = next(train_generator)
print(len(x))
x[0].shape
# x: list of tensors of length = seq_len
# x[0] tensor of shape: B * C * H * W

In [None]:
channels=3
pose_dim=5
discriminator_dim=100

scene_discriminator = SceneDiscriminator(pose_dim, discriminator_dim).to("cuda")
pose_encoder = Encoder(channels, pose_dim).to("cuda")


In [None]:
target = torch.cuda.FloatTensor(batch_size, 1)
x1 = x[0].to("cuda")  # First frame of all videos in batch: BS x C x H x W
x2 = x[1].to("cuda")  # Second frame of all videos in batch
h_p1 = pose_encoder(x1)[0].detach()  # Pose of first frames of all videos in a batch: BS x pose_dim x 1 x 1
h_p2 = pose_encoder(x2)[0].detach()

In [None]:
half = int(batch_size/2)
rp = torch.randperm(half).cuda()
h_p2[:half] = h_p2[rp]  # Permute first half of h_p2; allowing frames from different videos to be compared by the discriminator

In [None]:
target[:half] = 1
target[half:] = 0

In [None]:
print(f"h_p1 shape: {h_p1.shape}")
out = scene_discriminator(h_p1, h_p2)
print(f"scene discriminator out shape: {out.shape}")

In [None]:
bce = nn.MSELoss()(out, Variable(target))
acc =out[:half].gt(0.5).sum() + out[half:].le(0.5).sum()
acc

## PredRNN

In [None]:
args_dict = {"batch_size": 8,
        "total_length": 10,
        "input_length": 5,
        "img_width": 128,
        "patch_size": 4,
        "img_channel": 3,
        "scheduled_sampling": 1,
        "sampling_stop_iter": 50000,
        "sampling_changing_rate": 0.00002}

class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__dict__ = self

args = AttrDict(args_dict)

In [None]:
def schedule_sampling(eta, itr):
    zeros = np.zeros((args.batch_size,
                      args.total_length - args.input_length - 1,
                      args.img_width // args.patch_size,
                      args.img_width // args.patch_size,
                      args.patch_size ** 2 * args.img_channel))
    if not args.scheduled_sampling:
        return 0.0, zeros

    if itr < args.sampling_stop_iter:
        eta -= args.sampling_changing_rate
    else:
        eta = 0.0
    random_flip = np.random.random_sample(
        (args.batch_size, args.total_length - args.input_length - 1))
    true_token = (random_flip < eta)
    ones = np.ones((args.img_width // args.patch_size,
                    args.img_width // args.patch_size,
                    args.patch_size ** 2 * args.img_channel))
    zeros = np.zeros((args.img_width // args.patch_size,
                      args.img_width // args.patch_size,
                      args.patch_size ** 2 * args.img_channel))
    real_input_flag = []
    for i in range(args.batch_size):
        for j in range(args.total_length - args.input_length - 1):
            if true_token[i, j]:
                real_input_flag.append(ones)
            else:
                real_input_flag.append(zeros)
    real_input_flag = np.array(real_input_flag)
    real_input_flag = np.reshape(real_input_flag,
                                 (args.batch_size,
                                  args.total_length - args.input_length - 1,
                                  args.img_width // args.patch_size,
                                  args.img_width // args.patch_size,
                                  args.patch_size ** 2 * args.img_channel))
    return eta, real_input_flag


In [None]:
eta = 1.0
iter = 2

eta, real_input_flag = schedule_sampling(eta, iter)

In [None]:
eta

In [1]:
import torch
import torch.nn as nn

class SpatioTemporalLSTMCell(nn.Module):
    def __init__(self, in_channel, num_hidden, kernel_size, stride):
        super().__init__()

        self.num_hidden = num_hidden
        self.padding = kernel_size // 2
        self._forget_bias = 1.0

        self.conv_x = nn.Sequential(
            nn.Conv2d(in_channel, num_hidden * 7, kernel_size=kernel_size, stride=stride, padding=self.padding, bias=False),
        )
        self.conv_h = nn.Sequential(
            nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=kernel_size, stride=stride, padding=self.padding, bias=False),
        )
        self.conv_m = nn.Sequential(
            nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=kernel_size, stride=stride, padding=self.padding, bias=False),
        )
        self.conv_o = nn.Sequential(
            nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=kernel_size, stride=stride, padding=self.padding, bias=False),
        )

        self.conv_last = nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=1, stride=1, padding=0, bias=False)
    
    def forward(self, x_t, h_t, c_t, m_t):
        x_concat = self.conv_x(x_t)
        h_concat = self.conv_h(h_t)
        m_concat = self.conv_m(m_t)
        i_x, f_x, g_x, i_x_prime, f_x_prime, g_x_prime, o_x = torch.split(x_concat, self.num_hidden, dim=1)
        i_h, f_h, g_h, o_h = torch.split(h_concat, self.num_hidden, dim=1)
        i_m, f_m, g_m = torch.split(m_concat, self.num_hidden, dim=1)

        i_t = torch.sigmoid(i_x + i_h)
        f_t = torch.sigmoid(f_x + f_h + self._forget_bias)
        g_t = torch.tanh(g_x + g_h)

        # print(f"f_t: {f_t.shape}")
        # print(f"c_t: {c_t.shape}")
        # print(f"i_t: {i_t.shape}")
        # print(f"g_t: {g_t.shape}")
        c_new = f_t * c_t + i_t * g_t

        i_t_prime = torch.sigmoid(i_x_prime + i_m)
        f_t_prime = torch.sigmoid(f_x_prime + f_m + self._forget_bias)
        g_t_prime = torch.tanh(g_x_prime + g_m)

        m_new = f_t_prime * m_t + i_t_prime * g_t_prime

        mem = torch.cat((c_new, m_new), 1)
        o_t = torch.sigmoid(o_x + o_h + self.conv_o(mem))
        h_new = o_t * torch.tanh(self.conv_last(mem))

        return h_new, c_new, m_new

class PredRNN(nn.Module):
    def __init__(self, input_channels, img_width, img_height,
                 num_layers, num_hidden, 
                 num_ctx_frames, num_tgt_frames,
                 kernel_size,
                 stride,
                 learning_rate=1e-3):
        super().__init__()

        self.input_channels = input_channels
        self.num_layers = num_layers
        self.num_hidden = num_hidden
        cell_list = []

        self.mse = nn.MSELoss()
        self.num_ctx_frames = num_ctx_frames
        self.num_tgt_frames = num_tgt_frames
        self.total_length = num_ctx_frames + num_tgt_frames
        self.kernel_size = kernel_size
        self.padding = kernel_size // 2
        self.learning_rate = learning_rate 

        for i in range(num_layers):
            in_channel = self.input_channels if i == 0 else num_hidden[i - 1]
            cell_list.append(
                SpatioTemporalLSTMCell(in_channel, num_hidden[i], kernel_size, stride)
            )
        self.cell_list = nn.ModuleList(cell_list)
        self.conv_last = nn.Conv2d(num_hidden[num_layers - 1], 
                                   self.input_channels,
                                   kernel_size=1, stride=1, padding=0, bias=False)

    def forward(self, x):
        B, C, F, H, W = x.shape

        next_frames = []
        h_t = []
        c_t = []

        # Initialize hidden states and cell states
        for i in range(self.num_layers):
            zeros = torch.zeros([B, self.num_hidden[i], H, W])
            h_t.append(zeros)
            c_t.append(zeros)

        # Initialize memory state
        memory = torch.zeros([B, self.num_hidden[0], H, W])

        for t in range(self.total_length - 1):
            print(f"t: {t}")
            frame = x[:, :, t]
            h_t[0], c_t[0], memory = self.cell_list[0](frame, h_t[0], c_t[0], memory)

            for i in range(1, self.num_layers):
                h_t[i], c_t[i], memory = self.cell_list[i](h_t[i - 1], h_t[i], c_t[i], memory)

            x_gen = self.conv_last(h_t[self.num_layers - 1])
            next_frames.append(x_gen)

        next_frames = torch.stack(next_frames, dim=2)
        return next_frames

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr = self.learning_rate)
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, )
        return {"optimizer": optimizer, 
                "lr_scheduler": lr_scheduler,
                "monitor": "val_loss"
                }

In [8]:
input_channels=3
img_width=128
img_height=128
num_hidden = [64,64,64]
num_layers = len(num_hidden)
num_ctx_frames=5
num_tgt_frames=5
patch_size=4
kernel_size=5
stride=1
layer_norm=0

model = PredRNN(input_channels=input_channels, 
                img_width=img_width, 
                img_height=img_height,
                num_layers=num_layers,
                num_hidden=num_hidden, 
                num_ctx_frames=num_ctx_frames, 
                num_tgt_frames=num_tgt_frames,
                kernel_size=kernel_size,
                stride=stride)

In [9]:
train_x_batch = torch.randn(8, 3, 10, 128, 128)
next_frames = model(train_x_batch)

t: 0
t: 1
t: 2
t: 3
t: 4
t: 5
t: 6
t: 7
t: 8


In [10]:
next_frames.shape

torch.Size([8, 3, 9, 128, 128])