In [1]:
## Adapted from the GAN implementation in the PyTorch-GAN model zoo:
## https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py

In [2]:
import argparse
import os
import numpy as np
import math
import matplotlib.pyplot as plt
import seaborn as sns; sns.set_style("white")
%matplotlib inline

import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

In [3]:
os.makedirs('images', exist_ok=True)

n_epochs = 200 #'number of epochs of training'
batch_size = 64 #'size of the batches'
lr = 0.0002 #'adam: learning rate'
b1 = 0.5 #'adam: decay of first order momentum of gradient'
b2 = 0.999 #'adam: decay of first order momentum of gradient'
n_cpu = 4 #'number of cpu threads to use during batch generation'
latent_dim = 100 #'dimensionality of the latent space'
img_size = 28 #'size of each image dimension'
channels = 1 #'number of image channels'
sample_interval=400 #'interval betwen image samples'

In [4]:
torch.cuda.is_available()

False

In [5]:
img_shape = (channels, img_size, img_size)

cuda = True if torch.cuda.is_available() else False

In [6]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

In [7]:
# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Configure data loader
os.makedirs('../../data/mnist', exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('../../data/mnist', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                   ])),
    batch_size=batch_size, shuffle=True)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [8]:
## useful functions for collecting triangle statistics

def triplet_sample(X):
    """
    given an array of images X, return a triplet
    """
    while True:
        [i,j,k] = list(np.random.randint(0, X.shape[0], 3))
        if len({i,j,k}) == 3:
            return [X[i], X[j], X[k]]
        
        
def distance_L2(x,y):
    """
    given 2 images, x, y, return the normalized l2 distance between them,
    normalized so that the max distance is 1
    """
    return np.sqrt(np.sum((x - y)**2))/56


def distance_H(x,y):
    """
    given 2 images, x, y, return the normalized Hamming distance between them,
    """
    x_binary = np.asarray(x.reshape(784)>0)
    y_binary = np.asarray(y.reshape(784)>0) 
    return np.count_nonzero(x_binary != y_binary)/len(x_binary)


def triangle_distances_l2(x,y,z):
    """
    return the sorted l2 distances of the triangle formed by the points (x,y,z)
    """
    return sorted([distance_L2(x,y), distance_L2(x,z), distance_L2(y,z)])


def triangle_distances_H(x,y,z):
    """
    return the sorted normalized Hamming distances of the triangle formed by the points (x,y,z)
    """
    return sorted([distance_H(x,y), distance_H(x,z), distance_H(y,z)])


def angles(dxy, dxz, dyz):
    """
    Given the 3 distances of a triangle, return the sorted angles
    """
    theta_xy = np.arccos((-dxy**2 + dxz**2 + dyz**2)/(2*dxz*dyz))
    theta_xz = np.arccos((dxy**2 - dxz**2 + dyz**2)/(2*dxy*dyz))
    theta_yz = np.arccos((dxy**2 + dxz**2 - dyz**2)/(2*dxy*dxz))
    return sorted([theta_xy, theta_xz, theta_yz])


def triangle_distributions(X, Num):
    """
    Given an array of samples 4, generate two arrays of shape (Num, 2).
    
    Array 1: the 2 dimensions are [dmid-dmin, dmax-dmid] (l2 distance)
    Array 2: the 2 dimensions are [theta_min/theta_max, theta_min/theta_mid] (l2 distance)
    Array 3: the 2 dimensions are [dmid-dmin, dmax-dmid] (Hamming distance)
    Array 4: the 2 dimensions are [theta_min/theta_max, theta_min/theta_mid] (Hamming distance)
    """

    K_distances_l2 = np.zeros((Num,2))
    K_angles_l2 = np.zeros((Num,2))
    K_distances_H = np.zeros((Num,2))
    K_angles_H = np.zeros((Num,2))
    
    for i in range(Num):
        [x, y, z] = triplet_sample(X)
        
        ## l2 distances
        [d_min, d_mid, d_max] = triangle_distances_l2(x,y,z)
        [theta_min, theta_mid, theta_max] = angles(d_min, d_mid, d_max)        

        K_distances_l2[i,0] = d_mid - d_min
        K_distances_l2[i,1] = d_max - d_mid        
        K_angles_l2[i,0] = theta_min/theta_mid
        K_angles_l2[i,1] = theta_min/theta_max

        ## Hamming distances
        [d_min, d_mid, d_max] = triangle_distances_l2(x,y,z)
        [theta_min, theta_mid, theta_max] = angles(d_min, d_mid, d_max)        

        K_distances_H[i,0] = d_mid - d_min
        K_distances_H[i,1] = d_max - d_mid        
        K_angles_H[i,0] = theta_min/theta_mid
        K_angles_H[i,1] = theta_min/theta_max

        
    return [K_distances_l2, K_angles_l2, K_distances_H, K_angles_H]

### Training

In [None]:
g_loss_list = []
d_loss_list = []
g_overlap_list = []
d_overlap_list = []

for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)
        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        optimizer_G.zero_grad()
        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))
        # Generate a batch of images
        gen_imgs = generator(z)
        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()
        # -----------------
        
        # ---------------------
        #  Train Discriminator
        optimizer_D.zero_grad()
        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()
        # ---------------------

        # ---------------------
        # Monitor Progress
        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:# and batches_done != 0:
            print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % \
                   (epoch, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))
            
            batches_done_str = str(batches_done).zfill(8)
            
            ## save the losses
            g_loss_list.append(g_loss.item())
            d_loss_list.append(d_loss.item())
            
            #----------------------
            # Generator Grad Overlap
            gen_imgs = generator(z)
 
            ## loss computed via built-in CE function
            g_loss = adversarial_loss(discriminator(gen_imgs), valid)

            ## loss computed by scratch (needed to compute 2nd deriv)
            g_loss_homebrew = - torch.log(discriminator(gen_imgs)).mean()
            print("g loss comparison: ", g_loss, g_loss_homebrew)
            
            ## this computes g_i, ||g||^2
            g_loss_grad = torch.autograd.grad(g_loss_homebrew, generator.parameters(), create_graph=True)
            g_loss_grad_norm_sq = 0
            for g in g_loss_grad:
                g_loss_grad_norm_sq = g_loss_grad_norm_sq + g.pow(2).sum()

            ## this computes 2(g^T H)_i, 4 ||H g||^2
            g_loss_grad2 = torch.autograd.grad(g_loss_grad_norm_sq, generator.parameters(), create_graph=True)
            g_loss_grad2_norm_sq = 0
            for g in g_loss_grad2:
                g_loss_grad2_norm_sq = g_loss_grad2_norm_sq + g.pow(2).sum()

            ## this computes 2(g^T H g)
            g_overlap_raw = 0
            for i in range(len(g_loss_grad2)):
                g_overlap_raw = g_overlap_raw + torch.mul(g_loss_grad[i], g_loss_grad2[i]).sum() 

            ## finally, compute the overlap (g^T H g)/(||g|| ||H g||)
            g_overlap = g_overlap_raw/torch.sqrt(g_loss_grad_norm_sq)/torch.sqrt(g_loss_grad2_norm_sq)
            g_overlap_list.append(g_overlap)
            #----------------------

            #----------------------
            # Discriminator Grad Overlap
            gen_imgs = generator(z)

            ## loss computed via built-in CE function
            real_loss = adversarial_loss(discriminator(real_imgs), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2

            ## loss computed by scratch (needed to compute 2nd deriv)
            real_loss_homebrew = - torch.log(discriminator(real_imgs)).mean()
            fake_loss_homebrew = - torch.log(1.0-discriminator(gen_imgs.detach())).mean()
            d_loss_homebrew = (real_loss_homebrew + fake_loss_homebrew) / 2            
            print("d loss comparison: ", d_loss, d_loss_homebrew)

            ## this computes g_i, ||g||^2
            d_loss_grad = torch.autograd.grad(g_loss_homebrew, discriminator.parameters(), create_graph=True)
            d_loss_grad_norm_sq = 0
            for g in d_loss_grad:
                d_loss_grad_norm_sq = d_loss_grad_norm_sq + g.pow(2).sum()

            ## this computes 2(g^T H)_i, 4 ||H g||^2
            d_loss_grad2 = torch.autograd.grad(d_loss_grad_norm_sq, discriminator.parameters(), create_graph=True)
            d_loss_grad2_norm_sq = 0
            for g in d_loss_grad2:
                d_loss_grad2_norm_sq = d_loss_grad2_norm_sq + g.pow(2).sum()

            ## this computes 2(g^T H g)
            d_overlap_raw = 0
            for i in range(len(d_loss_grad2)):
                d_overlap_raw = d_overlap_raw + torch.mul(d_loss_grad[i], d_loss_grad2[i]).sum() 

            ## finally, compute the overlap (g^T H g)/(||g|| ||H g||)
            d_overlap = d_overlap_raw/torch.sqrt(d_loss_grad_norm_sq)/torch.sqrt(d_loss_grad2_norm_sq)
            d_overlap_list.append(d_overlap)
            #----------------------

            ## print the overlaps
            print('g_overlap: %.3f, d_overlap: %.3f' % (g_overlap, d_overlap))

            # plot a grid of samples
            x = gen_imgs.data[:,0,:,:].cpu()
            plt.figure(figsize=(6, 6))
            plt.suptitle("Sampled Images, Batch_Num: " + str(batches_done), fontsize=16)
            for j in range(64):
                plt.subplot(8, 8, j + 1)
                plt.imshow(x[j].reshape([28,28]), cmap='gray')
                plt.xticks(())
                plt.yticks(())
            plt.savefig('images/sample_images_' + batches_done_str + '.png')
            plt.close()
            #plt.show()

            Z = Variable(Tensor(np.random.normal(0, 1, (1000, latent_dim))))
            gen_imgs = generator(Z)
            X = gen_imgs.data[:,0,:,:].cpu().numpy()

            [K_distances_l2, K_angles_l2, K_distances_H, K_angles_H] = triangle_distributions(X, 10000)
            
            ## plot the 2d triangle distance density plot (l2 norm)
            fig,ax = plt.subplots(1,1)
            ax1 = sns.kdeplot(K_distances_l2[:,0], K_distances_l2[:,1])
            ax1.set_xlim([0, 1.0])
            ax1.set_ylim([0, 1.0])
            ax1.set_xlabel(r"$d_{mid}-d_{min}$", fontsize=16)
            ax1.set_ylabel(r"$d_{max}-d_{mid}$", fontsize=16)
            ax1.set_title("Distance Density Plot (l2 distance), Batch_Num: " + str(batches_done), fontsize=16)
            plt.savefig('images/distance_plot_l2_' + batches_done_str + '.png')
            plt.close()
            #plt.show()
            
            ## plot the 2d triangle angle density plot (l2 norm)
            fig,ax = plt.subplots(1,1)
            ax2 = sns.kdeplot(K_angles_l2[:,0], K_angles_l2[:,1])
            ax2.set_xlim([0, 1.0])
            ax2.set_ylim([0, 1.0])
            ax2.set_xlabel(r"$\theta_{min}/\theta_{mid}$", fontsize=16)
            ax2.set_ylabel(r"$\theta_{min}/\theta_{max}$", fontsize=16)
            ax2.set_title("Angle Density Plot (l2 distance), Batch_Num: " + str(batches_done), fontsize=16)
            plt.savefig('images/triangle_plot_l2_' + batches_done_str + '.png')
            plt.close()            
            #plt.show()
    
            ## plot the 2d triangle distance density plot (H norm)
            fig,ax = plt.subplots(1,1)
            ax1 = sns.kdeplot(K_distances_H[:,0], K_distances_H[:,1])
            ax1.set_xlim([0, 1.0])
            ax1.set_ylim([0, 1.0])
            ax1.set_xlabel(r"$d_{mid}-d_{min}$", fontsize=16)
            ax1.set_ylabel(r"$d_{max}-d_{mid}$", fontsize=16)
            ax1.set_title("Distance Density Plot (H distance), Batch_Num: " + str(batches_done), fontsize=16)
            plt.savefig('images/distance_plot_H_' + batches_done_str + '.png')
            plt.close()
            #plt.show()
            
            ## plot the 2d triangle angle density plot (H norm)
            fig,ax = plt.subplots(1,1)
            ax2 = sns.kdeplot(K_angles_H[:,0], K_angles_H[:,1])
            ax2.set_xlim([0, 1.0])
            ax2.set_ylim([0, 1.0])
            ax2.set_xlabel(r"$\theta_{min}/\theta_{mid}$", fontsize=16)
            ax2.set_ylabel(r"$\theta_{min}/\theta_{max}$", fontsize=16)
            ax2.set_title("Angle Density Plot (H distance), Batch_Num: " + str(batches_done), fontsize=16)
            plt.savefig('images/triangle_plot_H_' + batches_done_str + '.png')
            plt.close()            
            #plt.show()
        # ---------------------

[Epoch 0/200] [Batch 0/938] [D loss: 0.368506] [G loss: 0.942657]
g loss comparison:  tensor(4.9051, grad_fn=<BinaryCrossEntropyBackward>) tensor(4.9051, grad_fn=<NegBackward>)
d loss comparison:  tensor(0.5938, grad_fn=<DivBackward0>) tensor(0.5938, grad_fn=<DivBackward0>)
g_overlap: 0.187, d_overlap: 0.674
[Epoch 0/200] [Batch 400/938] [D loss: 0.279333] [G loss: 2.732106]
g loss comparison:  tensor(1.6974, grad_fn=<BinaryCrossEntropyBackward>) tensor(1.6974, grad_fn=<NegBackward>)
d loss comparison:  tensor(0.2091, grad_fn=<DivBackward0>) tensor(0.2091, grad_fn=<DivBackward0>)
g_overlap: 0.789, d_overlap: 0.995
[Epoch 0/200] [Batch 800/938] [D loss: 0.573991] [G loss: 5.688356]
g loss comparison:  tensor(3.1213, grad_fn=<BinaryCrossEntropyBackward>) tensor(3.1213, grad_fn=<NegBackward>)
d loss comparison:  tensor(0.1260, grad_fn=<DivBackward0>) tensor(0.1260, grad_fn=<DivBackward0>)
g_overlap: 0.447, d_overlap: 0.964
[Epoch 1/200] [Batch 262/938] [D loss: 0.470148] [G loss: 0.731117

[Epoch 11/200] [Batch 482/938] [D loss: 0.290539] [G loss: 1.648951]
g loss comparison:  tensor(2.2932, grad_fn=<BinaryCrossEntropyBackward>) tensor(2.2932, grad_fn=<NegBackward>)
d loss comparison:  tensor(0.2651, grad_fn=<DivBackward0>) tensor(0.2651, grad_fn=<DivBackward0>)
g_overlap: 0.548, d_overlap: 0.993
[Epoch 11/200] [Batch 882/938] [D loss: 0.169859] [G loss: 2.094741]
g loss comparison:  tensor(2.4817, grad_fn=<BinaryCrossEntropyBackward>) tensor(2.4817, grad_fn=<NegBackward>)
d loss comparison:  tensor(0.1568, grad_fn=<DivBackward0>) tensor(0.1568, grad_fn=<DivBackward0>)
g_overlap: 0.528, d_overlap: 0.987
[Epoch 12/200] [Batch 344/938] [D loss: 0.159646] [G loss: 2.513616]
g loss comparison:  tensor(1.8554, grad_fn=<BinaryCrossEntropyBackward>) tensor(1.8554, grad_fn=<NegBackward>)
d loss comparison:  tensor(0.1706, grad_fn=<DivBackward0>) tensor(0.1706, grad_fn=<DivBackward0>)
g_overlap: 0.449, d_overlap: 0.990
[Epoch 12/200] [Batch 744/938] [D loss: 0.202407] [G loss: 2.

[Epoch 23/200] [Batch 26/938] [D loss: 0.280291] [G loss: 2.726737]
g loss comparison:  tensor(1.5716, grad_fn=<BinaryCrossEntropyBackward>) tensor(1.5716, grad_fn=<NegBackward>)
d loss comparison:  tensor(0.2389, grad_fn=<DivBackward0>) tensor(0.2389, grad_fn=<DivBackward0>)
g_overlap: 0.574, d_overlap: 0.991
[Epoch 23/200] [Batch 426/938] [D loss: 0.264271] [G loss: 1.946367]
g loss comparison:  tensor(1.7748, grad_fn=<BinaryCrossEntropyBackward>) tensor(1.7748, grad_fn=<NegBackward>)
d loss comparison:  tensor(0.2559, grad_fn=<DivBackward0>) tensor(0.2559, grad_fn=<DivBackward0>)
g_overlap: 0.595, d_overlap: 0.983
[Epoch 23/200] [Batch 826/938] [D loss: 0.192806] [G loss: 1.845007]
g loss comparison:  tensor(1.6536, grad_fn=<BinaryCrossEntropyBackward>) tensor(1.6536, grad_fn=<NegBackward>)
d loss comparison:  tensor(0.2147, grad_fn=<DivBackward0>) tensor(0.2147, grad_fn=<DivBackward0>)
g_overlap: 0.626, d_overlap: 0.980
[Epoch 24/200] [Batch 288/938] [D loss: 0.304218] [G loss: 1.8

g_overlap: 0.429, d_overlap: 0.982
[Epoch 34/200] [Batch 508/938] [D loss: 0.357055] [G loss: 2.089746]
g loss comparison:  tensor(1.2457, grad_fn=<BinaryCrossEntropyBackward>) tensor(1.2457, grad_fn=<NegBackward>)
d loss comparison:  tensor(0.3538, grad_fn=<DivBackward0>) tensor(0.3538, grad_fn=<DivBackward0>)
g_overlap: 0.423, d_overlap: 0.996
[Epoch 34/200] [Batch 908/938] [D loss: 0.227212] [G loss: 3.089481]
g loss comparison:  tensor(2.6963, grad_fn=<BinaryCrossEntropyBackward>) tensor(2.6963, grad_fn=<NegBackward>)
d loss comparison:  tensor(0.1851, grad_fn=<DivBackward0>) tensor(0.1851, grad_fn=<DivBackward0>)
g_overlap: 0.439, d_overlap: 0.972
[Epoch 35/200] [Batch 370/938] [D loss: 0.324291] [G loss: 2.343469]
g loss comparison:  tensor(1.2226, grad_fn=<BinaryCrossEntropyBackward>) tensor(1.2226, grad_fn=<NegBackward>)
d loss comparison:  tensor(0.3205, grad_fn=<DivBackward0>) tensor(0.3205, grad_fn=<DivBackward0>)
g_overlap: 0.053, d_overlap: 0.984
[Epoch 35/200] [Batch 770/

g_overlap: 0.537, d_overlap: 0.974
[Epoch 46/200] [Batch 52/938] [D loss: 0.272010] [G loss: 1.894344]
g loss comparison:  tensor(2.2324, grad_fn=<BinaryCrossEntropyBackward>) tensor(2.2324, grad_fn=<NegBackward>)
d loss comparison:  tensor(0.2406, grad_fn=<DivBackward0>) tensor(0.2406, grad_fn=<DivBackward0>)
g_overlap: 0.318, d_overlap: 0.981
[Epoch 46/200] [Batch 452/938] [D loss: 0.253375] [G loss: 2.697104]
g loss comparison:  tensor(1.5322, grad_fn=<BinaryCrossEntropyBackward>) tensor(1.5322, grad_fn=<NegBackward>)
d loss comparison:  tensor(0.2383, grad_fn=<DivBackward0>) tensor(0.2383, grad_fn=<DivBackward0>)
g_overlap: 0.569, d_overlap: 0.985
[Epoch 46/200] [Batch 852/938] [D loss: 0.448799] [G loss: 3.563986]
g loss comparison:  tensor(1.3309, grad_fn=<BinaryCrossEntropyBackward>) tensor(1.3309, grad_fn=<NegBackward>)
d loss comparison:  tensor(0.3159, grad_fn=<DivBackward0>) tensor(0.3159, grad_fn=<DivBackward0>)
g_overlap: 0.397, d_overlap: 0.987
[Epoch 47/200] [Batch 314/9

In [None]:
g_overlap_list