In [107]:
#importing dependencies

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import save_image

#import _pickle as cpickle #to store model histories in a file
import os
import imageio
from PIL import Image

use_cuda = False
device   = torch.device('cpu')
if torch.cuda.is_available():
    use_cuda = True
    device   = torch.device('cuda')
print(use_cuda)    
    
%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots"

False


In [108]:
#specifying parameters
image_size = 64 #HACK to use MNIST architecture
G_input_dim = 100
G_output_dim = 3
D_input_dim = 3
D_output_dim = 1
num_filters = [1024, 512, 256, 128]

learning_rate = 0.0002
betas = (0.5, 0.999)
batch_size = 128
num_epochs = 1000

data_dir = './Train_data'
save_dir = './DCGAN_results/'

In [109]:
print(os.getcwd())

/Users/vivek1410patel/Documents/Quarter 3/CS231n/project/PoseGuided


In [110]:
#loading data

transform = transforms.Compose([
                                 #Hack to make MNIST code work
                                transforms.ToTensor(),
                                #transforms.Normalize(mean=(214.0466981, 206.55220904, 203.99178198), std=(54.34939265, 55.62690195, 58.85794001))
                               ])
                                

df_data = dsets.ImageFolder(data_dir, transform = transform)

data_loader = torch.utils.data.DataLoader(dataset=df_data,
                                          batch_size=batch_size,
                                          shuffle=True)


## VAE Model

In [111]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(64*64*3, 4096)
        self.fc2 = nn.Linear(4096,1000)
        self.fc3 = nn.Linear(1000,400)
        self.fc31 = nn.Linear(400, 60)
        self.fc32 = nn.Linear(400, 60)
        self.fc4 = nn.Linear(60, 400)
        self.fc5 = nn.Linear(400,1000)
        self.fc6 = nn.Linear(1000,4096)
        self.fc7 = nn.Linear(4096, 64*64*3)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        h2 = F.relu(self.fc2(h1))
        h3 = F.relu(self.fc3(h2))
        return self.fc31(h3), self.fc32(h3)

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5*logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, z):
        h4 = F.relu(self.fc4(z))
        h5 = F.relu(self.fc5(h4))
        h6 = F.relu(self.fc6(h5))
        #out = F.sigmoid(self.fc7(h6))
        out = self.fc7(h6)
        return out

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 64*64*3))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

##  Loss Function

In [112]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.mse_loss(recon_x, x.view(-1, 64*64*3), size_average=False)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

## Set up training and testing functions

In [113]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(data_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model.forward(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(data_loader.dataset),
                100. * batch_idx / len(data_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(data_loader.dataset)))


def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(data_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0 and epoch%1 == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(-1, 3, 64, 64)[:n]])
                save_image(comparison.cpu(),
                         'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(data_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))


# Train and Test

In [None]:
for epoch in range(1, num_epochs+1):
    train(epoch)
    test(epoch)
    
    with torch.no_grad():
        if epoch%1 == 0:
            sample = torch.randn(64, 60).to(device)
            sample = model.decode(sample).cpu()
            save_image(sample.view(64, 3, 64, 64),'results/sample_' + str(epoch) + '.png')


====> Epoch: 1 Average loss: 1228.0886
====> Test set loss: 465.1393
====> Epoch: 2 Average loss: 410.0470
====> Test set loss: 453.9904
====> Epoch: 3 Average loss: 374.0098
====> Test set loss: 377.6066
====> Epoch: 4 Average loss: 354.5689
====> Test set loss: 362.3549
====> Epoch: 5 Average loss: 328.1678
====> Test set loss: 323.1482


====> Epoch: 6 Average loss: 313.8504
====> Test set loss: 319.5590
====> Epoch: 7 Average loss: 304.3807
====> Test set loss: 306.9815
====> Epoch: 8 Average loss: 299.2436
====> Test set loss: 291.8725
====> Epoch: 9 Average loss: 291.5027
====> Test set loss: 315.3355
====> Epoch: 10 Average loss: 286.5118
====> Test set loss: 289.6316
====> Epoch: 11 Average loss: 281.9416
====> Test set loss: 283.8439
