In [2]:
from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.autograd import Variable
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

In [23]:
%reload_ext autoreload
%autoreload 2

In [24]:
from vaegan import VAE
from vaegan import NetD
from vaegan import Aux
from vaegan import loss_function

In [25]:
bsz = 128
criterion = nn.BCELoss()


In [26]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', download=True,
                   transform=transforms.ToTensor()),
    batch_size=bsz, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train = False, transform=transforms.ToTensor()),
    batch_size=bsz, shuffle=True)

In [27]:
netG = VAE()
netD = NetD()
aux = Aux()

In [28]:
optimizerD = optim.Adam(netD.parameters(), lr=1e-4)
optimizerG = optim.Adam(netG.parameters(), lr=1e-4)
optimizer_aux = optim.Adam(aux.parameters(), lr=1e-4)

In [29]:
input = torch.FloatTensor(bsz,28,28)
label = torch.FloatTensor(bsz)
real_label=1
fake_label=0
USE_CUDA=1

if(USE_CUDA):
    netG=netG.cuda()
    netD=netD.cuda()
    aux = aux.cuda()
    criterion=criterion.cuda()
    input,label=input.cuda(), label.cuda()

In [36]:
%reload_ext autoreload
%autoreload 2
for epoch in range(200):
    for i, (data,y) in enumerate(train_loader):
        gamma = 1.0
        real_cpu = data;

        real_cpu = real_cpu.cuda()
        y=y.cuda()
        input.resize_as_(real_cpu).copy_(real_cpu)
        label.resize_(bsz).fill_(real_label)

        inputv = Variable(input)
        labelv = Variable(label)
        y = Variable(y)

        #need variables for dis
        #x_l, x_l_tilde
        
        
        #do discriminator calculations
        netD.zero_grad()
        #fc3_weight,fc4_weight = aux.return_weights()
        mu,logvar = netG(inputv, y)
        std = logvar.mul(0.5).exp_()
        eps = Variable(std.data.new(std.size()).normal_())
        z=eps.mul(std).add_(mu)
        fake = aux(z, y)
        
        x_l_tilde, output_fake = netD(fake, y)
        x_l, output_real = netD(inputv, y)
        #x_l_aux, output_fake_aux = netD(fake_aux)
        L_GAN_real = criterion(output_real, labelv)
        L_GAN_real.backward(retain_graph=True)
        
        labelv = Variable(label.fill_(fake_label))
        L_GAN_fake = criterion(output_fake, labelv)
        L_GAN_fake.backward(retain_graph=True)
        
        z_p = Variable(std.data.new(std.size()).normal_())
        fake_aux = aux(z_p, y)
        x_l_aux, output_aux = netD(fake_aux, y)
        L_GAN_aux = criterion(output_aux,labelv)
        L_GAN_aux.backward(retain_graph=True)
        optimizerD.step()
           
        
        #get weights of netG and use in aux
        aux.zero_grad()
        labelv=Variable(label.fill_(real_label))
        
        L_dec_vae = gamma*loss_function(x_l_tilde,x_l,mu,logvar)
        L_dec_fake = criterion(output_fake,labelv)
        L_dec_aux  = criterion(output_aux,labelv)
        L_dec_vae.backward(retain_graph=True)
        L_dec_fake.backward(retain_graph=True)
        L_dec_aux.backward(retain_graph=True)
        optimizer_aux.step()
        
        

        #encoder loss 
        netG.zero_grad()
        L_enc = loss_function(x_l_tilde, x_l,mu,logvar)
        L_enc.backward()
        optimizerG.step()

    
        if i % 100 == 0:
            print('real_cpu.size()', real_cpu.size(), "iteration: ", i)
            vutils.save_image(real_cpu,
                            './results/cvaegan results/real_samples2.png',
                                normalize=True)
            vutils.save_image(fake.data.view(-1,1,28,28),
                                './results/cvaegan results/fake_samples2.png',
                                normalize=True)
    if epoch % 25 == 0:
        print('epoch: ', epoch)
        vutils.save_image(fake.data.view(-1,1,28,28),
                                './results/cvaegan results/fake_samples2_{0}.png'.format(epoch),
                                normalize=True)

real_cpu.size() torch.Size([100, 1, 28, 28]) iteration:  0
real_cpu.size() torch.Size([100, 1, 28, 28]) iteration:  100
real_cpu.size() torch.Size([100, 1, 28, 28]) iteration:  200
real_cpu.size() torch.Size([100, 1, 28, 28]) iteration:  300
real_cpu.size() torch.Size([100, 1, 28, 28]) iteration:  400
real_cpu.size() torch.Size([100, 1, 28, 28]) iteration:  500
epoch:  0
real_cpu.size() torch.Size([100, 1, 28, 28]) iteration:  0
real_cpu.size() torch.Size([100, 1, 28, 28]) iteration:  100
real_cpu.size() torch.Size([100, 1, 28, 28]) iteration:  200
real_cpu.size() torch.Size([100, 1, 28, 28]) iteration:  300
real_cpu.size() torch.Size([100, 1, 28, 28]) iteration:  400
real_cpu.size() torch.Size([100, 1, 28, 28]) iteration:  500
real_cpu.size() torch.Size([100, 1, 28, 28]) iteration:  0
real_cpu.size() torch.Size([100, 1, 28, 28]) iteration:  100
real_cpu.size() torch.Size([100, 1, 28, 28]) iteration:  200
real_cpu.size() torch.Size([100, 1, 28, 28]) iteration:  300
real_cpu.size() torc

In [None]:
torch.save(netG, './pretrained models/netG3.pth')
torch.save(netD, './pretrained models/netG3.pth')
torch.save(aux, './pretrained models/netG3.pth')

## Load Pretrained Model

In [5]:
netG = torch.load('pretrained_models/netG2.pth')
netD = torch.load('pretrained_models/netD2.pth')
aux = torch.load('pretrained_models/aux2.pth')

EOFError: Ran out of input

In [None]:
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=bsz, shuffle=True)

In [29]:
data, y = iter(test_loader).next()
vutils.save_image(data.view(-1,1,28,28),
                                './fake.png',
                                normalize=True)

In [30]:
%reload_ext autoreload
%autoreload 2
mu,logvar = netG(Variable(data).cuda(), Variable(y).cuda(), Variable(torch.tensor([8])).cuda(), .5)
std = logvar.mul(0.5).exp_()
eps = Variable(std.data.new(std.size()).normal_())
z=eps.mul(std).add_(mu)
fake = aux(z, y, Variable(torch.tensor([8])).cuda(), .5)
vutils.save_image(fake.data.view(-1,1,28,28),
                                './results/cvae results/generated2.png',
                                normalize=True)

In [31]:
mu,logvar = netG(Variable(fake), Variable(y).cuda())
std = logvar.mul(0.5).exp_()
eps = Variable(std.data.new(std.size()).normal_())
z=eps.mul(std).add_(mu)
fake2 = aux(z, y)
vutils.save_image(fake2.data.view(-1,1,28,28),
                                './results/cvae results/generated3.png',
                                normalize=True)

In [32]:
mu,logvar = netG(Variable(data).cuda(), Variable(y).cuda(), Variable(torch.tensor([8])).cuda(), 1)
std = logvar.mul(0.5).exp_()
eps = Variable(std.data.new(std.size()).normal_())
z=eps.mul(std).add_(mu)
fake = aux(z, y, Variable(torch.tensor([8])).cuda(), 1)
vutils.save_image(fake.data.view(-1,1,28,28),
                                './results/cvae results/generated.png',
                                normalize=True)