In [1]:
# import math
import random
import torch
from torch import nn
import torchvision.transforms as t
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import RandomSampler, BatchSampler
from dataset import FkDataset, Simulation
import fk
import numpy as np
from flows import MADE
from glob import glob
import matplotlib.pyplot as plt


def log(*m, **kwargs):
    if DEBUG:
        print(*m, **kwargs)
        
def log_progress(epoch, batch_number, n_batches, loss):
    s = f"Epoch: {epoch} \t Batch: {batch_number}/{n_batches} \t"
    s += "\t".join(["{}_loss: {:.4f}".format(k, v) for k, v in loss.items()])
    print(s, end="\r")
    return

def plot_progress(epoch, x_hat, x, loss, **kwargs):
    fk.plot.show(x_hat.detach().cpu().numpy(), vmin=None, vmax=None)
    fk.plot.show(x.detach().cpu().numpy(), vmin=None, vmax=None)
    plt.show()
    return

def gradient_loss(gen_frames, gt_frames):
    def gradient(x):
        # idea from tf.image.image_gradients(image)
        # https://github.com/tensorflow/tensorflow/blob/r2.1/tensorflow/python/ops/image_ops_impl.py#L3441-L3512
        # x: (b,c,h,w), float32 or float64
        # dx, dy: (b,c,h,w)

        h_x = x.size()[-2]
        w_x = x.size()[-1]
        # gradient step=1
        left = x
        right = torch.nn.functional.pad(x, [0, 1, 0, 0])[:, :, :, 1:]
        top = x
        bottom = torch.nn.functional.pad(x, [0, 0, 0, 1])[:, :, 1:, :]

        # dx, dy = torch.abs(right - left), torch.abs(bottom - top)
        dx, dy = right - left, bottom - top
        # dx will always have zeros in the last column, right-left
        # dy will always have zeros in the last row,    bottom-top
        dx[:, :, :, -1] = 0
        dy[:, :, -1, :] = 0

        return torch.sqrt(dx.pow(2) + dy.pow(2))

    # gradient
    grad_pred = gradient(gen_frames)
    grad_truth = gradient(gt_frames)

    # condense into one tensor and avg
    return torch.nn.functional.mse_loss(grad_pred, grad_truth, reduction="sum")

class Downsample:
    def __init__(self, size, mode="bicubic"):
        self.size = size
        self.mode = mode
    
    def __call__(self, x):
        return torch.nn.functional.interpolate(x, self.size)

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)
    
class Unflatten(nn.Module):
    def __init__(self, size=256, h=None, w=None):
        self.size = size
        self.h = h if h is not None else 1
        self.w = w if w is not None else 1
        super(Unflatten, self).__init__()
        
    def forward(self, input):
        return input.view(input.size(0), self.size, self.w, self.h)

class Elu(nn.Module):
    def forward(self, x):
        return nn.functional.elu(x)
    
class ConvBlock(nn.Sequential):
    def __init__(self, in_channels, out_channels, padding):
        super(ConvBlock, self).__init__()

        # add convolution 1
        self.add_module("conv1",
                        nn.Conv2d(in_channels=in_channels,
                                  out_channels=out_channels,
                                  kernel_size=3,
                                  padding=int(padding)))
        self.add_module("relu1", nn.ReLU())

        # add convolution 2
        self.add_module("conv2",
                        nn.Conv2d(in_channels=out_channels,
                                  out_channels=out_channels,
                                  kernel_size=3,
                                  padding=int(padding)))
        self.add_module("relu2", nn.ReLU())

    def forward(self, x):
        return super().forward(x)


class ConvTransposeBlock(nn.Module):
    def __init__(self, in_channels, out_channels, padding):
        super(ConvTransposeBlock, self).__init__()

        # upsample
        self.up = nn.ConvTranspose2d(in_channels=in_channels,
                                     out_channels=out_channels,
                                     kernel_size=4,
                                     stride=4)

        # add convolutions block
        self.conv_block = ConvBlock(in_channels=in_channels,
                                        out_channels=out_channels,
                                        padding=padding)

    def center_crop(self, layer, target_size):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
        return layer[:, :, diff_y:(diff_y + target_size[0]), diff_x:(diff_x + target_size[1])]

    def forward(self, x, skip_connection):
        up = self.up(x)
        crop1 = self.center_crop(skip_connection, up.shape[2:])
        out = torch.cat([up, crop1], 1)
        out = self.conv_block(out)

        return out

    
class UNet(nn.Module):
    def __init__(self, filters, hidden_size, input_size, in_channels, out_channels):
        super(UNet, self).__init__()
        
        depth = len(filters)
        if (in_channels == 1 and out_channels == 0):
            in_channels = 3
            out_channels = 3
        
        # downsampling
        self.downsample = nn.ModuleList()
        for i in range(depth):
            log(filters[i])
            self.downsample.append(UNetConvBlock(in_channels=in_channels,
                                                 out_channels=filters[i],
                                                 padding=1))
            in_channels = filters[i]
        
        # latent
        self.latent_cont = nn.Conv2d(in_channels, in_channels, 1, 1)
        self.flatten = Flatten()
        self.latent_in = nn.Linear(input_size // 2, hidden_size)
        self.latent_out = nn.Linear(hidden_size, input_size // 2)
        self.unflatten = Unflatten(input_size // 2 , 1, 1)
        
#         # propagator
        self.propagator = MADE(hidden_size, hidden_size, 5)
        
        # upsample
        self.upsample = nn.ModuleList()
        out_filter = [out_channels] + filters
        for i in reversed(range(1, depth)):
            log(filters[i])
            self.upsample.append(UNetUpBlock(in_channels=in_channels,
                                             out_channels=out_filter[i],
                                             padding=1))
            in_channels = out_filter[i]
            
        self.output = nn.Conv2d(in_channels, out_channels, 1, 1)
        return
    
    def get_loss(self, u, sum_log_abs_det_jacobians, y_hat, y):
        mse =  torch.sqrt(F.mse_loss(y_hat, y, reduction="sum"))
        grad = torch.sqrt(gradient_loss(y_hat, y))
        nf = torch.sum(self.propagator.base_dist.log_prob(u) + sum_log_abs_det_jacobians, dim=1).sum()
        return {"mse": mse, "grad": grad, "nf": nf}
    
    def encode(self, X):
        log("input", X.shape)
        skip_connections = []
        z = X
        for i, down in enumerate(self.downsample):
            z = down(z)
            log("down", z.shape)
            if i != len(self.downsample) - 1:
                skip_connections.append(z)
                z = F.max_pool2d(z, 4)
                
        z = self.flatten(z)
        log("flatten", z.shape)
        z = self.latent_in(z)
        log("latent_in", z.shape)
        return z, skip_connections
    
    def decode(self, z, skip_connections):
        y_hat = self.latent_out(z)
        log("latent_out", y_hat.shape)
        y_hat = self.unflatten(y_hat)
        log("unflatten", y_hat.shape)
    
        for i, up in enumerate(self.upsample):
            y_hat = up(y_hat, skip_connections[-i - 1])
            log("up", y_hat.shape)
        y_hat = self.output(y_hat)
        log("output", y_hat.shape)
        return y_hat
    
    def propagate(self, z):
        u, jacobian = self.propagator(z)
        return u, jacobian
    
    def forward(self, X):
        z, skip_connections = self.encode(X)
        u, jacobian = self.propagator(z)
        y_hat = self.decode(z, skip_connections)
        return y_hat, self.get_loss(u, jacobian, y_hat, X)
    
    def parameters_count(self):
        return sum(p.numel() for p in self.parameters())

class FKLoader:
    def __init__(self, dataset, batch_size=32, shuffle=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = list(range(len(dataset)))
        if shuffle:
            random.shuffle(self.indices)
        
    def __len__(self):
        return len(dataset)
    
    def __getitem__(self, idx):
#         indices = self.indices[idx:idx + self.batch_size]
        b = torch.as_tensor(self.dataset[idx:idx + self.batch_size])
        return b#torch.stack(self.dataset[idx:idx + self.batch_size])
        
    
def collate_fn(batch):
    batch = torch.as_tensor(batch)
#     c = torch.stack(batch)
    log("Loading {}".format(hash(c)))
    return c

def train(model, loader, optimiser, epochs, frames_in, frames_out, device):
    for e in range(epochs):
        for b, y in enumerate(loader):
            y = y.to(device)
            y_in = y[:frames_in]
            y_out = y[frames_in:]
            
            model = model.to(device)
            z, skip_connections = model.encode(y)
            u, jacobian = model.propagate(z)
            y_hat = model.decode(u, skip_connections)
            
            rec_loss = torch.sqrt(F.mse_loss(y_hat, y, reduction="sum"))
            nf_loss = -torch.sum(model.propagator.base_dist.log_prob(u) + jacobian, dim=1).sum()
            
            loss = rec_loss + nf_loss
            
            loss.backward()
            optimiser.step()
            optimiser.zero_grad()
            
            # log
            log_progress(e, b, len(loader), {"loss": loss})
            idx = random.randint(0, len(y_hat) - 1)
            plot_progress(e, y_hat[idx], y[idx], loss)

            
if __name__ == "__main__":
    ## HYPERPARAMS
    ROOT = "/media/SSD1/epignatelli/train_dev_set/"
    EPOCHS = 100000
    DEVICE = torch.device("cuda")
    INPUT_SIZE = 256
    HIDDEN_SIZE = 2048
    BATCH_SIZE = 2
    IN_FRAMES = 4
    OUT_FRAMES = 5
    DEBUG = False
    
    net = UNet([8, 16, 32, 64, 128], HIDDEN_SIZE, INPUT_SIZE, IN_FRAMES, OUT_FRAMES).to(DEVICE)
    log(net)
    log(net.parameters_count())
    filename = glob(ROOT + "*.hdf5")[0]
    fkset = FkDataset(ROOT, IN_FRAMES, OUT_FRAMES, 1, transforms=t.Compose([Downsample((INPUT_SIZE, INPUT_SIZE))]), squeeze=True)
    loader = DataLoader(fkset, num_workers=0, collate_fn=collate_fn, batch_sampler=BatchSampler(RandomSampler(fkset), BATCH_SIZE, True))
    optimiser = torch.optim.Adam(net.parameters())
  
    train(net, loader, optimiser, EPOCHS, IN_FRAMES, OUT_FRAMES, DEVICE)

RuntimeError: Given groups=1, weight of size [8, 4, 3, 3], expected input[9, 3, 1200, 1200] to have 4 channels, but got 3 channels instead