In [None]:
#From https://github.com/cwkx/GON
#
# requirements
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import cv2
from random import shuffle
import matplotlib.pyplot as plt

# colab requirements
from IPython.display import clear_output
import matplotlib.pyplot as plt
from time import sleep

In [None]:
img_size = 32
n_channels = 3
img_coords = 2

# training info
lr = 1e-4
batch_size = 32
num_latent = 32
hidden_features = 256
num_layers = 4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#CIFAR
transform = transforms.Compose(
    [transforms.ToTensor()])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
dataloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=1)

x, _= next(iter(dataloader))
print(x.shape)

img = np.transpose(x[0], (1, 2, 0))
plt.imshow(img)
plt.show()

In [None]:
# create the GON network (a SIREN as in https://vsitzmann.github.io/siren/)
class SirenLayer(nn.Module):
    def __init__(self, in_f, out_f, w0=30, is_first=False, is_last=False):
        super().__init__()
        self.in_f = in_f
        self.w0 = w0
        self.linear = nn.Linear(in_f, out_f)
        self.is_first = is_first
        self.is_last = is_last
        self.init_weights()
    
    def init_weights(self):
        b = 1 / self.in_f if self.is_first else np.sqrt(6 / self.in_f) / self.w0
        with torch.no_grad():
            self.linear.weight.uniform_(-b, b)

    def forward(self, x):
        x = self.linear(x)
        return x if self.is_last else torch.sin(self.w0 * x)

def gon_model(dimensions):
    first_layer = SirenLayer(dimensions[0], dimensions[1], is_first=True)
    other_layers = []
    for dim0, dim1 in zip(dimensions[1:-2], dimensions[2:-1]):
        other_layers.append(SirenLayer(dim0, dim1))
    final_layer = SirenLayer(dimensions[-2], dimensions[-1], is_last=True)
    return nn.Sequential(first_layer, *other_layers, final_layer)

class Res_Block(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(Res_Block, self).__init__()
    self.C1 = nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, 3, 1, 1))
    self.C2 = nn.utils.spectral_norm(nn.Conv2d(out_channels, out_channels, 4, 2, 1))
    self.CS = nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, 4, 2, 1))

  def forward(self, x):
    x_i = self.CS(x)
    x = nn.functional.leaky_relu(self.C1(x))
    x = nn.functional.leaky_relu(x_i + self.C2(x))
    return x

class Disc(nn.Module):
  def __init__(self, img_size):
    super(Disc, self).__init__()
    self.h, self.w = img_size//8, img_size//8
    self.C1 = nn.utils.spectral_norm(nn.Conv2d(3, 32, 4, 2, 1))
    self.C2 = nn.utils.spectral_norm(nn.Conv2d(32, 64, 4, 2, 1))
    self.C3 = nn.utils.spectral_norm(nn.Conv2d(64, 128, 4, 2, 1))
    self.C4 = nn.utils.spectral_norm(nn.Conv2d(128, 128, 3, 1, 1))

    self.R1 = Res_Block(3, 32)
    self.R2 = Res_Block(32, 64)
    self.R3 = Res_Block(64, 128)
    self.D1 = nn.utils.spectral_norm(nn.Linear(128*self.h*self.w, 1))

  def forward(self, x):
    #x = nn.functional.leaky_relu(self.C1(x))
    #x = nn.functional.leaky_relu(self.C2(x))
    #x = nn.functional.leaky_relu(self.C3(x))
    x = self.R3(self.R2(self.R1(x)))
    x = nn.functional.leaky_relu(self.C4(x))
    x = x.reshape(x.size(0), 128*self.h*self.w)
    return self.D1(x)

In [None]:
# helper functions
def get_mgrid(sidelen, dim=2):
    tensors = tuple(dim * [torch.linspace(-1, 1, steps=sidelen)])
    mgrid = torch.stack(torch.meshgrid(*tensors), dim=-1)
    mgrid = mgrid.reshape(-1, dim)
    return mgrid

def cycle(iterable):
    while True:
        for x in iterable:
            yield x

def slerp(a, b, t):
    omega = torch.acos((a/torch.norm(a, dim=1, keepdim=True)*b/torch.norm(b, dim=1, keepdim=True)).sum(1)).unsqueeze(1)
    res = (torch.sin((1.0-t)*omega)/torch.sin(omega))*a + (torch.sin(t*omega)/torch.sin(omega)) * b
    return res

def slerp_batch(model, z, coords):
    lz = z.data.clone().squeeze(1)
    col_size = int(np.sqrt(z.size(0)))
    src_z = lz.data[:col_size].repeat(col_size,1)
    z1, z2 = lz.data.split(lz.shape[0]//2)
    tgt_z = torch.cat([z2, z1])
    tgt_z = tgt_z[:col_size].repeat(col_size,1)
    t = torch.linspace(0,1,col_size).unsqueeze(1).repeat(1,col_size).contiguous().view(batch_size,1).contiguous().to(device)
    z_slerp = slerp(src_z, tgt_z, t)
    z_slerp_rep = z_slerp.unsqueeze(1).repeat(1,coords.size(1),1) 
    g_slerp = model(torch.cat((coords, z_slerp_rep), dim=-1))
    return g_slerp

def gon_sample(model, recent_zs, coords):
    zs = torch.cat(recent_zs, dim=0).squeeze(1).cpu().numpy()
    mean = np.mean(zs, axis=0)
    cov = np.cov(zs.T)
    sample = np.random.multivariate_normal(mean, cov, size=batch_size)
    sample = torch.tensor(sample).unsqueeze(1).repeat(1,coords.size(1),1).to(device).float()
    model_input = torch.cat((coords, sample), dim=-1)
    return model(model_input)
    

In [None]:
# define GON architecture, for example gon_shape = [34, 256, 256, 256, 256, 3]
#gon_shape = [img_coords+num_latent] + [hidden_features]*num_layers + [n_channels]
gon_shape = [img_coords+num_latent] + [256, 256, 256, 256] + [n_channels]
F = gon_model(gon_shape).to(device)
D = Disc(img_size).to(device)

bce_loss = nn.BCEWithLogitsLoss()
optim = torch.optim.Adam(lr=0.0001, params=F.parameters())
optimizerD = torch.optim.Adam(lr = 0.0005, params = D.parameters(), betas = (0.0, 0.99))
c = torch.stack([get_mgrid(img_size, 2) for _ in range(batch_size)]).to(device) # coordinates
print(c.shape)
recent_zs = []
print(f'> Number of G parameters {len(torch.nn.utils.parameters_to_vector(F.parameters()))}')
print(f'> Number of D parameters {len(torch.nn.utils.parameters_to_vector(D.parameters()))}')

In [None]:
train_iterator = iter(cycle(dataloader))

for step in range(10001):
    # sample a batch of data
    x, _ = next(train_iterator)
    batch_size = x.size(0)
    x = x.to(device)
    x = x.permute(0, 2, 3, 1).reshape(batch_size, -1, n_channels)

    real_label, fake_label = torch.ones(batch_size, 1).to(device), torch.zeros(batch_size, 1).to(device)

    # compute the gradients of the inner loss with respect to zeros (gradient origin) ############################################
    z = torch.zeros(batch_size, 1, num_latent).to(device).requires_grad_()

    z_rep = z.repeat(1,c.size(1),1)
    g = F(torch.cat((c[:batch_size], z_rep), dim=-1))
    L_inner = ((g - x)**2).sum(1).mean() 
    z = -torch.autograd.grad(L_inner, [z], create_graph=True, retain_graph=True)[0]

    # now with z as our new latent points, optimise the data fitting loss #######################################################
    z_rep = z.repeat(1, c.size(1), 1)
    g = F(torch.cat((c[:batch_size], z_rep), dim=-1))

    fake_output = D(g.view(-1, img_size, img_size, 3).permute(0, 3, 1, 2))
    real_output = D(x.view(-1, img_size, img_size, 3).permute(0, 3, 1, 2))
    errG = torch.mean(((real_output - torch.mean(fake_output)) - fake_label)**2) + torch.mean(((fake_output - torch.mean(real_output)) - real_label)**2)
    errG /= 2

    L_outer = 10*errG + ((g - x)**2).sum(1).mean() 
    optim.zero_grad()
    L_outer.backward()
    optim.step()

    # compute sampling statistics
    recent_zs.append(z.detach())
    recent_zs = recent_zs[-100:]

    #Discomator ####################################################################################################
    D.zero_grad()
    real_images = x.view(-1, img_size, img_size, 3).permute(0, 3, 1, 2)

    sample = gon_sample(F, recent_zs, c[:batch_size]).detach()[:batch_size]
    sample = sample.view(-1, img_size, img_size, 3).permute(0, 3, 1, 2)

    real_output = D(real_images)
    fake_output = D(sample)
    
    errD_real = torch.mean(((real_output - torch.mean(fake_output)) - real_label)**2) 
    errD_fake = torch.mean(((fake_output - torch.mean(real_output)) - fake_label)**2)
    errD = (errD_fake + errD_real)/2
    errD.backward()
    optimizerD.step()
    
    if step % 10 == 0 and step > 0:
        print(f"Step: {step}   Loss: {L_outer.item():.3f}", "(", errD.item(), errG.item(), ")")
    if step % 100 == 0 and step > 0:
        clear_output()
        print(f"Step: {step}   Loss: {L_outer.item():.3f}")

        # plot reconstructions, interpolations, and samples
        recons = g.clone().detach()[:16]
        for i, r in enumerate(recons):
          recons[i] = (r - r.min())/(r.max() - r.min())

        sample = gon_sample(F, recent_zs, c[:batch_size]).detach()[:16]
        for i, r in enumerate(sample):
          sample[i] = (r - r.min())/(r.max() - r.min())

        recons = torchvision.utils.make_grid(recons.permute(0, 2, 1).reshape(-1, n_channels, img_size, img_size), nrow=8)
        #slerps = torchvision.utils.make_grid(torch.clamp(slerp_batch(F, z.data, c), 0, 1).permute(0,2,1).reshape(-1, n_channels, img_size, img_size), nrow=8)
        sample = torchvision.utils.make_grid(sample.permute(0,2,1).reshape(-1, n_channels, img_size, img_size), nrow = 8)

        plt.figure(figsize=(15,15))
        plt.title('Reconstructions')
        plt.imshow(recons.permute(1, 2, 0).cpu().data.numpy())

        #plt.figure(figsize=(15,15))
        #plt.title('Spherical Interpolations')
        #plt.imshow(slerps.permute(1, 2, 0).cpu().data.numpy())

        plt.figure(figsize=(15,15))
        plt.title('Samples')
        plt.imshow(sample.permute(1, 2, 0).cpu().data.numpy())
        plt.show()
