In [1]:
import random
import torch
from torch import nn
import torchvision.transforms as t
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import RandomSampler, BatchSampler
from dataset import FkDataset, Simulation
import fk
import numpy as np
from flows import MADE
from glob import glob
import matplotlib.pyplot as plt


def log(*m, **kwargs):
    if DEBUG:
        print(*m, **kwargs)
        
def log_progress(epoch, batch_number, n_batches, loss):
    s = f"Epoch: {epoch} \t Batch: {batch_number}/{n_batches} \t"
    s += "\t".join(["{}_loss: {:.4f}".format(k, v) for k, v in loss.items()])
    print(s, end="\r")
    return

def plot_progress(epoch, x_hat, x, loss, **kwargs):
    fk.plot.show(x_hat.detach().cpu().numpy(), vmin=None, vmax=None)
    fk.plot.show(x.detach().cpu().numpy(), vmin=None, vmax=None)
    plt.show()
    return

def gradient_loss(gen_frames, gt_frames):
    def gradient(x):
        h_x = x.size()[-2]
        w_x = x.size()[-1]
        left = x
        right = torch.nn.functional.pad(x, [0, 1, 0, 0])[:, :, :, 1:]
        top = x
        bottom = torch.nn.functional.pad(x, [0, 0, 0, 1])[:, :, 1:, :]

        dx, dy = right - left, bottom - top
        dx[:, :, :, -1] = 0
        dy[:, :, -1, :] = 0

        return torch.sqrt(dx.pow(2) + dy.pow(2))

    grad_pred = gradient(gen_frames)
    grad_truth = gradient(gt_frames)

    # condense into one tensor and avg
    return torch.nn.functional.mse_loss(grad_pred, grad_truth, reduction="sum")

class Downsample:
    def __init__(self, size, mode="bicubic"):
        self.size = size
        self.mode = mode
    
    def __call__(self, x):
        return torch.nn.functional.interpolate(x, self.size)
    
class UNetConvBlock3D(nn.Sequential):
    def __init__(self, in_channels, out_channels, padding):
        super(UNetConvBlock3D, self).__init__()

        # add convolution 1
        self.add_module("conv1",
                        nn.Conv3d(in_channels=in_channels,
                                  out_channels=out_channels,
                                  kernel_size=3,
                                  padding=int(padding)))
        self.add_module("relu1", nn.ReLU())

        # add convolution 2
        self.add_module("conv2",
                        nn.Conv3d(in_channels=out_channels,
                                  out_channels=out_channels,
                                  kernel_size=3,
                                  padding=int(padding)))
        self.add_module("relu2", nn.ReLU())

    def forward(self, x):
        return super().forward(x)


class UNetUpBlock3D(nn.Module):
    def __init__(self, in_channels, out_channels, padding):
        super(UNetUpBlock3D, self).__init__()

        # upsample
        self.up = nn.ConvTranspose3d(in_channels=in_channels,
                                     out_channels=out_channels,
                                     kernel_size=(3, 4, 4),
                                     stride=4)

        # add convolutions block
        self.conv_block = UNetConvBlock3D(in_channels=in_channels,
                                        out_channels=out_channels,
                                        padding=padding)

    def forward(self, x, skip_connection):
        log(x.shape, skip_connection.shape)
        up = self.up(x)
        log(up.shape, skip_connection.shape)
        out = torch.cat([up, skip_connection], 2)
        log(out.shape)
        out = self.conv_block(out)
        log(out.shape)
        return out

    
class UNet3D(nn.Module):
    def __init__(self, filters, in_channels, out_channels):
        super(UNet3D, self).__init__()
        
        depth = len(filters)
        if (in_channels == 1 and out_channels == 0):
            in_channels = 3
            out_channels = 3
        
        # downsampling
        self.downsample = nn.ModuleList()
        for i in range(depth):
            log(filters[i])
            self.downsample.append(UNetConvBlock3D(in_channels=in_channels,
                                                 out_channels=filters[i],
                                                 padding=1))
            in_channels = filters[i]
        
        # upsample
        self.upsample = nn.ModuleList()
        out_filter = [out_channels] + filters
        for i in reversed(range(depth)):
            log(filters[i])
            self.upsample.append(UNetUpBlock3D(in_channels=in_channels,
                                             out_channels=out_filter[i],
                                             padding=1))
            in_channels = out_filter[i]
            
        self.output = nn.Conv2d(in_channels, out_channels, 1, 1)
        return
    
    def get_loss(self, u, sum_log_abs_det_jacobians, y_hat, y):
        mse =  torch.sqrt(F.mse_loss(y_hat, y, reduction="sum"))
        grad = torch.sqrt(gradient_loss(y_hat, y))
        nf = torch.sum(self.propagator.base_dist.log_prob(u) + sum_log_abs_det_jacobians, dim=1).sum()
        return {"mse": mse, "grad": grad, "nf": nf}
    
    def encode(self, X):
        log("input", X.shape)
        skip_connections = []
        z = X
        for i, down in enumerate(self.downsample):
            z = down(z)
            log("down", z.shape)
            log("skip_connection {}".format(i), z.shape)
            if i != len(self.downsample) - 1:
                skip_connections.append(z)
                z = F.max_pool3d(z, (1, 4, 4))
        return z, skip_connections
    
    def decode(self, z, skip_connections):
        print([s.shape for s in skip_connections])
        for i, up in enumerate(self.upsample):
            y_hat = up(z, skip_connections[-i - 1])
            log("up", y_hat.shape)
        y_hat = self.output(y_hat)
        log("output", y_hat.shape)
        return y_hat
    
    def forward(self, X):
        z, skip_connections = self.encode(X)
        y_hat = self.decode(z, skip_connections)
        return y_hat, self.get_loss(u, jacobian, y_hat, X)
    
    def parameters_count(self):
        return sum(p.numel() for p in self.parameters())

    
def collate_fn(batch):
    c = torch.stack(batch)
    log("Loading {}".format(hash(c)))
    return c

def train(model, loader, optimiser, epochs, frames_in, frames_out, device):
    for e in range(epochs):
        for b, y in enumerate(loader):
            y = torch.ra
            log(y.shape)
            y = y.to(device)
            y_in = y[:, :frames_in]
            y_out = y[:, frames_in:]
            log(y_in.shape)
            log(y_out.shape)
            model = model.to(device)
            z, skip_connections = model.encode(y_in)
            y_hat = model.decode(z, skip_connections)
            
            rec_loss = torch.sqrt(F.mse_loss(y_hat, y_out, reduction="sum"))            
            loss = rec_loss# + nf_loss
            loss.backward()
            optimiser.step()
            optimiser.zero_grad()
            
            # log
            log_progress(e, b, len(loader), {"loss": loss})
            idx = random.randint(0, len(y_hat) - 1)
            plot_progress(e, y_hat[idx], y[idx], loss)

def test():
    y_in = torch.randn(BATCH_SIZE, IN_FRAMES, 3, INPUT_SIZE, INPUT_SIZE, device=DEVICE)
    y_out = torch.randn(BATCH_SIZE, OUT_FRAMES, 3, INPUT_SIZE, INPUT_SIZE, device=DEVICE)
    model = UNet3D([8, 16, 32, 64, 128], IN_FRAMES, OUT_FRAMES).to(DEVICE)
    z, skip_connections = model.encode(y_in)
    y_hat = model.decode(z, skip_connections)
    loss = torch.sqrt(F.mse_loss(y_hat, y_out, reduction="sum"))
    loss.backward()
    
    
    
if __name__ == "__main__":
    ## HYPERPARAMS
    ROOT = "/media/SSD1/epignatelli/train_dev_set/"
    EPOCHS = 100000
    DEVICE = torch.device("cuda")
    INPUT_SIZE = 256
    HIDDEN_SIZE = 2048
    BATCH_SIZE = 4
    IN_FRAMES = 7
    OUT_FRAMES = 9
    DEBUG = True
    
    model = UNet3D([8, 16, 32, 64, 128], IN_FRAMES, OUT_FRAMES).to(DEVICE)
    log(model)
    log(model.parameters_count())
    fkset = FkDataset(ROOT, IN_FRAMES, OUT_FRAMES, 1, transforms=t.Compose([Downsample((INPUT_SIZE, INPUT_SIZE))]), squeeze=True)
    loader = DataLoader(fkset, num_workers=0, collate_fn=collate_fn, batch_size=BATCH_SIZE, drop_last=True)
    optimiser = torch.optim.Adam(model.parameters())
#     train(model, loader, optimiser, EPOCHS, IN_FRAMES, OUT_FRAMES, DEVICE)
    test()

8
16
32
64
128
128
64
32
16
8
UNet3D(
  (downsample): ModuleList(
    (0): UNetConvBlock3D(
      (conv1): Conv3d(7, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (relu1): ReLU()
      (conv2): Conv3d(8, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (relu2): ReLU()
    )
    (1): UNetConvBlock3D(
      (conv1): Conv3d(8, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (relu1): ReLU()
      (conv2): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (relu2): ReLU()
    )
    (2): UNetConvBlock3D(
      (conv1): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (relu1): ReLU()
      (conv2): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (relu2): ReLU()
    )
    (3): UNetConvBlock3D(
      (conv1): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (relu1): ReLU()
      (conv2): Conv3d(64, 64, ker

RuntimeError: Given groups=1, weight of size [64, 128, 3, 3, 3], expected input[4, 64, 14, 4, 4] to have 128 channels, but got 64 channels instead