In [1]:
## Adapted from the GAN implementation in the PyTorch-GAN model zoo:
## https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py

In [2]:
import argparse
import os
import numpy as np
import math
import matplotlib.pyplot as plt
import seaborn as sns; sns.set_style("white")
%matplotlib inline

import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

In [3]:
os.makedirs('images', exist_ok=True)

n_epochs = 200 #'number of epochs of training'
batch_size = 64 #'size of the batches'
lr = 0.0002 #'adam: learning rate'
b1 = 0.5 #'adam: decay of first order momentum of gradient'
b2 = 0.999 #'adam: decay of first order momentum of gradient'
n_cpu = 4 #'number of cpu threads to use during batch generation'
latent_dim = 100 #'dimensionality of the latent space'
img_size = 28 #'size of each image dimension'
channels = 1 #'number of image channels'
sample_interval=400 #'interval betwen image samples'

In [4]:
img_shape = (channels, img_size, img_size)

cuda = True if torch.cuda.is_available() else False

In [5]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

In [6]:
# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Configure data loader
os.makedirs('../../data/mnist', exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('../../data/mnist', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                   ])),
    batch_size=batch_size, shuffle=True)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [7]:
def distance(x,y):
    """
    given 2 images, x, y, return the normalized l2 distance between them,
    normalized so that the max distance is 1
    """
    return np.sqrt(np.sum((x - y)**2))/56


def triangle_distances(x,y,z):
    """
    return the sorted distances of the triangle formed by the points (x,y,z)
    """
    return sorted([distance(x,y), distance(x,z), distance(y,z)])


def triplet_sample(X):
    """
    given an array of images X, return a triplet
    """
    while True:
        [i,j,k] = list(np.random.randint(0, X.shape[0], 3))
        if len({i,j,k}) == 3:
            return [X[i], X[j], X[k]]

        
def angles(x, y, z):
    """
    Given a triangle-defining triplet, return the sorted angles
    """
    [dxy, dxz, dyz] = triangle_distances(x,y,z)
    theta_xy = np.arccos((-dxy**2 + dxz**2 + dyz**2)/(2*dxz*dyz))
    theta_xz = np.arccos((dxy**2 - dxz**2 + dyz**2)/(2*dxy*dyz))
    theta_yz = np.arccos((dxy**2 + dxz**2 - dyz**2)/(2*dxy*dxz))
    return sorted([theta_xy, theta_xz, theta_yz])


def triangle_distributions(X, Num):
    """
    Given an array of samples X, generate two arrays of shape (Num, 2).
    
    Array 1: the 2 dimensions are [dmid-dmin, dmax-dmid]
    Array 2: the 2 dimensions are [theta_min/theta_max, theta_min/theta_mid]
    """

    K_distances = np.zeros((Num,2))
    K_angles = np.zeros((Num,2))
    
    for i in range(Num):
        [x, y, z] = triplet_sample(X)
        [d_min, d_mid, d_max] = triangle_distances(x,y,z)
        [theta_min, theta_mid, theta_max] = angles(x,y,z)
        
        K_distances[i,0] = d_mid - d_min
        K_distances[i,1] = d_max - d_mid
        
        K_angles[i,0] = theta_min/theta_mid
        K_angles[i,1] = theta_min/theta_max

        
    return [K_distances, K_angles]

In [8]:
# ----------
#  Training
# ----------

for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:# and batches_done != 0:
            print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))
            
            batches_done_str = str(batches_done).zfill(8)
            
            # plot a grid of samples
            x = gen_imgs.data[:,0,:,:].cpu()
            plt.figure(figsize=(6, 6))
            plt.suptitle("Sampled Images, Batch_Num: " + str(batches_done), fontsize=16)
            for j in range(64):
                plt.subplot(8, 8, j + 1)
                plt.imshow(x[j].reshape([28,28]), cmap='gray')
                plt.xticks(())
                plt.yticks(())
            plt.savefig('images/sample_images_' + batches_done_str + '.png')
            plt.close()
            #plt.show()

            ## plot the 2d triangle distance density plot
            Z = Variable(Tensor(np.random.normal(0, 1, (1000, latent_dim))))
            gen_imgs = generator(Z)
            X = gen_imgs.data[:,0,:,:].cpu().numpy()

            [K_distances, K_angles] = triangle_distributions(X, 10000)
            fig,ax = plt.subplots(1,1)
            ax1 = sns.kdeplot(K_distances[:,0], K_distances[:,1])
            ax1.set_xlim([0, 1.0])
            ax1.set_ylim([0, 1.0])
            ax1.set_xlabel(r"$d_{mid}-d_{min}$", fontsize=16)
            ax1.set_ylabel(r"$d_{max}-d_{mid}$", fontsize=16)
            ax1.set_title("Distance Density Plot, Batch_Num: " + str(batches_done), fontsize=16)
            plt.savefig('images/distance_plot_' + batches_done_str + '.png')
            plt.close()
            #plt.show()
            
            ## plot the 2d triangle angle density plot
            fig,ax = plt.subplots(1,1)
            ax2 = sns.kdeplot(K_angles[:,0], K_angles[:,1])
            ax2.set_xlim([0, 1.0])
            ax2.set_ylim([0, 1.0])
            ax2.set_xlabel(r"$\theta_{min}/\theta_{mid}$", fontsize=16)
            ax2.set_ylabel(r"$\theta_{min}/\theta_{max}$", fontsize=16)
            ax2.set_title("Angle Density Plot, Batch_Num: " + str(batches_done), fontsize=16)
            plt.savefig('images/triangle_plot_' + batches_done_str + '.png')
            plt.close()            
            #plt.show()

[Epoch 0/200] [Batch 0/938] [D loss: 0.653889] [G loss: 0.718994]


  return np.add.reduce(sorted[indexer] * weights, axis=axis) / sumval


[Epoch 0/200] [Batch 400/938] [D loss: 0.483200] [G loss: 0.578756]
[Epoch 0/200] [Batch 800/938] [D loss: 0.342723] [G loss: 0.822985]
[Epoch 1/200] [Batch 262/938] [D loss: 0.588816] [G loss: 2.485917]
[Epoch 1/200] [Batch 662/938] [D loss: 0.484637] [G loss: 0.761122]
[Epoch 2/200] [Batch 124/938] [D loss: 0.290944] [G loss: 1.087301]
[Epoch 2/200] [Batch 524/938] [D loss: 0.607129] [G loss: 0.460838]
[Epoch 2/200] [Batch 924/938] [D loss: 0.162961] [G loss: 1.576303]
[Epoch 3/200] [Batch 386/938] [D loss: 0.264576] [G loss: 3.198326]
[Epoch 3/200] [Batch 786/938] [D loss: 0.198144] [G loss: 2.381291]
[Epoch 4/200] [Batch 248/938] [D loss: 0.323987] [G loss: 3.146493]
[Epoch 4/200] [Batch 648/938] [D loss: 0.308600] [G loss: 1.853544]
[Epoch 5/200] [Batch 110/938] [D loss: 0.309909] [G loss: 1.326990]
[Epoch 5/200] [Batch 510/938] [D loss: 0.235691] [G loss: 3.604652]
[Epoch 5/200] [Batch 910/938] [D loss: 0.380048] [G loss: 3.296381]
[Epoch 6/200] [Batch 372/938] [D loss: 0.231563]

[Epoch 51/200] [Batch 562/938] [D loss: 0.422541] [G loss: 1.475723]
[Epoch 52/200] [Batch 24/938] [D loss: 0.333872] [G loss: 1.893930]
[Epoch 52/200] [Batch 424/938] [D loss: 0.356275] [G loss: 2.178028]
[Epoch 52/200] [Batch 824/938] [D loss: 0.487356] [G loss: 1.890546]
[Epoch 53/200] [Batch 286/938] [D loss: 0.377878] [G loss: 1.881037]
[Epoch 53/200] [Batch 686/938] [D loss: 0.432842] [G loss: 1.327764]
[Epoch 54/200] [Batch 148/938] [D loss: 0.324685] [G loss: 1.905217]
[Epoch 54/200] [Batch 548/938] [D loss: 0.358527] [G loss: 1.368164]
[Epoch 55/200] [Batch 10/938] [D loss: 0.352783] [G loss: 1.695581]
[Epoch 55/200] [Batch 410/938] [D loss: 0.438742] [G loss: 1.783901]
[Epoch 55/200] [Batch 810/938] [D loss: 0.255529] [G loss: 1.977053]
[Epoch 56/200] [Batch 272/938] [D loss: 0.341433] [G loss: 2.181850]
[Epoch 56/200] [Batch 672/938] [D loss: 0.437403] [G loss: 1.547978]
[Epoch 57/200] [Batch 134/938] [D loss: 0.376669] [G loss: 1.565433]
[Epoch 57/200] [Batch 534/938] [D lo

[Epoch 102/200] [Batch 324/938] [D loss: 0.405372] [G loss: 1.351003]
[Epoch 102/200] [Batch 724/938] [D loss: 0.411918] [G loss: 1.354909]
[Epoch 103/200] [Batch 186/938] [D loss: 0.461757] [G loss: 1.696794]
[Epoch 103/200] [Batch 586/938] [D loss: 0.328002] [G loss: 1.896736]
[Epoch 104/200] [Batch 48/938] [D loss: 0.444186] [G loss: 2.072689]
[Epoch 104/200] [Batch 448/938] [D loss: 0.438934] [G loss: 1.736473]
[Epoch 104/200] [Batch 848/938] [D loss: 0.364862] [G loss: 1.614417]
[Epoch 105/200] [Batch 310/938] [D loss: 0.356435] [G loss: 2.200775]
[Epoch 105/200] [Batch 710/938] [D loss: 0.422420] [G loss: 1.391967]
[Epoch 106/200] [Batch 172/938] [D loss: 0.422086] [G loss: 2.312851]
[Epoch 106/200] [Batch 572/938] [D loss: 0.398597] [G loss: 2.369773]
[Epoch 107/200] [Batch 34/938] [D loss: 0.409924] [G loss: 1.922763]
[Epoch 107/200] [Batch 434/938] [D loss: 0.414433] [G loss: 1.816016]
[Epoch 107/200] [Batch 834/938] [D loss: 0.354329] [G loss: 1.474966]
[Epoch 108/200] [Batch

[Epoch 152/200] [Batch 624/938] [D loss: 0.431338] [G loss: 1.332425]
[Epoch 153/200] [Batch 86/938] [D loss: 0.411982] [G loss: 1.589581]
[Epoch 153/200] [Batch 486/938] [D loss: 0.452257] [G loss: 2.041316]
[Epoch 153/200] [Batch 886/938] [D loss: 0.429951] [G loss: 1.973903]
[Epoch 154/200] [Batch 348/938] [D loss: 0.505281] [G loss: 1.622912]
[Epoch 154/200] [Batch 748/938] [D loss: 0.426713] [G loss: 1.238048]
[Epoch 155/200] [Batch 210/938] [D loss: 0.343675] [G loss: 1.652057]
[Epoch 155/200] [Batch 610/938] [D loss: 0.468899] [G loss: 1.594764]
[Epoch 156/200] [Batch 72/938] [D loss: 0.437651] [G loss: 1.926811]
[Epoch 156/200] [Batch 472/938] [D loss: 0.344062] [G loss: 1.640561]
[Epoch 156/200] [Batch 872/938] [D loss: 0.504861] [G loss: 1.521666]
[Epoch 157/200] [Batch 334/938] [D loss: 0.432307] [G loss: 1.668160]
[Epoch 157/200] [Batch 734/938] [D loss: 0.474866] [G loss: 1.387235]
[Epoch 158/200] [Batch 196/938] [D loss: 0.417040] [G loss: 1.441964]
[Epoch 158/200] [Batch

In [9]:
## look at 1d measures of UM
## look at raw MNIST data
## double check code
## look into top-subspace