In [1]:
import sys
sys.path.append('./venv/lib/python3.6/site-packages')

In [2]:
import os
import time
import random

import numpy as np
import matplotlib.pyplot as plt

from IPython.display import Image

In [3]:
import torch
from torch import nn
from torch import optim
from torchvision import transforms as vtransforms
from torchvision import utils as vutils
from torchvision import datasets

In [4]:
from gans.utils.data import CartoonSet
from gans.utils.layers import *
from gans.utils.functions import *
from gans.models import *
from gans.trainer import *

In [5]:
# Set random seem for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

Random Seed:  999


<torch._C.Generator at 0x7f082b8b78f0>

In [6]:
# Dataset
data_root = './datasets'

# data_name = 'CartoonSet'
data_name = 'FashionMNIST'
# data_name = 'MNIST'

# batch_size = 200
# sample_size = 50
# nrow = 5

# num_epochs = 100

# num_workers = 8


# # Output dir
# output_root = './output/vanillaGAN'
# # overwrite real sample output
# overwrite_real = False

# noize_dim = 100


In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('using device: ', device)

using device:  cuda:0


In [8]:
img_size = 28
in_channels = 1

if data_name == 'CartoonSet':
    img_size = 75
    dataset = CartoonSet(root=os.path.join(data_root, data_name), 
                         transform=vtransforms.Compose([
                             vtransforms.CenterCrop(size=400),
                             vtransforms.Resize(size=img_size), 
                             vtransforms.ToTensor(),
                             vtransforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
                         ]))
    
elif data_name == 'MNIST':
    dataset = datasets.MNIST(root=os.path.join(data_root, data_name), 
                             download=True,
                             transform=vtransforms.Compose([
                                 vtransforms.ToTensor(),
                                 vtransforms.Normalize((0.5,), (0.5,)),
                         ]))
elif data_name == 'FashionMNIST':
    dataset = datasets.FashionMNIST(root=os.path.join(data_root, data_name), 
                                    download=True,
                                    transform=vtransforms.Compose([
                                        vtransforms.ToTensor(),
                                        vtransforms.Normalize((0.5,), (0.5,)),
                                    ]))


In [11]:
image_sample_dir = 'samples'

real = os.path.join(data_name, 'original')

if not os.path.exists(os.path.join(image_sample_dir, real)):
    os.makedirs(os.path.join(image_sample_dir, real))

image_path = os.path.join(image_sample_dir, real, 'img{:05}.png')

for i in random.sample(range(len(dataset)), 2048):
    vutils.save_image(dataset[i][0], image_path.format(i), 
                      normalize=True, range=(0,1))

- MNIST
  - vanillaGAN: 29.8860039314948
  - cGAN: 27.60618076139116
  - DCGAN: 7.548617460620278
  - cDCGAN: 7.420052673632085
  
- FashionMNIST:
  - vanillaGAN: 55.893088118384526
  - cGAN: 60.76478200199051
  - DCGAN: 13.996017101103291
  - cDCGAN: 13.584223451920764

In [10]:
len(dataset)

60000