In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as f
import torch.optim as optim 
from torchvision import datasets, transforms
import torchvision.utils
import os
import sys
import time
import math

In [2]:
latent_dim = 100
channels = 3
batch_size = 64

In [3]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # Size : 512 x 4 x 4
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # Size : 256 x 8 x 8
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # Size : 128 x 16 x 16
            nn.ConvTranspose2d(128, channels, 4, 2, 1, bias=False),
            nn.Tanh(),
            # Size : 1 x 32 x 32
        )
        
    def forward (self, input):
        return self.main(input)

In [4]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # Size : 1 x 32 x 32
            nn.Conv2d(channels, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # Size : 128 x 16 x 16
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # Size 256 x 8 x 8
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # Size : 512 x 4 x 4
            nn.Conv2d(512, 1 , 4, 1, 0, bias=False),
            nn.Sigmoid()            
        )
        
    def forward(self, input):
        return self.main(input)

In [5]:
gen = Generator()
disc = Discriminator()

In [6]:
trainLoader = torch.utils.data.DataLoader(
    datasets.CIFAR10('../data', train=True, download=True,
                    transform=transforms.Compose([
                       transforms.Resize(32),
                       transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 
                    ])),
    batch_size=batch_size, shuffle=True, drop_last=True)

Files already downloaded and verified


In [7]:
# Put on GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
gen.to(device)
disc.to(device)

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.2, inplace=True)
    (6): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2, inplace=True)
    (9): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (10): Sigmoid()
  )
)

In [10]:
gen_optimizer = optim.Adam(gen.parameters(), lr=0.0002, betas=(0.5, 0.999))
desc_optimizer = optim.Adam(disc.parameters(), lr=0.0002, betas=(0.5, 0.999))
criterion = nn.BCELoss()

In [13]:
for epoch in range (20):
    # Put batch into GPU
    for i, (input_batch, _) in enumerate (trainLoader,0):
        input_batch = input_batch.to(device)
        #=======================#
        #==Train discriminator==#
        #=======================#
        
        disc.zero_grad()
    
        # Loss on real data
        real_desc = disc(input_batch)
        desc_loss_real = criterion(real_desc, torch.ones_like(real_desc))
    
        # Loss on fake
        noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
        gen_images = gen(noise)
        gen_desc = disc(gen_images)
        desc_loss_fake = criterion(gen_desc, torch.zeros_like(gen_desc))
    
        desc_loss = desc_loss_fake + desc_loss_real
        desc_loss.backward()
        desc_optimizer.step()
         
         #=======================#
         #==Train generator==#
         #=======================#
    
        gen.zero_grad()
                               
        noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
        gen_imgs = gen(noise)
        gen_desc = disc(gen_imgs)
        gen_loss = criterion(gen_desc, torch.ones_like(gen_desc))
            
        gen_loss.backward()
        gen_optimizer.step()
                               
        if i % 200 == 0:
            print('Epoch [{}], Step [{}], disc_loss: {:.4f}, gen_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
             .format(epoch, i+1, desc_loss.item(), gen_loss.item(),
                real_desc.mean().item(), gen_desc.mean().item()))

    disp_imgs = (gen_imgs + 1.) / 2
    torchvision.utils.save_image(disp_imgs, './img_epoch%.i.png' % epoch)                               

Epoch [0], Step [1], disc_loss: 2.7418, gen_loss: 1.5480, D(x): 0.89, D(G(z)): 0.23
Epoch [0], Step [201], disc_loss: 0.1501, gen_loss: 4.7465, D(x): 0.95, D(G(z)): 0.01
Epoch [0], Step [401], disc_loss: 0.5548, gen_loss: 4.5999, D(x): 0.95, D(G(z)): 0.02
Epoch [0], Step [601], disc_loss: 0.5684, gen_loss: 5.7115, D(x): 0.97, D(G(z)): 0.01
Epoch [1], Step [1], disc_loss: 0.3719, gen_loss: 3.3883, D(x): 0.81, D(G(z)): 0.06
Epoch [1], Step [201], disc_loss: 0.2520, gen_loss: 3.8270, D(x): 0.89, D(G(z)): 0.04
Epoch [1], Step [401], disc_loss: 0.3673, gen_loss: 4.1812, D(x): 0.83, D(G(z)): 0.03
Epoch [1], Step [601], disc_loss: 0.2087, gen_loss: 2.8964, D(x): 0.85, D(G(z)): 0.10
Epoch [2], Step [1], disc_loss: 0.2712, gen_loss: 3.5839, D(x): 0.86, D(G(z)): 0.04
Epoch [2], Step [201], disc_loss: 1.7683, gen_loss: 1.0918, D(x): 0.24, D(G(z)): 0.46
Epoch [2], Step [401], disc_loss: 0.1659, gen_loss: 2.9390, D(x): 0.95, D(G(z)): 0.07
Epoch [2], Step [601], disc_loss: 0.5879, gen_loss: 4.0339, 

In [14]:
disp_imgs = (gen_imgs + 1.) / 2
torchvision.utils.save_image(disp_imgs, './img_epoch%.i.png' % epoch)      