In [None]:
#From  https://github.com/cwkx/GON
# requirements
import torch
import torch.nn as nn
import torchvision
import numpy as np

# colab requirements
from IPython.display import clear_output
import matplotlib.pyplot as plt
from time import sleep

In [None]:
# image data
img_size = 32
n_channels = 3
img_coords = 2

# training info
lr = 1e-4
batch_size = 64
nz = 128
ngf = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# load datasets
transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor()])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, sampler=None, shuffle=True, batch_size=batch_size, drop_last=True)

In [None]:
# create the GON network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 4, 4, 1, 0, bias=True),
            nn.BatchNorm2d(ngf * 4),
            nn.ELU(True),

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=True),
            nn.BatchNorm2d(ngf * 2),
            nn.ELU(True),

            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=True),
            nn.BatchNorm2d(ngf),
            nn.ELU(True),

            nn.ConvTranspose2d(ngf, n_channels, 4, 2, 1, bias=True),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)
        
class Res_Block(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(Res_Block, self).__init__()
    self.C1 = nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, 3, 1, 1))
    self.C2 = nn.utils.spectral_norm(nn.Conv2d(out_channels, out_channels, 4, 2, 1))
    self.CS = nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, 4, 2, 1))

  def forward(self, x):
    x_i = self.CS(x)
    x = nn.functional.leaky_relu(self.C1(x))
    x = nn.functional.leaky_relu(x_i + self.C2(x))
    return x

class Disc(nn.Module):
  def __init__(self, img_size):
    super(Disc, self).__init__()
    self.h, self.w = img_size//8, img_size//8
    self.C1 = nn.utils.spectral_norm(nn.Conv2d(3, 32, 4, 2, 1))
    self.C2 = nn.utils.spectral_norm(nn.Conv2d(32, 64, 4, 2, 1))
    self.C3 = nn.utils.spectral_norm(nn.Conv2d(64, 128, 4, 2, 1))
    self.C4 = nn.utils.spectral_norm(nn.Conv2d(128, 128, 3, 1, 1))

    self.R1 = Res_Block(3, 32)
    self.R2 = Res_Block(32, 64)
    self.R3 = Res_Block(64, 128)
    self.D1 = nn.utils.spectral_norm(nn.Linear(128*self.h*self.w, 1))

  def forward(self, x):
    #x = nn.functional.leaky_relu(self.C1(x))
    #x = nn.functional.leaky_relu(self.C2(x))
    #x = nn.functional.leaky_relu(self.C3(x))
    x = self.R3(self.R2(self.R1(x)))
    x = nn.functional.leaky_relu(self.C4(x))
    x = x.reshape(x.size(0), 128*self.h*self.w)
    return self.D1(x)

In [None]:
# helper functions
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

def slerp(a, b, t):
    omega = torch.acos((a/torch.norm(a, dim=1, keepdim=True)*b/torch.norm(b, dim=1, keepdim=True)).sum(1)).unsqueeze(1)
    res = (torch.sin((1.0-t)*omega)/torch.sin(omega))*a + (torch.sin(t*omega)/torch.sin(omega)) * b
    return res

def slerp_batch(model, z):
    lz = z.data.clone().squeeze(-1).squeeze(-1)
    col_size = int(np.sqrt(z.size(0)))
    src_z = lz.data[:col_size].repeat(col_size,1)
    z1, z2 = lz.data.split(lz.shape[0]//2)
    tgt_z = torch.cat([z2, z1])
    tgt_z = tgt_z[:col_size].repeat(col_size,1)
    t = torch.linspace(0,1,col_size).unsqueeze(1).repeat(1,col_size).contiguous().view(batch_size,1).contiguous().to(device)
    z_slerp = slerp(src_z, tgt_z, t)
    g_slerp = model(z_slerp.unsqueeze(-1).unsqueeze(-1))
    return g_slerp

In [None]:
F = Generator().to(device)
D = Disc(img_size).to(device)

bce_loss = nn.BCEWithLogitsLoss()
optim = torch.optim.Adam(lr=lr, params=F.parameters())
optimizerD = torch.optim.Adam(lr = 0.0005, params = D.parameters(), betas = (0.0, 0.99))
print(f'> Number of G parameters {len(torch.nn.utils.parameters_to_vector(F.parameters()))}')
print(f'> Number of D parameters {len(torch.nn.utils.parameters_to_vector(D.parameters()))}')

In [None]:
train_iterator = iter(cycle(train_loader))
iterations = 10000 + 1
recent_zs = []

for step in range(iterations):
    # sample a batch of data
    x, t = next(train_iterator)
    x, t = x.to(device), t.to(device)
    real_label, fake_label = torch.ones(x.size(0), 1).to(device), torch.zeros(x.size(0), 1).to(device)

    # compute the gradients of the inner loss with respect to zeros (gradient origin)
    z = torch.zeros(batch_size, nz, 1, 1).to(device).requires_grad_()
    g = F(z)
    L_inner = ((g - x)**2).sum(1).mean()
    grad = torch.autograd.grad(L_inner, [z], create_graph=True, retain_graph=True)[0]
    z = (-grad)

    # now with z as our new latent points, optimise the data fitting loss
    g = F(z)

    fake_output = D(g)
    real_output = D(x)
    errG = torch.mean(((real_output - torch.mean(fake_output)) - fake_label)**2) + torch.mean(((fake_output - torch.mean(real_output)) - real_label)**2)
    errG /= 2

    L_outer = errG + 10*((g - x)**2).sum(1).mean()
    optim.zero_grad()
    L_outer.backward()
    optim.step()

    recent_zs.append(z.detach())
    recent_zs = recent_zs[-100:]
    
    #Discomator ####################################################################################################
    D.zero_grad()
    real_images = x
    sample = g.detach()

    real_output = D(real_images)
    fake_output = D(sample)
    
    errD_real = torch.mean(((real_output - torch.mean(fake_output)) - real_label)**2) 
    errD_fake = torch.mean(((fake_output - torch.mean(real_output)) - fake_label)**2)
    errD = (errD_fake + errD_real)/2
    errD.backward()
    optimizerD.step()

    if step % 50 == 0 and step > 0:
        print(f"Step: {step}   Loss: {L_outer.item():.3f} ({errG.item():.3f}) DLoss: {errD.item():.3f}")
    if step % 250 == 0 and step > 0:
        clear_output()
        print(f"Step: {step}   Loss: {L_outer.item():.3f} ({errG.item():.3f}) DLoss: {errD.item():.3f}")

        # plot reconstructions and interpolations
        recons = torchvision.utils.make_grid(torch.clamp(g, 0, 1)[:16], nrow=8)
        slerps = torchvision.utils.make_grid(torch.clamp(slerp_batch(F, z.data), 0, 1), nrow=8)
        print(recons.shape)
        plt.figure(figsize=(15,15))
        plt.title('Reconstructions')
        plt.imshow(recons.permute(1,2,0).cpu().data.numpy())
        plt.figure(figsize=(10,10))
        plt.title('Spherical Interpolations')
        plt.imshow(slerps.permute(1,2,0).cpu().data.numpy())
        plt.show()
        sleep(1)