In [1]:
## Adapted from the GAN implementation in the PyTorch-GAN model zoo:
## https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py

In [1]:
import argparse
import os
import numpy as np
import math
import matplotlib.pyplot as plt
import seaborn as sns; sns.set_style("white")
%matplotlib inline

import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

In [2]:
os.makedirs('images', exist_ok=True)

n_epochs = 200 #'number of epochs of training'
batch_size = 64 #'size of the batches'
lr = 0.0002 #'adam: learning rate'
b1 = 0.5 #'adam: decay of first order momentum of gradient'
b2 = 0.999 #'adam: decay of first order momentum of gradient'
n_cpu = 4 #'number of cpu threads to use during batch generation'
latent_dim = 100 #'dimensionality of the latent space'
img_size = 28 #'size of each image dimension'
channels = 1 #'number of image channels'
sample_interval=400 #'interval betwen image samples'

In [3]:
torch.cuda.is_available()

False

In [4]:
img_shape = (channels, img_size, img_size)

cuda = True if torch.cuda.is_available() else False

In [5]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

In [6]:
# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Configure data loader
os.makedirs('../../data/mnist', exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('../../data/mnist', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                   ])),
    batch_size=batch_size, shuffle=True)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [7]:
## useful functions for collecting triangle statistics

def triplet_sample(X):
    """
    given an array of images X, return a triplet
    """
    while True:
        [i,j,k] = list(np.random.randint(0, X.shape[0], 3))
        if len({i,j,k}) == 3:
            return [X[i], X[j], X[k]]
        
        
def distance_L2(x,y):
    """
    given 2 images, x, y, return the normalized l2 distance between them,
    normalized so that the max distance is 1
    """
    return np.sqrt(np.sum((x - y)**2))/56


def distance_H(x,y):
    """
    given 2 images, x, y, return the normalized Hamming distance between them,
    """
    x_binary = np.asarray(x.reshape(784)>0)
    y_binary = np.asarray(y.reshape(784)>0) 
    return np.count_nonzero(x_binary != y_binary)/len(x_binary)


def triangle_distances_l2(x,y,z):
    """
    return the sorted l2 distances of the triangle formed by the points (x,y,z)
    """
    return sorted([distance_L2(x,y), distance_L2(x,z), distance_L2(y,z)])


def triangle_distances_H(x,y,z):
    """
    return the sorted normalized Hamming distances of the triangle formed by the points (x,y,z)
    """
    return sorted([distance_H(x,y), distance_H(x,z), distance_H(y,z)])


def angles(dxy, dxz, dyz):
    """
    Given the 3 distances of a triangle, return the sorted angles
    """
    theta_xy = np.arccos((-dxy**2 + dxz**2 + dyz**2)/(2*dxz*dyz))
    theta_xz = np.arccos((dxy**2 - dxz**2 + dyz**2)/(2*dxy*dyz))
    theta_yz = np.arccos((dxy**2 + dxz**2 - dyz**2)/(2*dxy*dxz))
    return sorted([theta_xy, theta_xz, theta_yz])


def triangle_distributions(X, Num):
    """
    Given an array of samples 4, generate two arrays of shape (Num, 2).
    
    Array 1: the 2 dimensions are [dmid-dmin, dmax-dmid] (l2 distance)
    Array 2: the 2 dimensions are [theta_min/theta_max, theta_min/theta_mid] (l2 distance)
    Array 3: the 2 dimensions are [dmid-dmin, dmax-dmid] (Hamming distance)
    Array 4: the 2 dimensions are [theta_min/theta_max, theta_min/theta_mid] (Hamming distance)
    """

    K_distances_l2 = np.zeros((Num,2))
    K_angles_l2 = np.zeros((Num,2))
    K_distances_H = np.zeros((Num,2))
    K_angles_H = np.zeros((Num,2))
    
    for i in range(Num):
        [x, y, z] = triplet_sample(X)
        
        ## l2 distances
        [d_min, d_mid, d_max] = triangle_distances_l2(x,y,z)
        [theta_min, theta_mid, theta_max] = angles(d_min, d_mid, d_max)        

        K_distances_l2[i,0] = d_mid - d_min
        K_distances_l2[i,1] = d_max - d_mid        
        K_angles_l2[i,0] = theta_min/theta_mid
        K_angles_l2[i,1] = theta_min/theta_max

        ## Hamming distances
        [d_min, d_mid, d_max] = triangle_distances_l2(x,y,z)
        [theta_min, theta_mid, theta_max] = angles(d_min, d_mid, d_max)        

        K_distances_H[i,0] = d_mid - d_min
        K_distances_H[i,1] = d_max - d_mid        
        K_angles_H[i,0] = theta_min/theta_mid
        K_angles_H[i,1] = theta_min/theta_max

        
    return [K_distances_l2, K_angles_l2, K_distances_H, K_angles_H]

### Training

In [None]:
g_loss_list = []
d_loss_list = []

for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)
        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------
        optimizer_G.zero_grad()
        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))
        # Generate a batch of images
        gen_imgs = generator(z)
        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()
        
        #----------------------
        # Generator Grad Overlap (In Progress)
        # ---------------------
        #x = Variable(torch.Tensor([1, 1]), requires_grad=True)
        #v = x.clone().detach()
        #f = 3*x[0]**2 + 4*x[0]*x[1] + x[1]**2
        #grad_f, = torch.autograd.grad(f, x, create_graph=True)
        #z = grad_f @ v
        #z.backward()
        #print(x.grad)
        
        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()
        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # ---------------------
        # Monitor Progress
        # ---------------------
        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:# and batches_done != 0:
            print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % \
                   (epoch, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))
            
            batches_done_str = str(batches_done).zfill(8)
            
            ## save the losses
            g_loss_list.append(g_loss.item())
            d_loss_list.append(d_loss.item())
            
            # plot a grid of samples
            x = gen_imgs.data[:,0,:,:].cpu()
            plt.figure(figsize=(6, 6))
            plt.suptitle("Sampled Images, Batch_Num: " + str(batches_done), fontsize=16)
            for j in range(64):
                plt.subplot(8, 8, j + 1)
                plt.imshow(x[j].reshape([28,28]), cmap='gray')
                plt.xticks(())
                plt.yticks(())
            plt.savefig('images/sample_images_' + batches_done_str + '.png')
            plt.close()
            #plt.show()

            Z = Variable(Tensor(np.random.normal(0, 1, (1000, latent_dim))))
            gen_imgs = generator(Z)
            X = gen_imgs.data[:,0,:,:].cpu().numpy()

            [K_distances_l2, K_angles_l2, K_distances_H, K_angles_H] = triangle_distributions(X, 10000)
            
            ## plot the 2d triangle distance density plot (l2 norm)
            fig,ax = plt.subplots(1,1)
            ax1 = sns.kdeplot(K_distances_l2[:,0], K_distances_l2[:,1])
            ax1.set_xlim([0, 1.0])
            ax1.set_ylim([0, 1.0])
            ax1.set_xlabel(r"$d_{mid}-d_{min}$", fontsize=16)
            ax1.set_ylabel(r"$d_{max}-d_{mid}$", fontsize=16)
            ax1.set_title("Distance Density Plot (l2 distance), Batch_Num: " + str(batches_done), fontsize=16)
            plt.savefig('images/distance_plot_l2_' + batches_done_str + '.png')
            plt.close()
            #plt.show()
            
            ## plot the 2d triangle angle density plot (l2 norm)
            fig,ax = plt.subplots(1,1)
            ax2 = sns.kdeplot(K_angles_l2[:,0], K_angles_l2[:,1])
            ax2.set_xlim([0, 1.0])
            ax2.set_ylim([0, 1.0])
            ax2.set_xlabel(r"$\theta_{min}/\theta_{mid}$", fontsize=16)
            ax2.set_ylabel(r"$\theta_{min}/\theta_{max}$", fontsize=16)
            ax2.set_title("Angle Density Plot (l2 distance), Batch_Num: " + str(batches_done), fontsize=16)
            plt.savefig('images/triangle_plot_l2_' + batches_done_str + '.png')
            plt.close()            
            #plt.show()
    
            ## plot the 2d triangle distance density plot (H norm)
            fig,ax = plt.subplots(1,1)
            ax1 = sns.kdeplot(K_distances_H[:,0], K_distances_H[:,1])
            ax1.set_xlim([0, 1.0])
            ax1.set_ylim([0, 1.0])
            ax1.set_xlabel(r"$d_{mid}-d_{min}$", fontsize=16)
            ax1.set_ylabel(r"$d_{max}-d_{mid}$", fontsize=16)
            ax1.set_title("Distance Density Plot (H distance), Batch_Num: " + str(batches_done), fontsize=16)
            plt.savefig('images/distance_plot_H_' + batches_done_str + '.png')
            plt.close()
            #plt.show()
            
            ## plot the 2d triangle angle density plot (H norm)
            fig,ax = plt.subplots(1,1)
            ax2 = sns.kdeplot(K_angles_H[:,0], K_angles_H[:,1])
            ax2.set_xlim([0, 1.0])
            ax2.set_ylim([0, 1.0])
            ax2.set_xlabel(r"$\theta_{min}/\theta_{mid}$", fontsize=16)
            ax2.set_ylabel(r"$\theta_{min}/\theta_{max}$", fontsize=16)
            ax2.set_title("Angle Density Plot (H distance), Batch_Num: " + str(batches_done), fontsize=16)
            plt.savefig('images/triangle_plot_H_' + batches_done_str + '.png')
            plt.close()            
            #plt.show()

[Epoch 0/200] [Batch 0/938] [D loss: 0.702290] [G loss: 0.690400]
[Epoch 0/200] [Batch 400/938] [D loss: 0.371219] [G loss: 0.740968]
[Epoch 0/200] [Batch 800/938] [D loss: 0.232573] [G loss: 2.164175]
[Epoch 1/200] [Batch 262/938] [D loss: 0.472385] [G loss: 3.074948]
[Epoch 1/200] [Batch 662/938] [D loss: 0.235233] [G loss: 1.870987]
[Epoch 2/200] [Batch 124/938] [D loss: 0.704411] [G loss: 5.161714]
[Epoch 2/200] [Batch 524/938] [D loss: 0.168405] [G loss: 2.434054]
[Epoch 2/200] [Batch 924/938] [D loss: 0.174887] [G loss: 1.849967]
[Epoch 3/200] [Batch 386/938] [D loss: 0.261089] [G loss: 1.197292]
[Epoch 3/200] [Batch 786/938] [D loss: 0.262792] [G loss: 2.042330]
[Epoch 4/200] [Batch 248/938] [D loss: 0.114492] [G loss: 3.086748]
[Epoch 4/200] [Batch 648/938] [D loss: 0.260728] [G loss: 2.599126]
[Epoch 5/200] [Batch 110/938] [D loss: 0.187764] [G loss: 1.882069]
[Epoch 5/200] [Batch 510/938] [D loss: 0.332962] [G loss: 1.148920]
[Epoch 5/200] [Batch 910/938] [D loss: 0.726651] [

[Epoch 51/200] [Batch 162/938] [D loss: 0.215064] [G loss: 1.847807]
[Epoch 51/200] [Batch 562/938] [D loss: 0.223668] [G loss: 2.226358]
[Epoch 52/200] [Batch 24/938] [D loss: 0.322233] [G loss: 1.263673]
[Epoch 52/200] [Batch 424/938] [D loss: 0.292803] [G loss: 1.770442]
[Epoch 52/200] [Batch 824/938] [D loss: 0.232745] [G loss: 1.447896]
[Epoch 53/200] [Batch 286/938] [D loss: 0.387485] [G loss: 1.678121]
[Epoch 53/200] [Batch 686/938] [D loss: 0.341823] [G loss: 2.499058]
[Epoch 54/200] [Batch 148/938] [D loss: 0.228127] [G loss: 2.517919]
[Epoch 54/200] [Batch 548/938] [D loss: 0.222957] [G loss: 2.354793]
[Epoch 55/200] [Batch 10/938] [D loss: 0.344666] [G loss: 3.054306]
[Epoch 55/200] [Batch 410/938] [D loss: 0.430465] [G loss: 3.233263]
[Epoch 55/200] [Batch 810/938] [D loss: 0.348050] [G loss: 1.567875]
[Epoch 56/200] [Batch 272/938] [D loss: 0.336219] [G loss: 1.203178]
[Epoch 56/200] [Batch 672/938] [D loss: 0.275431] [G loss: 1.809331]
[Epoch 57/200] [Batch 134/938] [D lo

[Epoch 101/200] [Batch 862/938] [D loss: 0.236271] [G loss: 1.882722]
[Epoch 102/200] [Batch 324/938] [D loss: 0.340078] [G loss: 2.157157]
[Epoch 102/200] [Batch 724/938] [D loss: 0.194775] [G loss: 2.210497]
[Epoch 103/200] [Batch 186/938] [D loss: 0.319068] [G loss: 2.242239]
[Epoch 103/200] [Batch 586/938] [D loss: 0.256296] [G loss: 1.803264]
[Epoch 104/200] [Batch 48/938] [D loss: 0.296626] [G loss: 2.318252]
[Epoch 104/200] [Batch 448/938] [D loss: 0.238718] [G loss: 2.452931]
[Epoch 104/200] [Batch 848/938] [D loss: 0.261896] [G loss: 3.041712]
[Epoch 105/200] [Batch 310/938] [D loss: 0.317626] [G loss: 2.463090]
[Epoch 105/200] [Batch 710/938] [D loss: 0.231684] [G loss: 2.010285]
[Epoch 106/200] [Batch 172/938] [D loss: 0.251759] [G loss: 2.187737]
[Epoch 106/200] [Batch 572/938] [D loss: 0.274817] [G loss: 2.075555]
[Epoch 107/200] [Batch 34/938] [D loss: 0.290542] [G loss: 2.586968]
[Epoch 107/200] [Batch 434/938] [D loss: 0.248760] [G loss: 2.296454]
[Epoch 107/200] [Batch

[Epoch 152/200] [Batch 224/938] [D loss: 0.321077] [G loss: 2.086518]
[Epoch 152/200] [Batch 624/938] [D loss: 0.333415] [G loss: 2.048362]
[Epoch 153/200] [Batch 86/938] [D loss: 0.328353] [G loss: 2.358866]
[Epoch 153/200] [Batch 486/938] [D loss: 0.429424] [G loss: 2.313407]
[Epoch 153/200] [Batch 886/938] [D loss: 0.352999] [G loss: 1.606602]
[Epoch 154/200] [Batch 348/938] [D loss: 0.333544] [G loss: 1.747373]
[Epoch 154/200] [Batch 748/938] [D loss: 0.281827] [G loss: 2.066483]
[Epoch 155/200] [Batch 210/938] [D loss: 0.313337] [G loss: 1.611214]
[Epoch 155/200] [Batch 610/938] [D loss: 0.268371] [G loss: 1.873268]
[Epoch 156/200] [Batch 72/938] [D loss: 0.371652] [G loss: 2.015953]
[Epoch 156/200] [Batch 472/938] [D loss: 0.299577] [G loss: 1.681162]
[Epoch 156/200] [Batch 872/938] [D loss: 0.365319] [G loss: 1.458398]
[Epoch 157/200] [Batch 334/938] [D loss: 0.404813] [G loss: 2.423235]
[Epoch 157/200] [Batch 734/938] [D loss: 0.317120] [G loss: 1.803173]
[Epoch 158/200] [Batch

## Code Snippets - Work in Progress

In [None]:
## get the norm of the gradient
#g_grad_norm = Variable(Tensor(1).fill_(0.0), requires_grad=True)
#for p in generator.parameters():
#    g_grad_norm += p.grad.data.norm(2).item()**2
#g_grad_norm = g_grad_norm **(0.5)

In [None]:
## look at 1d measures of UM
## look at raw MNIST data
## double check code
## look into top-subspace

Here's how to take the "Hessian vector product", i.e. $ \sum_j H_{ij} v_j $ for $H_{ij} = \partial_i \partial_j f$ and $v_j$ an arbitrary vector. 

This came from this PyTorch help forum post: https://discuss.pytorch.org/t/calculating-hessian-vector-product/11240/4

In [None]:
## first, let v be an arbitrary vector:
v = Variable(torch.Tensor([1, 1]))
x = Variable(torch.Tensor([1, 1]), requires_grad=True)
f = 3*x[0]**2 + 4*x[0]*x[1] + x[1]**2
grad_f, = torch.autograd.grad(f, x, create_graph=True)
z = grad_f @ v
z.backward()
print(x.grad)

## now, let v be x - note that now the answer changes because the gradient also hits v
x = Variable(torch.Tensor([1, 1]), requires_grad=True)
v = x
f = 3*x[0]**2 + 4*x[0]*x[1] + x[1]**2
grad_f, = torch.autograd.grad(f, x, create_graph=True)
z = grad_f @ v
z.backward()
print(x.grad)

## lastly, clone + detach v so that the derivative does not hit it, even though it is related to x
x = Variable(torch.Tensor([1, 1]), requires_grad=True)
v = x.clone().detach()
f = 3*x[0]**2 + 4*x[0]*x[1] + x[1]**2
grad_f, = torch.autograd.grad(f, x, create_graph=True)
z = grad_f @ v
z.backward()
print(x.grad)

In [None]:
## lastly, clone + detach v so that the derivative does not hit it, even though it is related to x
x = Variable(torch.Tensor([1, 1]), requires_grad=True)
f = 3*x[0]**2 + 4*x[0]*x[1] + x[1]**2
grad_f, = torch.autograd.grad(f, x, create_graph=True)
v = grad_f.clone().detach()
z = grad_f @ v
z.backward()
print(x.grad)

In [None]:
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))
gen_imgs = generator(z)
g_loss = adversarial_loss(discriminator(gen_imgs), valid)

theta_g_tmp = []
for param in generator.parameters():
    theta_g_tmp.append(param.view(-1))
theta_g = torch.cat(theta_g_tmp)

In [None]:
#grad_g  = torch.autograd.grad(g_loss, generator.parameters(), create_graph=True)
grad_g  = torch.autograd.grad(g_loss, theta_g, create_graph=True)

In [None]:

theta_g_tmp = []
for param in generator.parameters():
    theta_g_tmp.append(param.view(-1))
theta_g = torch.cat(theta_g_tmp)

In [None]:
theta_g

In [None]:
z = grad_g @ theta_g
z.backward()
print(theta_g.grad)

In [None]:
linear = nn.Linear(10, 20)
input = torch.randn(1, 10)
out = linear(input).sum()
grads = torch.autograd.grad([out], linear.parameters(), create_graph=True)
flatten = torch.cat([g.reshape(-1) for g in grads if g is not None])
x = torch.randn_like(flatten)
print(flatten.shape)
flatten2 = Variable(flatten.data, requires_grad=True)
hvps = torch.autograd.grad([flatten2 @ x], linear.parameters(), allow_unused=True)

In [None]:
hvps

In [None]:
flatten

In [None]:
x = torch.randn_like(flatten)
print(flatten.shape) ## torch.Size([1792])
x2 = Variable(x.data, requires_grad=True)

In [None]:
hvps = torch.autograd.grad([flatten @ x2], conv.parameters(), allow_unused=True)

In [None]:
hvps

In [None]:
print(hvps[1]) ## None

In [None]:
flatten2 = torch.cat([g.reshape(-1) for g in hvps if g is not None])
print(flatten2.shape) ## torch.Size([1728])

In [None]:
## a simple neural network
linear = nn.Linear(10, 20)
x = torch.randn(1, 10)
y = linear(x).sum()

## compute the gradient and make a copy that is detached from the graph
grad = torch.autograd.grad(y, linear.parameters(), create_graph=True)
v = grad.clone().detach()

## compute the Hessian vector product
z = grad @ v
z.backward()

In [None]:
## lastly, clone + detach v so that the derivative does not hit it, even though it is related to x
x = Variable(torch.Tensor([1, 1]), requires_grad=True)
v = x.clone().detach()
f = 3*x[0]**2 + 4*x[0]*x[1] + x[1]**2
grad_f, = torch.autograd.grad(f, x, create_graph=True)
z = grad_f @ v
z.backward()

In [None]:
grad.view

In [None]:
hvps = torch.autograd.grad([flatten2 @ x], linear.parameters(), allow_unused=True)

In [None]:
flatten = torch.cat([g.reshape(-1) for g in grads if g is not None])
x = torch.randn_like(flatten)
print(flatten.shape)
flatten2 = Variable(flatten.data, requires_grad=True)
hvps = torch.autograd.grad([flatten2 @ x], linear.parameters(), allow_unused=True)

In [None]:
## lastly, clone + detach v so that the derivative does not hit it, even though it is related to x
x = Variable(torch.Tensor([1, 1]), requires_grad=True)
v = x.clone().detach()
f = 3*x[0]**2 + 4*x[0]*x[1] + x[1]**2
grad_f, = torch.autograd.grad(f, x, create_graph=True)
z = grad_f @ v
z.backward()
print(x.grad)

In [None]:
grad_f @ v

In [None]:
v