In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import os
import glob
import PIL
from PIL import Image
from scipy.misc import toimage
from torch.utils import data as D
import random

In [3]:
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

In [14]:
img_size = (64, 64, 1)
latent_size = 100
g_hidden_size = 128
d_hidden_size = 64

num_epochs = 100
batch_size = 128

dataset_dir = "./data/FMNIST"
sample_dir = "./result_dcgan_FMNIST/"

if not os.path.exists(dataset_dir):
    os.makedirs(dataset_dir)

if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [8]:
transform = transforms.Compose([
                transforms.Resize(img_size[0]),
                #transforms.CenterCrop(img_size),
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5),   # 3 for RGB channels
                                     std=(0.5, 0.5, 0.5))])

trainset = torchvision.datasets.FashionMNIST(root=dataset_dir, train=True ,transform=transform, download=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [9]:
class Generator(nn.Module):
    def __init__(self, input_size, latent_size, hidden_size, kernel_size=4):
        super(Generator, self).__init__()
        self.network = nn.Sequential(
            # 1st deconv block
            nn.ConvTranspose2d(latent_size, hidden_size * 8, kernel_size=kernel_size, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(hidden_size * 8),
            nn.ReLU(True),
            
            # 2nd deconv block
            nn.ConvTranspose2d(hidden_size * 8, hidden_size * 4, kernel_size=kernel_size, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(hidden_size * 4),
            nn.ReLU(True),
            
            # 3rd deconv block
            nn.ConvTranspose2d(hidden_size * 4, hidden_size * 2, kernel_size=kernel_size, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(hidden_size * 2),
            nn.ReLU(True),
            
            # 4th deconv block
            nn.ConvTranspose2d(hidden_size * 2, hidden_size * 1, kernel_size=kernel_size, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(hidden_size * 1),
            nn.ReLU(True),
            
            # 5th deconv block
            nn.ConvTranspose2d(hidden_size * 1, input_size[2], kernel_size=kernel_size, stride=2, padding=1, bias=False),
            nn.Tanh()
        )
        
    def forward(self, x):
        output = self.network(x)
        return output

In [10]:
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, kernel_size=4):
        super(Discriminator, self).__init__()
        self.network = nn.Sequential(
            # 64 x 64 x c --> 32 x 32 x hidden_size
            nn.Conv2d(input_size[2], hidden_size, kernel_size=kernel_size, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, True),
            
            # 32 x 32 x hidden_size --> 16 x 16 x hidden_size * 2
            nn.Conv2d(hidden_size, hidden_size * 2, kernel_size=kernel_size, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(hidden_size * 2),
            nn.LeakyReLU(0.2, True),
            
            # 16 x 16 x hidden_size * 2 --> 8 x 8 x hidden_size * 4
            nn.Conv2d(hidden_size * 2, hidden_size * 4, kernel_size=kernel_size, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(hidden_size * 4),
            nn.LeakyReLU(0.2, True),
            
            # 8 x 8 x hidden_size * 4 --> 4 x 4 x hidden_size * 8
            nn.Conv2d(hidden_size * 4, hidden_size * 8, kernel_size=kernel_size, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(hidden_size * 8),
            nn.LeakyReLU(0.2, True),
            
            # 4 x 4 x hidden_size * 8 --> 1 x 1 x 1
            nn.Conv2d(hidden_size * 8, 1, kernel_size=kernel_size, stride=1, padding=0, bias=False),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        output = self.network(x)
        return output

In [11]:
G = Generator(img_size, latent_size, g_hidden_size, kernel_size=4)
D = Discriminator(img_size, d_hidden_size, kernel_size=4)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

G.to(device)
D.to(device)

cuda:0


Discriminator(
  (network): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)

In [12]:
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

In [15]:
# Start training
total_step = len(trainloader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(trainloader):
        images = images.to(device)
        #print("images shape : ", images.size())
        # Create the labels which are later used as input for the BCE loss
        real_labels = torch.ones(images.size()[0], 1).to(device)
        fake_labels = torch.zeros(images.size()[0], 1).to(device)
        
        # ================================================================== #
        #                        Train the discriminator                     #
        # ================================================================== #

        outputs = D(images)
        
        #print("outputs : ", outputs.shape)
        #print("real labels : ", real_labels.shape)
        
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        
        z = torch.randn((images.size()[0], latent_size)).view(-1, latent_size, 1, 1).to(device)
        fake_images = G(z)
        
        #print("fake_images size : ", fake_images.shape)
        
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        d_loss = d_loss_real + d_loss_fake
        
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()

        d_loss.backward()
        d_optimizer.step()
        
        # ================================================================== #
        #                        Train the generator                         #
        # ================================================================== #

        # Compute loss with fake images
        z = torch.randn((images.size()[0], latent_size)).view(-1, latent_size, 1, 1).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        
        g_loss = criterion(outputs, real_labels)
        
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))
    
    # Save real images
    if (epoch+1) == 1:
        images = images.reshape(images.size()[0], 1, 64, 64)
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
    
    # Save sampled images
    fake_images = fake_images.reshape(fake_images.size()[0], 1, 64, 64)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))

# Save the model checkpoints 
torch.save(G.state_dict(), sample_dir + '/G.ckpt')
torch.save(D.state_dict(), sample_dir + '/D.ckpt')

  "Please ensure they have the same size.".format(target.size(), input.size()))


Epoch [0/100], Step [200/469], d_loss: 0.0112, g_loss: 7.0610, D(x): 1.00, D(G(z)): 0.01
Epoch [0/100], Step [400/469], d_loss: 0.0618, g_loss: 6.8804, D(x): 0.95, D(G(z)): 0.01


  "Please ensure they have the same size.".format(target.size(), input.size()))


Epoch [1/100], Step [200/469], d_loss: 0.0326, g_loss: 7.7574, D(x): 0.98, D(G(z)): 0.00
Epoch [1/100], Step [400/469], d_loss: 0.0615, g_loss: 6.6831, D(x): 0.99, D(G(z)): 0.04
Epoch [2/100], Step [200/469], d_loss: 0.0186, g_loss: 4.3121, D(x): 1.00, D(G(z)): 0.01
Epoch [2/100], Step [400/469], d_loss: 0.0591, g_loss: 8.6713, D(x): 0.95, D(G(z)): 0.00
Epoch [3/100], Step [200/469], d_loss: 0.0634, g_loss: 7.2436, D(x): 1.00, D(G(z)): 0.05
Epoch [3/100], Step [400/469], d_loss: 0.0804, g_loss: 6.5258, D(x): 0.94, D(G(z)): 0.01
Epoch [4/100], Step [200/469], d_loss: 0.1907, g_loss: 2.7839, D(x): 0.87, D(G(z)): 0.01
Epoch [4/100], Step [400/469], d_loss: 0.1269, g_loss: 6.3250, D(x): 0.91, D(G(z)): 0.02
Epoch [5/100], Step [200/469], d_loss: 0.0472, g_loss: 5.9886, D(x): 0.98, D(G(z)): 0.03
Epoch [5/100], Step [400/469], d_loss: 0.3958, g_loss: 11.5125, D(x): 0.77, D(G(z)): 0.00
Epoch [6/100], Step [200/469], d_loss: 0.6871, g_loss: 6.6929, D(x): 0.62, D(G(z)): 0.00
Epoch [6/100], Step 