In [None]:
import time
import math
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

import torchode as to
from adabelief_pytorch import AdaBelief
from pytorch_msssim import SSIM, MS_SSIM, ssim
from focal_frequency_loss import FocalFrequencyLoss as FFL

from IPython.display import Image, Video, HTML

In [None]:
train_dataset = torchvision.datasets.CIFAR10(
    "MNIST",
    train=True,
    download=True
    ,
    transform=torchvision.transforms.ToTensor()
)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)

data, labels = next(iter(trainloader))
grid = torchvision.utils.make_grid(data[:10], nrow=5, padding=0).permute(1, 2, 0)
plt.imshow(grid)
plt.axis('off')
plt.show()

In [None]:
class GNCA(nn.Module):
    def __init__(self, latent_dim, aug_dim, img_dim=(32, 32)):
        super(GNCA, self).__init__()
        self.latent_dim = latent_dim
        self.aug_dim = aug_dim
        self.h, self.w = img_dim[0], img_dim[1]
        self.set_modules()
        
    def set_modules(self):
        sobel_x = torch.Tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]])
        sobel_y = torch.Tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]])
        laplace = torch.Tensor([[1, 2, 1], [2, -12, 2], [1, 2, 1]])

        self.C_S_x = nn.Conv2d(self.aug_dim, self.aug_dim, kernel_size = 3, stride = 1, padding = 1, bias = False, groups = self.aug_dim)
        self.C_S_x.weight = nn.Parameter(sobel_x.unsqueeze(0).repeat(self.aug_dim, 1, 1, 1), requires_grad = False)

        self.C_S_y = nn.Conv2d(self.aug_dim, self.aug_dim, kernel_size = 3, stride = 1, padding = 1, bias = False, groups = self.aug_dim)
        self.C_S_y.weight = nn.Parameter(sobel_y.unsqueeze(0).repeat(self.aug_dim, 1, 1, 1), requires_grad = False)
        
        self.C_S_l = nn.Conv2d(self.aug_dim, self.aug_dim, kernel_size = 3, stride = 1, padding = 1, bias = False, groups = self.aug_dim)
        self.C_S_l.weight = nn.Parameter(laplace.unsqueeze(0).repeat(self.aug_dim, 1, 1, 1), requires_grad = False)

        self.PosEmb = nn.Parameter(torch.randn(1, self.aug_dim, 32, 32), requires_grad=True)
        self.G0, self.B0 = nn.Linear(self.latent_dim, self.aug_dim), nn.Linear(self.latent_dim, self.aug_dim)

        self.D1 = nn.Conv2d(6*self.aug_dim + self.latent_dim, 128, 1, 1, 0)
        #self.D2 = nn.Conv2d(64, 64, 1, 1, 0)
        #self.D3 = nn.Conv2d(64, 64, 1, 1, 0)
        self.D4 = nn.Conv2d(128, self.aug_dim, 1, 1, 0, bias=False)

        self.Pool = nn.MaxPool2d(3, stride = 1, padding = 1)
        self.scale = nn.Parameter(torch.ones(1, self.aug_dim, 32, 32), requires_grad=True)

    def perceive(self, x):
        #g, b = self.G0(z)[:, :, None, None], self.B0(z)[:, :, None, None]
        #x = g*x + b
        x1 = self.C_S_x(x)    #SobelX
        x2 = self.C_S_y(x)    #Sobel Y
        x3 = self.C_S_l(x)    #Sobel Y            #Identity
        x = torch.cat([x1, x2, x3, x, self.Pool(x), -self.Pool(-x)], dim = 1)
        return x

    def update(self, x, z):
        x = torch.cat([x, z[:, :, None, None].repeat(1, 1, self.h, self.w)], dim=1)
        x = self.D1(x)
        x = F.elu(x)
        #x = F.elu(self.D2(x))
        #x = F.elu(self.D3(x))
        return self.D4(x)
    
    def forward(self, x, z):
        x_i = x.clone()
        #x = x + self.PosEmb
        x = self.perceive(x)
        x = self.update(x, z)
        x = x_i + x
        x = F.dropout(x, p=0.01)
        return torch.sigmoid(x)


class Encoda(nn.Module):
    def __init__(self, latent_dim):
        super(Encoda, self).__init__()

        self.N = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1),
            nn.BatchNorm2d(32),
            nn.ELU(),
            nn.Conv2d(32, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.Conv2d(64, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.Conv2d(64, 64, 4, 2, 1),
        )
        self.F = nn.Linear(2*2*64, latent_dim)

    def forward(self, x):
        x = self.N(x).reshape(x.shape[0], -1)
        return torch.tanh(self.F(x))



In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

ts = torch.linspace(0, 1, 16).to(device)
aug_dim = 16
latent_dim = 128

netEnc = Encoda(latent_dim).to(device)
netG = GNCA(latent_dim, aug_dim).to(device)

criterion = nn.CrossEntropyLoss()
ffl = FFL(loss_weight=1.0, alpha=1.0) 
#optG = optim.SGD(netG.parameters(), 0.01)
#optG = optim.Adam(netG.parameters(), lr=1e-3)
optG = AdaBelief(netG.parameters(), lr=1e-3, eps=1e-16, betas=(0.9,0.999), weight_decouple = True, rectify = False)
optEnc = AdaBelief(netEnc.parameters(), lr=1e-3, eps=1e-16, betas=(0.9,0.999), weight_decouple = True, rectify = False)

def lossfn(x, y):
    #loss = 1 - ssim(x, y, data_range=1, size_average=True)
    loss = ((x - y)**2).mean()
    #loss += ffl(x, y)
    return loss

#images, _ = next(iter(trainloader)) 

for step in range(0, 10000):
    start = time.time()
    
    images, _ = next(iter(trainloader)) 
    images = images.to(device)
    b_size = images.shape[0]
    
    seed = torch.zeros(b_size, aug_dim, 32, 32).to(device)

    optEnc.zero_grad()
    optG.zero_grad()
    
    z = netEnc(images)
    
    g = seed.clone()
    g[:, :3, 0, 0] = 1

    outs = []
    out = g
    for i in range(32):
        out = netG(out, z)
        outs.append(out.unsqueeze(0))
    outs = torch.cat(outs)

    L = 0
    for i in range(1, 32):
        L += lossfn(outs[i][:, :3, :i, :i], images[:, :, :i, :i])
    
    L.backward()
    optG.step()
    optEnc.step()

    end = time.time()

    if step%20 == 0:
        print(f"Step: {step} Loss: {L.item()} (Gen) Time per step: {round(end - start, 3)}s")
        #print(z.mean(), z.std()
    if step%100 == 0:
        print(f"Images, reconstruction")

        train_imgs = torch.cat([images[:8], outs[-1][:8, :3]], dim=0)
        grid = torchvision.utils.make_grid(train_imgs, nrow=8, padding=0).permute(1, 2, 0).cpu().detach().numpy()
        plt.imshow(grid)
        plt.axis('off')
        plt.show()

