In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torchvision
from torch import nn
import torchvision.transforms as t
from torch.utils.data import DataLoader
from dataset import FkDataset
import matplotlib.pyplot as plt
import fk
import numpy as np
import random

In [3]:
DEBUG = False

In [4]:
## LOGGING

def log(m):
    if DEBUG:
        print(m)
        
def log_progress(epoch, batch_number, n_batches, loss):
    s = f"Epoch: {epoch} \t Batch: {batch_number}/{n_batches} \t"
    s += "\t".join(["{}_loss: {:.4f}".format(k, v) for k, v in loss.items()])
    print(s, end="\r")
    return
    
        
def plot_progress(epoch, x_hat, x, loss, **kwargs):
    fk.plot.show(x_hat.detach().cpu().numpy(), vmin=None, vmax=None)
    fk.plot.show(x.detach().cpu().numpy(), vmin=None, vmax=None)
    return

## TRANSFORMS

class Downsample:
    def __init__(self, size, mode="bicubic"):
        self.size = size
        self.mode = mode
    
    def __call__(self, x):
        return torch.nn.functional.interpolate(x, self.size)

class SetwiseMinMax:
    """
    Normalises a Dict[string, torch.Tensor] with keys "Input" and "Output" using  minmax normalisation.
    The statistics are calculated on an entire dataset
    Args:
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
    """
    def __init__(self, min=-85, max=15):
        self.min = min
        self.max = max

    def __call__(self, tensor):
        """
        Normalize a tensor using minimum and maximum values.
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.

        Returns:
            Tensor: Normalized Tensor.
        """
        return tensor.sub(self.min).div(self.max - self.min)
    
    @staticmethod
    def unnormalise(self, tensor):
        return tensor.mul(self.max - self.min).add(self.min)

def random_noise(tensor):
    if random.random() > 0.5:
        tensor += torch.randn_like(tensor)
    return tensor

def random_rotate(tensor):
    if random.random() > 0.5:
        tensor = torch.rot90(tensor, random.randint(1, 3))
    return tensor

def random_flip(tensor):
    if random.random() > 0.5:
        tensor = torch.flip(tensor, (random.randint(-2, -1),))
    return tensor
    
class Augment:
    def __call__(self, tensor):
        tensor = random_rotate(tensor)
        tensor = random_flip(tensor)
        return tensor

    
class ToCuda:
    def __call__(self, tensor):
        return tensor.cuda()

## MODULES

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)
    
class Unflatten(nn.Module):
    def __init__(self, size=256):
        self.size = size
        super(Unflatten, self).__init__()
    def forward(self, input):
        return input.view(input.size(0), self.size, 1, 1)

class T(nn.Module):
    def __init__(self):
        """
        This is another autoregressive flow, probably a Deep Dense Sigmoidal Flow.
        It propagates a state-space z_{0:t} in time, and returns a belief state b_{t+1}, which will be decoded by the network O.
        """
        pass
    
    def forward(self, zt):
        pass
    
## LOSSES

def gradient_loss(gen_frames, gt_frames):

    def gradient(x):
        # idea from tf.image.image_gradients(image)
        # https://github.com/tensorflow/tensorflow/blob/r2.1/tensorflow/python/ops/image_ops_impl.py#L3441-L3512
        # x: (b,c,h,w), float32 or float64
        # dx, dy: (b,c,h,w)

        h_x = x.size()[-2]
        w_x = x.size()[-1]
        # gradient step=1
        left = x
        right = torch.nn.functional.pad(x, [0, 1, 0, 0])[:, :, :, 1:]
        top = x
        bottom = torch.nn.functional.pad(x, [0, 0, 0, 1])[:, :, 1:, :]

        # dx, dy = torch.abs(right - left), torch.abs(bottom - top)
        dx, dy = right - left, bottom - top
        # dx will always have zeros in the last column, right-left
        # dy will always have zeros in the last row,    bottom-top
        dx[:, :, :, -1] = 0
        dy[:, :, -1, :] = 0

        return torch.sqrt(dx.pow(2) + dy.pow(2))

    # gradient
    grad_pred = gradient(gen_frames)
    grad_truth = gradient(gt_frames)

    # condense into one tensor and avg
    return torch.nn.functional.mse_loss(grad_pred, grad_truth, reduction="sum")

class GradientLoss(nn.Module):
    def forward(self, pred, truth):
        return gradient_loss(pred, truth)

In [5]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_dim):
        """
        Encodes a single state x_i into a posterior distribution b that contains the information about the belief-state of the environment.
        """
        super(Encoder, self).__init__()
        self.input_size = input_size
        self.encode = nn.ModuleList()
        self.encode.append(nn.Conv2d(in_channels=3, out_channels=8, kernel_size=2, stride=2))
        self.encode.append(nn.Conv2d(in_channels=8, out_channels=8, kernel_size=2, stride=2))
        current_size = input_size // 4
        i = 3
        while True:
            current_size /= 2
            if current_size < 1:
                break
            in_ch = 2 ** (i)
            out_ch = 2 ** (i + 1)
            i += 1
            layer = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=2, stride=2)
            self.encode.append(layer)
            current_size /= 2
            if current_size < 1:
                break            
            layer =  nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=2, stride=2)
            self.encode.append(layer)
        
        self.latent_dim = out_ch
        self.hidden_dim = hidden_dim

        self.flatten = Flatten()
        self.logvar = nn.Linear(self.latent_dim, self.hidden_dim)
        self.mu = nn.Linear(self.latent_dim, self.hidden_dim)
        return
    
    def reparameterise(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        esp = torch.randn_like(mu, device=mu.device)
        z = mu + std * esp         
        return z
    
    def forward(self, X):
        for module in self.encode:
            X = module(X)
            X = nn.functional.elu(X)
            log(X.shape)
        z = self.flatten(X)
        log(z.shape)
        mu = self.logvar(z)
        logvar = self.logvar(z)
        b = self.reparameterise(mu, logvar)
        return b, logvar, mu
    

class Decoder(nn.Module):
    def __init__(self, output_size, latent_dim):
        super(Decoder, self).__init__()
        self.output_size = output_size
        
        self.unflatten = Unflatten(latent_dim)
        
        self.decode = nn.ModuleList()
        current_size = 1
        i = 0
        while True:
            current_size *= 2
            if current_size > output_size // 2:
                break
            in_ch = latent_dim // (2 ** i)
            out_ch = in_ch // 2
            i += 1
            layer = nn.ConvTranspose2d(in_channels=in_ch, out_channels=out_ch, kernel_size=2, stride=2)
            self.decode.append(layer)
            current_size *= 2
            if current_size > output_size // 2:
                break
            layer = nn.ConvTranspose2d(in_channels=out_ch, out_channels=out_ch, kernel_size=2, stride=2)
            self.decode.append(layer)
        self.decode.append(nn.ConvTranspose2d(in_channels=out_ch, out_channels=3, kernel_size=2, stride=2))
        return
        
    def forward(self, b):
        b = self.unflatten(b)
        for i, module in enumerate(self.decode):
            if i != len(self.decode) - 1:
                b = module(b)
                b = torch.nn.functional.elu(b)
                log(b.shape)
        x = self.decode[-1](b, output_size=(self.output_size, self.output_size))
#         x = torch.softmax(x)
        return x


class VAE(nn.Module):
    def __init__(self, input_size, hidden_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder(input_size, hidden_dim)
        self.decoder = Decoder(input_size, hidden_dim)
        self.n_params = sum(p.numel() for p in self.parameters())
        
    def forward(self, x):
        b, mu, logvar = self.encoder(x)
        x_hat = self.decoder(b)
        loss = self.get_loss(x_hat, x, mu, logvar)
        return x_hat, loss
    
    def get_loss(self, x_hat, y, mu, logvar):
        mse = torch.nn.functional.mse_loss(x_hat, y, reduction="sum") / x_hat.numel()
        kld = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
        grad = 0.#gradient_loss(x_hat, y) / x_hat.numel()
        return {"mse": mse, "kld": kld, "grad": grad}

In [67]:
fkset[100].shape

torch.Size([3, 256, 256])

In [68]:
Autoencoder(256, 512)(fkset[0].unsqueeze(0)).shape

torch.Size([1, 8, 84, 84])
torch.Size([1, 16, 27, 27])
torch.Size([1, 32, 8, 8])
torch.Size([1, 64, 2, 2])
torch.Size([1, 256])
torch.Size([1, 64, 2, 2])
torch.Size([1, 64, 9, 9])


RuntimeError: Given transposed=1, weight of size [32, 32, 5, 5], expected input[1, 64, 9, 9] to have 32 channels, but got 64 channels instead

In [8]:
if __name__ == "__main__":
    ## HYPERPARAMS
    root = "/home/ep119/repos/fenton_karma_jax/data/train_dev_set/"
    epochs = 100000
    device = torch.device("cuda")
    input_size = 256
    hidden_dim = 128
    loss_coeff = {
        "mse": 10000.,
        "kld": 1.,
        "grad": 0.
    }
    
    ## SETUP
    fkset = FkDataset(root, 1, 0, 1, transforms=t.Compose([Downsample((input_size, input_size))]), squeeze=True)
    loader = DataLoader(fkset, batch_size=32, shuffle=True, num_workers=0 if DEBUG else 12)
    vae = VAE(input_size=input_size, hidden_dim=hidden_dim).to(device)
    optimiser = torch.optim.Adam(vae.parameters())
    print(vae)
    print("{} paramameters".format(vae.n_params))
    
    ## OPTIMISE
    for e in range(epochs):
        for b, sample in enumerate(loader):
            pred, loss = vae(sample.to(device))
            loss = {k: loss[k] * loss_coeff[k] for k in loss}
            total_loss = sum(loss.values())

            total_loss.backward()
            optimiser.step()
            optimiser.zero_grad()

            log_progress(e, b, len(loader), loss)
            idx = random.randint(0, len(pred) - 1)
        plot_progress(e, pred[idx], sample[idx], total_loss)

VAE(
  (encoder): Encoder(
    (encode): ModuleList(
      (0): Conv2d(3, 8, kernel_size=(2, 2), stride=(2, 2))
      (1): Conv2d(8, 8, kernel_size=(2, 2), stride=(2, 2))
      (2): Conv2d(8, 16, kernel_size=(2, 2), stride=(2, 2))
      (3): Conv2d(16, 16, kernel_size=(2, 2), stride=(2, 2))
      (4): Conv2d(16, 32, kernel_size=(2, 2), stride=(2, 2))
      (5): Conv2d(32, 32, kernel_size=(2, 2), stride=(2, 2))
      (6): Conv2d(32, 64, kernel_size=(2, 2), stride=(2, 2))
      (7): Conv2d(64, 64, kernel_size=(2, 2), stride=(2, 2))
    )
    (flatten): Flatten()
    (logvar): Linear(in_features=64, out_features=128, bias=True)
    (mu): Linear(in_features=64, out_features=128, bias=True)
  )
  (decoder): Decoder(
    (unflatten): Unflatten()
    (decode): ModuleList(
      (0): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
      (1): ConvTranspose2d(64, 64, kernel_size=(2, 2), stride=(2, 2))
      (2): ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2))
      (3): Co

KeyboardInterrupt: 