In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torchvision
from torch import nn
import torchvision.transforms as t
from torch.utils.data import DataLoader
from dataset import FkDataset
import matplotlib.pyplot as plt
import fk
import numpy as np
import random

In [3]:
DEBUG = False

In [84]:
class Downsample:
    def __init__(self, size, mode="bicubic"):
        self.size = size
        self.mode = mode
    
    def __call__(self, x):
        return torch.nn.functional.interpolate(x, self.size)

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)
    
class Unflatten(nn.Module):
    def __init__(self, size=256, h=None, w=None):
        self.size = size
        self.h = h if h is not None else 1
        self.w = w if w is not None else 1
        super(Unflatten, self).__init__()
    def forward(self, input):
        return input.view(input.size(0), self.size, self.w, self.h)

class Elu(nn.Module):
    def forward(self, x):
        return nn.functional.elu(x)

class ConvBlock(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size, stride, pool_kernel_size):
        modules = [
            nn.Conv2d(in_channels, out_channels, kernel_size, stride),
            Elu(),
            nn.MaxPool2d(pool_kernel_size),
        ]
        super(ConvBlock, self).__init__(*modules)

class ConvTransposeBlock(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        modules = [
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride),
            nn.ConvTranspose2d(out_channels, out_channels, kernel_size, stride),
            nn.Conv2d(out_channels, out_channels, 1, 1),
            Elu(),
        ]
        super(ConvTransposeBlock, self).__init__(*modules)
        
class Autoencoder(nn.Module):
    def __init__(self, input_size, hidden_dim):
        super(Autoencoder, self).__init__()
        
        self.encoder = nn.ModuleList([
            ConvBlock(3, 8, 3, 1, 3),
            ConvBlock(8, 16, 3, 1, 3),
            ConvBlock(16, 32, 3, 1, 3),
            ConvBlock(32, 64, 3, 1, 3),
            Flatten(),
        ])
        
        self.latent_encode = nn.Linear(input_size, hidden_dim)
        self.latent_decode = nn.Linear(hidden_dim, input_size)
        
        self.decoder = nn.ModuleList([
            Unflatten(64, 2, 2),
            ConvTransposeBlock(64, 32, 3, 2),
            ConvTransposeBlock(32, 16, 3, 2),
            ConvTransposeBlock(16, 8, 2, 2),
            ConvTransposeBlock(8, 3, 3, 1),
        ])
    
    @property
    def n_params(self):
        return sum(p.numel() for p in self.parameters())
    
    def encode(self, x):
        for module in self.encoder:
            x = module(x)
            print(x.shape)
        return x
    
    def decode(self, z):
        for module in self.decoder:
            z = module(z)
            print(z.shape)
        return z        
    
    def forward(self, x):
        x = self.encode(x)
        z = self.latent_encode(x)
        z = self.latent_decode(z)
        y_hat = self.decode(z) 
        return y_hat

In [85]:
fkset[100].shape

torch.Size([3, 256, 256])

In [86]:
Autoencoder(256, 512)(fkset[0].unsqueeze(0)).shape

torch.Size([1, 8, 84, 84])
torch.Size([1, 16, 27, 27])
torch.Size([1, 32, 8, 8])
torch.Size([1, 64, 2, 2])
torch.Size([1, 256])
torch.Size([1, 64, 2, 2])
torch.Size([1, 32, 11, 11])
torch.Size([1, 16, 15, 15])
torch.Size([1, 8, 60, 60])
torch.Size([1, 3, 62, 62])


torch.Size([1, 3, 62, 62])

In [14]:
if __name__ == "__main__":
    ## HYPERPARAMS
    root = "/home/ep119/repos/fenton_karma_jax/data/train_dev_set/"
    epochs = 100000
    device = torch.device("cuda")
    input_size = 256
    hidden_dim = 128
    loss_coeff = {
        "mse": 10000.,
        "kld": 1.,
        "grad": 0.
    }
    
    ## SETUP
    fkset = FkDataset(root, 1, 0, 1, transforms=t.Compose([Downsample((input_size, input_size))]), squeeze=True)
    loader = DataLoader(fkset, batch_size=32, shuffle=True, num_workers=0 if DEBUG else 12)
    vae = Autoencoder(input_size=input_size, hidden_dim=hidden_dim).to(device)
    optimiser = torch.optim.Adam(vae.parameters())
    print(vae)
    print("{} paramameters".format(vae.n_params))
    
    ## TEST
    Autoencoder(256, 512)(fkset[0].unsqueeze(0)).shape
    
    ## OPTIMISE
#     for e in range(epochs):
#         for b, sample in enumerate(loader):
#             pred, loss = vae(sample.to(device))
#             loss = {k: loss[k] * loss_coeff[k] for k in loss}
#             total_loss = sum(loss.values())

#             total_loss.backward()
#             optimiser.step()
#             optimiser.zero_grad()

#             log_progress(e, b, len(loader), loss)
#             idx = random.randint(0, len(pred) - 1)
#         plot_progress(e, pred[idx], sample[idx], total_loss)

Autoencoder(
  (encoder): ModuleList(
    (0): ConvBlock(
      (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1))
      (1): Elu()
      (2): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
    )
    (1): ConvBlock(
      (0): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1))
      (1): Elu()
      (2): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
    )
    (2): ConvBlock(
      (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
      (1): Elu()
      (2): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
    )
    (3): ConvBlock(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
      (1): Elu()
      (2): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
    )
    (4): Flatten()
  )
  (latent_encode): Linear(in_features=256, out_features=128, bias=True)
  (latent_decode): Linear(in_features=128, out_features=256, bias=True)
  (decoder): ModuleList(
    (0): Unfl

RuntimeError: Given transposed=1, weight of size [32, 32, 5, 5], expected input[1, 64, 9, 9] to have 32 channels, but got 64 channels instead