In [30]:
import os
import math
import random

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torch.utils.data
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [29]:
reset

Once deleted, variables cannot be recovered. Proceed (y/[n])? y


In [31]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

In [32]:
image_size = 32
batch_size = 8

dataset = datasets.CIFAR10(root="./data/cifar", 
                           download=True,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

dataloader = torch.utils.data.DataLoader(dataset, 
                                         batch_size=batch_size,
                                         shuffle=True)

Files already downloaded and verified


In [33]:
def weights_init(m):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean=0.0, std=0.02)
    elif isinstance(m, nn.BatchNorm2d):
        if m.weight is not None:
            m.weight.data.normal_(mean=1.0, std=0.02)
        if m.bias is not None:
            m.bias.data.fill_(0)

In [42]:
#64  kernel_size 2 stride 2 padding 1
#128 kernel_size 4 stride 2 padding 1
#256 kernel_size 4 stride 2 padding 1
#512 kernel_size 4 stride 2 padding 1
#1   kernel_size 2 stride 1 padding 0

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.conv_1 = nn.Conv2d(3, 64, kernel_size = 3, stride = 2, padding = 1)
        self.batch_norm_1 = nn.BatchNorm2d(64)
        
        self.conv_2 = nn.Conv2d(64, 128, kernel_size = 4, stride = 2, padding = 1)
        self.batch_norm_2 = nn.BatchNorm2d(128)
        
        self.conv_3 = nn.Conv2d(128, 256, kernel_size = 4, stride = 2, padding = 1)
        self.batch_norm_3 = nn.BatchNorm2d(256)
        
        self.conv_4 = nn.Conv2d(256, 512, kernel_size = 4, stride = 2, padding = 1)
        self.batch_norm_4 = nn.BatchNorm2d(512)
        
        self.conv_5 = nn.Conv2d(512, 1, kernel_size = 2, stride = 1, padding = 0)

        
    def forward(self, x):
        '''
        Inputs:
            x: (batch x 3 x 32 x 32)
        Outputs:
            prob: (batch x 1)
        '''
        x = self.conv_1(x)
        x = self.batch_norm_1(x)
        x = F.leaky_relu(x, 0.2)
        
        x = self.conv_2(x)
        x = self.batch_norm_2(x)
        x = F.leaky_relu(x, 0.2)
        
        x = self.conv_3(x)
        x = self.batch_norm_3(x)
        x = F.leaky_relu(x, 0.2)
        
        x = self.conv_4(x)
        x = self.batch_norm_4(x)
        x = F.leaky_relu(x, 0.2)
        
        x = self.conv_5(x)

        
        x = x.view(x.size(0), -1)
        
        x = torch.sigmoid(x)

        
        return x

          
class Generator(nn.Module):
    def __init__(self, z_size):
        super(Generator, self).__init__()
        
        self.conv_1 = nn.ConvTranspose2d(z_size, 512, kernel_size = 4, stride = 1, padding = 0, bias=False)     
        
        self.conv_2 = nn.ConvTranspose2d(512, 256, kernel_size = 4, stride = 2, padding = 1, bias=False)
        self.batch_norm_2 = nn.BatchNorm2d(256)
        
        self.conv_3 = nn.ConvTranspose2d(256, 128, kernel_size = 4, stride = 2, padding = 1, bias=False)
        self.batch_norm_3 = nn.BatchNorm2d(128)
        
        self.conv_4 = nn.ConvTranspose2d(128, 3, kernel_size = 4, stride = 2, padding = 1, bias=False)
#         self.batch_norm_4 = nn.BatchNorm2d(64)      
        
#         self.conv_5 = nn.ConvTranspose2d(64, 3, kernel_size = 4, stride = 2, padding = 1, bias=False)  
        
    
    def forward(self, noise):
        '''
        Inputs:
            noise: (batch x z_size)
        Outputs:
            image: (batch x 3 x 32 x 32)
        '''
        #code here
        image = noise.view(noise.size(0), 100, 1, 1)

        
        image = self.conv_1(image)
        image = F.leaky_relu(image, 0.2)
        
        image = self.conv_2(image)
        image = self.batch_norm_2(image)
        image = F.leaky_relu(image, 0.2)
        
        image = self.conv_3(image)
        image = self.batch_norm_3(image)
        image = F.leaky_relu(image, 0.2)
        
        image = self.conv_4(image)
#         image = self.batch_norm_4(image)
        image = F.tanh(image)
        
#         image = self.conv_5(image)
#         image = F.leaky_relu(image, 0.2)
        
        return image

In [43]:
latent_size = 100

generator     = Generator(latent_size).to(device)
discriminator = Discriminator().to(device)

adversarial_loss = nn.BCELoss()

lr    = 0.0002
beta1 = 0.5
beta2 = 0.999

gen_optimizer = optim.Adam(generator.parameters(),     lr=lr, betas=(beta1, beta2))
dis_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))

epoch      = 0
num_epochs = 25

dis_losses = []
gen_losses = []

In [44]:
fixed_noise = torch.randn(batch_size, latent_size).to(device)

In [47]:
while epoch < num_epochs:
    for batch_idx, (image, _) in enumerate(dataloader):
        image = image.to(device)
        
        real_labels = torch.ones(image.size(0), 1).to(device)
        fake_labels = torch.zeros(image.size(0), 1).to(device)

        generated_image = generator(torch.randn(image.size(0), 100).to(device))
        
        generator_loss = adversarial_loss(discriminator(generated_image), real_labels)
        
        gen_optimizer.zero_grad()
        generator_loss.backward()
        gen_optimizer.step()
        

        #code here
        disc_output = discriminator(image)
        discriminator_loss = adversarial_loss(disc_output, real_labels) + adversarial_loss(discriminator(generated_image.detach()), fake_labels)
            
        dis_optimizer.zero_grad()
        discriminator_loss.backward()
        dis_optimizer.step()
        
        dis_losses.append(discriminator_loss.item()) 
        gen_losses.append(generator_loss.item())
        
        if batch_idx % 500 == 0:
            torchvision.utils.save_image(generator(fixed_noise).data,
                                          '%s/%d_%d.png' % ('./dcan', epoch, batch_idx),
                                          normalize=True)
        
    epoch += 1



KeyboardInterrupt: 