In [None]:
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel

import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as datasets

import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

nz = 100
ngf = 64
ndf = 64
nc = 1
lr = .0002
beta1 = .5
imageSize = 28
ncl = 20
cls= 10
#numPC = 1000
batchSize = 50




def gallery(array, ncols=3):
    nindex, height, width, intensity = array.shape
    nrows = nindex//ncols
    assert nindex == nrows*ncols
    # want result.shape = (height*nrows, width*ncols, intensity)
    result = (array.reshape((nrows, ncols, height, width, intensity))
              .swapaxes(1,2)
              .reshape((height*nrows, width*ncols, intensity)))
    return result

def make_array():
    from PIL import Image
    return np.array([np.asarray(Image.open('face.png').convert('RGB'))]*12)




train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       #transforms.Normalize((0.1307,), (0.3081,)  )
                   ])),
    batch_size=batchSize, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       #transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batchSize, shuffle=True)

trainData = train_loader.dataset.train_data.unsqueeze(1).float()/255
trainLabels = train_loader.dataset.train_labels

data_class = []
data = torch.Tensor()
labels = torch.Tensor()
for i in range(0, 10):
    inds = (trainLabels==i)
    temp = trainData.index_select(0, inds.nonzero().squeeze())
    data_class.append(temp)
    data = torch.cat((data, data_class[i]),0)
    label_class = torch.zeros(data_class[i].size(0))*0+i
    labels = torch.cat((labels, label_class),0)

#data = torch.cat((data_class[7][0:numPC], data_class[9][0:numPC]),0)

In [None]:
seed = 8#random.randint(1,1000)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


class _netG(nn.Module):
    def __init__(self, ngpu):
        super(_netG, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(nz+ncl, ngf * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
#            nn.ConvTranspose2d(ngf * 8, ngf * 4, 3, 1, 1, bias=False),
#            nn.BatchNorm2d(ngf * 4),
#            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Sigmoid()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
        return output

    
class _netD(nn.Module):
    def __init__(self, ngpu):
        super(_netD, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            #nn.Conv2d(ndf, ndf, 3, 1, 1, bias=False),
            #nn.BatchNorm2d(ndf),
            #nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 2, ndf * 2, 3, 1, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 4, 3, 1, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 4, ncl, 4, 1, 0, bias=False),
        )

    def forward(self, input):
        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)

        return F.softmax(output.view(-1, ncl))

    

In [None]:
netG = _netG(1)
netD = _netD(1)
criterion = nn.BCELoss()
netG.apply(weights_init);
netD.apply(weights_init);
epoch = 0


In [None]:
input = torch.FloatTensor(batchSize, nc, imageSize, imageSize)
noise = torch.FloatTensor(batchSize, nz+ncl, 1, 1)
fixed_noise = torch.FloatTensor(batchSize, nz, 1, 1).normal_(0, 1)
label = torch.FloatTensor(batchSize)
real_label = 1
fake_label = 0

netD.cuda()
netG.cuda()
criterion.cuda()
input, label = input.cuda(), label.cuda()
noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

input = Variable(input)
label = Variable(label)
noise = Variable(noise)
fixed_noise = Variable(fixed_noise)

# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

classMat = torch.eye(ncl).cuda()
phi_all = torch.rand(data.size(0), ncl)
phi_all = phi_all/phi_all.sum(1, keepdim=True)
z_all = torch.multinomial(phi_all,1)
z_all = classMat.cpu().index_select(0, z_all.squeeze())

In [None]:
def Estep(phi):
    if(epoch==0):
        px_z = phi/phi.sum(0).expand_as(phi)
        _, inds = px_z.max(1)
        return classMat.index_select(0, inds.squeeze())
    else:
        px_z = phi/phi_backup.sum(0).cuda().expand_as(phi)
        _, inds = px_z.max(1)
        return classMat.index_select(0, inds.squeeze())


In [None]:
############################
# (1) Update D network
###########################

def trainD():
    # train with real
    global real_loss, fake_loss, target, phi_backup
    optimizerD.zero_grad()

    input.data.resize_(real_cpu.size()).copy_(real_cpu)
    output = netD(input)+1e-5
    phi_backup = torch.cat((output.data, phi_backup),0)[0:3000]
    phi_all.index_copy_(0, inds_batch, output.data.cpu())
    z = Estep(output.data)
    z = Variable(z)
    z_all.index_copy_(0, inds_batch, z.data.cpu())
    log_phi = torch.log(output)
    
    exponent = torch.mm(z, log_phi.t())
    exponent2 = exponent - torch.diag(exponent).view(batchSize,1).expand_as(exponent)
    temp = exponent2.exp()
    px_z_inv = temp.sum(1)
    real_loss = px_z_inv.log().sum()
    real_loss.backward()
    
    target = z.data.clone()
    fake_loss = Variable(torch.zeros(1).cuda())
    # train with fake
    noise.data.normal_(0, 1)
    noise[:,nz:nz+ncl,0,0].data.copy_(z.data)
    fake = netG(noise)
    fake_output = netD(fake.detach())+1e-5
    
    log_phi_fake = torch.log(fake_output)
    fake_loss = (fake_output*log_phi_fake).sum()
    fake_loss.backward()

    optimizerD.step()
    return real_loss.data[0], fake_loss.data[0]
    
############################
# (2) Update G network: 
###########################

def trainG():
    optimizerG.zero_grad()
    global target
    z = target.clone()
    z = Variable(z)
    fake = netG(noise)
    output = netD(fake)+1e-5
    log_phi = torch.log(output)
    #loss = (z*log_phi).sum()*-1

    exponent = torch.mm(z, log_phi.t())
    exponent2 = exponent - torch.diag(exponent).view(batchSize,1).expand_as(exponent)
    temp = exponent2.exp()
    px_z_inv = temp.sum(1) 
    loss = px_z_inv.log().sum()
    loss.backward()
    optimizerG.step()
 
    

In [None]:
niter = 25
sigma = .5
import numpy as np
outf = '/home/gaurav/Desktop/Research/models/github/IBFNN/'
label = 0
phi_backup = torch.rand(3000, ncl).cuda()
posterior = torch.zeros(data.size(0), ncl)
alpha = 1
for epoch in range(100):
    netD.train()
    netG.train()
    real_loss_all = 0
    fake_loss_all = 0
    shuffle = torch.randperm(data.size(0))
    for i in range(0, data.size(0), batchSize):
        inds_batch = shuffle[i:i+batchSize]
        real_cpu = data.index_select(0, inds_batch)
        real_loss, fake_loss = trainD()
        trainG()
        
        real_loss_all += real_loss
        fake_loss_all += fake_loss
        

    print(str(real_loss_all) + ' ' + str(fake_loss_all))

    

    _, indices_fake = phi_all.cpu().max(1)
    indices_fake = indices_fake.squeeze().float()
    indices_real = labels
    intersect = torch.zeros(ncl,cls)
    for i in range(0,ncl):
        for j in range(0,cls):
            intersect[i][j] = ((indices_fake==i)*(indices_real==j)).sum()

    accuracy = intersect.max(1)[0].sum()/intersect.sum()
    print('Purity of clusters:',accuracy)


    netG.eval()
    nose = torch.zeros(200,120,1,1).cuda()
    nose = Variable(nose)
    target = torch.zeros(200,20)
    nose.data.normal_(0, 1)
    nose[:, nz:nz+ncl].data.zero_()
    target.zero_()
    k=0
    for i in range(0,200,10):
        target[i:i+10,k].fill_(1)
        k+=1
    nose[:,nz:nz+ncl].data.copy_(target)
    fake = netG(nose)
    result = gallery(fake.permute(0,2,3,1).data.cpu()[10:].numpy(), 10)
    plt.axis('off')
    plt.imshow(result[:,:,0], cmap='gray')
    plt.savefig("./MNIST.png", bbox_inches='tight')
