In [1]:
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset 
from torch.utils.data import DataLoader
import copy
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from collections import defaultdict
import pdb
import torch.distributions as tdist
import matplotlib
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from mpl_toolkits.mplot3d import Axes3D  

from matplotlib import cm

In [2]:
def gmm_sampler(num_samples, num_mixtures, mean, cov, mix_coeffs):
    z = np.random.multinomial(num_samples, mix_coeffs)

    samples = np.zeros(shape=[num_samples, len(mean[0])])
    target = np.zeros(shape=[num_samples])
    
    i_start = 0
    data = []
    for i in range(len(mix_coeffs)):
        i_end = i_start + z[i]
        samples[i_start:i_end, :] = np.random.multivariate_normal(
            mean=np.array(mean)[i, :],
            cov=np.diag(np.array(cov)[i, :]),            
            size=z[i])
        
        target[i_start:i_end] = i

        for j in range(i_start,i_end):
            data.append({"x":samples[j],"class":target[j]})
        i_start = i_end
   
    return data


class SynthDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, train_set_len, train_set):
        self.train_set_len = train_set_len
        self.train_set = train_set
    
    def __len__(self):
        return self.train_set_len  

    def __getitem__(self, idx):
        return self.train_set[idx]
        


In [3]:
nz = 10
ngf = 64
ndf = 64

inputdim=2
num_mixtures = 5
radius = 2.0
std = 0.01
thetas = np.linspace(0, 2 * np.pi, num_mixtures + 1)[:num_mixtures]
xs, ys = radius * np.sin(thetas), radius * np.cos(thetas)
mix_coeffs = [1./num_mixtures for i in range(num_mixtures)]
mean=tuple(zip(xs, ys))
cov=tuple([(std, std)] * num_mixtures)



In [4]:
div = 'JS'

In [5]:

batch_size = 64
train_set_len = 10000
train_set = gmm_sampler(train_set_len, num_mixtures, mean, cov, mix_coeffs)

synthdataset = SynthDataset(train_set_len, train_set)
dataloader = DataLoader(synthdataset, batch_size=batch_size,
                        shuffle=True, num_workers=4) 
if div == 'KL':
    def activation_func(x):
        return x

    def conjugate(x):
        return torch.exp(x - 1)

elif div == 'Reverse-KL':
    def activation_func(x):
        return -torch.exp(-x)

    def conjugate(x):
        return -1 - torch.log(-x)

elif div == 'JS':
    def activation_func(x):
        return torch.log(2.0 / (1 + torch.exp(-x)))

    def conjugate(x):
        return -torch.log(2 - torch.exp(x))
elif div == 'Pearson':
    def activation_func(x):
        return x

    def conjugate(x):
        return 0.25 * torch.pow(x, 2) + x
elif div == 'Total-Variation':
    def activation_func(x):
        return 0.5 * torch.tanh(x)

    def conjugate(x):
        return x


def g_loss(d_fake_score):
    g_loss_kl = -torch.mean(conjugate(activation_func(d_fake_score)))
    return g_loss_kl

def d_loss(d_real,d_fake):
    return -(torch.mean(activation_func(d_real)) - torch.mean(conjugate(activation_func(d_fake))))


def g_loss_nomean(d_fake_score):
    g_loss_kl = -conjugate(activation_func(d_fake_score))
    return g_loss_kl




In [6]:
device=torch.device('cuda')
class G(nn.Module):
    def __init__(self):
        super(G, self).__init__()   
        
        
        self.c1 = nn.Linear(nz, 128)
        self.c1_bn = nn.BatchNorm1d(128)
        self.c1_relu = nn.ReLU(True)
        self.c2 = nn.Linear(128, 128)
        self.c2_bn = nn.BatchNorm1d(128)
        self.c2_relu = nn.ReLU(True)
        self.c3 = nn.Linear(128, 2)
        self.c3_tanh = nn.Tanh()    
        
    def forward(self, input_1):
        output = self.c1_relu(self.c1_bn(self.c1(input_1)))
        output = self.c2_relu(self.c2_bn(self.c2(output)))
        output = self.c3(output) # removed tanh because we dont want bounding
        return output

        
class D(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(D, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.activation_fn = F.relu
        self.sigmoid_act = F.sigmoid

    def forward(self, x):
        x = self.activation_fn(self.map1(x))
        x = self.activation_fn(self.map2(x))
        return self.map3(x)
      

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
        
netD = D(2,256,1)
netD.to(device)
netD.apply(weights_init)  
optimizerD = optim.Adam(netD.parameters(), lr=0.0001, betas=(0.5, 0.999))



netG = G()
netG.to(device)
netG.apply(weights_init)
optimizerG = optim.Adam(netG.parameters(), lr=0.0001, betas=(0.5, 0.999))

meanmatrix = np.matrix(mean)
import matplotlib.tri as tri
import matplotlib.mlab as mlab


In [7]:


def g_sample():
    with torch.no_grad():
        gen_input = torch.randn(batch_size*10, nz, device=device)
        g_fake_data = netG(gen_input)
        return g_fake_data.cpu().numpy()



In [8]:
    
Tensor = torch.cuda.FloatTensor
LongTensor = torch.cuda.LongTensor
grads=[]
xaxis=[]
stopeps = [1,2,5,10]

In [9]:
for epoch in tqdm(range(11)):
      print(epoch)
      for i, sample in enumerate(dataloader):
            points = Variable(sample['x'].type(Tensor))
            targets = Variable((sample['class']).type(LongTensor), requires_grad = False)        
            batch_size = points.size(0)

            z = torch.randn(batch_size, nz, device=device)

            valid = Variable(Tensor(points.size(0), 1).fill_(1.0), requires_grad=False)
            fake = Variable(Tensor(points.size(0), 1).fill_(0.0), requires_grad=False)

            real_points = Variable(points.type(Tensor), requires_grad = False) 

            # Update G

            optimizerG.zero_grad()
            gen_points = netG(z)
            output_d = netD(gen_points)

            gloss = g_loss(output_d)
            gloss.backward()
            optimizerG.step()
        
            optimizerD.zero_grad()
            output_d_fake = netD(gen_points.detach())
            output_d_real = netD(real_points)

            dloss = d_loss(output_d_real, output_d_fake)
            dloss.backward()
            optimizerD.step()
            
            x = np.arange(-4, 4, 0.1)
            y = np.arange(-4, 4, 0.1)
            X, Y = np.meshgrid(x,y) # X: 80 x 80, Y: 80 x 80
            
            
            if i%50==0:
                data = np.array(list(zip(X.flatten(),Y.flatten()))) # data: 6400 x 2
                tensordata = Variable(torch.Tensor(data)).cuda()
                with torch.no_grad():
                    d_output = netD(tensordata)
                output_loss = g_loss_nomean(d_output).cpu().numpy().reshape(80,80)
                plt.clf()
                ax = Axes3D(plt.figure())
                surf = ax.plot_surface(X, Y, output_loss, cmap=cm.coolwarm, linewidth=0, antialiased=False)
                plt.colorbar(surf, shrink=0.5, aspect=5)
                plt.savefig('visualize_loss/'+str(epoch)+'_'+str(i)+'.png')
                plt.close()

  0%|          | 0/11 [00:00<?, ?it/s]

0


  9%|▉         | 1/11 [00:03<00:33,  3.36s/it]

1


 18%|█▊        | 2/11 [00:06<00:31,  3.49s/it]

2


 27%|██▋       | 3/11 [00:10<00:27,  3.49s/it]

3


 36%|███▋      | 4/11 [00:13<00:24,  3.46s/it]

4


 45%|████▌     | 5/11 [00:17<00:21,  3.52s/it]

5


 55%|█████▍    | 6/11 [00:21<00:17,  3.55s/it]

6


 64%|██████▎   | 7/11 [00:24<00:14,  3.52s/it]

7


 73%|███████▎  | 8/11 [00:28<00:10,  3.55s/it]

8


 82%|████████▏ | 9/11 [00:32<00:07,  3.56s/it]

9


 91%|█████████ | 10/11 [00:35<00:03,  3.60s/it]

10


100%|██████████| 11/11 [00:39<00:00,  3.61s/it]


<Figure size 432x288 with 0 Axes>