# Implementation of Conditional GANs
Reference: https://arxiv.org/pdf/1411.1784.pdf

In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import imageio

In [2]:
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image

In [3]:
import numpy as np
import datetime
import scipy.misc

In [4]:
MODEL_NAME = 'ConditionalGAN'
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
num_pixels=4096
num_classes=25

In [6]:
def to_cuda(x):
    return x.to(DEVICE)

In [7]:
def to_onehot(x, num_classes=25):
    assert isinstance(x, int) or isinstance(x, (torch.LongTensor, torch.cuda.LongTensor))
    if isinstance(x, int):
        c = torch.zeros(1, num_classes).long()
        c[0][x] = 1
    else:
        x = x.cpu()
        c = torch.LongTensor(x.size(0), num_classes)
        c.zero_()
        c.scatter_(1, x, 1) # dim, index, src value
    return c

In [8]:
def get_sample_image(G, n_noise=100):
    """
        save sample 250 images
    """
    for num in range(25):
        c = to_cuda(to_onehot(num))
        for i in range(10):
            z = to_cuda(torch.randn(1, n_noise))
            y_hat = G(z,c)
            line_img = torch.cat((line_img, y_hat.view(64, 64, 3)), dim=1) if i > 0 else y_hat.view(64, 64, 3)
        all_img = torch.cat((all_img, line_img), dim=0) if num > 0 else line_img
    img = all_img.cpu().data.numpy()
    return img

In [9]:
class Discriminator(nn.Module):
    """
        Simple Discriminator w/ MLP
    """
    def __init__(self, input_size=12288, label_size=25, num_classes=1):
        super(Discriminator, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(input_size+label_size, 200),
            nn.ReLU(),
            nn.Dropout(),
        )
        self.layer2 = nn.Sequential(
            nn.Linear(200, 200),
            nn.ReLU(),
            nn.Dropout(),
        )
        self.layer3 = nn.Sequential(
            nn.Linear(200, num_classes),
            nn.Sigmoid(),
        )
    
    def forward(self, x, y):        
        x, y = x.view(x.size(0), -1), y.view(y.size(0), -1).float()
        v = torch.cat((x, y), 1) # v: [input, label] concatenated vector
        y_ = self.layer1(v)
        y_ = self.layer2(y_)
        y_ = self.layer3(y_)
        return y_

In [10]:
class Generator(nn.Module):
    """
        Simple Generator w/ MLP
    """
    def __init__(self, input_size=100, label_size=25, num_classes=12288):
        super(Generator, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(input_size+label_size, 200),
            nn.LeakyReLU(0.2),
            nn.Linear(200, 200),
            nn.LeakyReLU(0.2),
            nn.Linear(200, num_classes),
            nn.Tanh()
        )
        
    def forward(self, x, y):
        x, y = x.view(x.size(0), -1), y.view(y.size(0), -1).float()
        v = torch.cat((x, y), 1) # v: [input, label] concatenated vector
        y_ = self.layer(v)
        y_ = y_.view(x.size(0), 1, 64, 64,3)
        return y_

In [11]:
D = to_cuda(Discriminator())
G = to_cuda(Generator())

In [12]:
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5),std=(0.5, 0.5, 0.5))
])

In [13]:
dataset = datasets.ImageFolder('arcDataset', transform=transform)

In [14]:
batch_size = 64
condition_size = 25

In [15]:

data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [16]:
criterion = nn.BCELoss()
D_opt = torch.optim.Adam(D.parameters())
G_opt = torch.optim.Adam(G.parameters())

In [17]:
max_epoch = 100 # need more than 200 epochs for training generator
step = 0
n_critic = 5 # for training more k steps about Discriminator
n_noise = 100

In [18]:
D_labels = to_cuda(torch.ones(batch_size)) # Discriminator Label to real
D_fakes = to_cuda(torch.zeros(batch_size)) # Discriminator Label to fake

In [19]:
for epoch in range(max_epoch):
    for idx, (images, labels) in enumerate(data_loader):
        step += 1
        # Training Discriminator
        x = to_cuda(images)
        y = labels.view(batch_size, 1)
        y = to_cuda(to_onehot(y))
        x_outputs = D(x, y)
        D_x_loss = criterion(x_outputs, D_labels)

        z = to_cuda(torch.randn(batch_size, n_noise))
        z_outputs = D(G(z, y), y)
        D_z_loss = criterion(z_outputs, D_fakes)
        D_loss = D_x_loss + D_z_loss
        
        D.zero_grad()
        D_loss.backward()
        D_opt.step()
        
        if step % n_critic == 0:
            # Training Generator
            z = to_cuda(torch.randn(batch_size, n_noise))
            z_outputs = D(G(z, y), y)
            G_loss = criterion(z_outputs, D_labels)

            G.zero_grad()
            G_loss.backward()
            G_opt.step()
        
        if step % 1000 == 0:
            print('Epoch: {}/{}, Step: {}, D Loss: {}, G Loss: {}'.format(epoch, max_epoch, step, D_loss.data[0], G_loss.data[0]))
            
        if epoch % 5 == 0:
            G.eval()
            img = get_sample_image(G)
            imageio.imwrite('sample/{}_epoch_{}_type1.jpg'.format(MODEL_NAME, epoch), img)
            G.train()

  "Please ensure they have the same size.".format(target.size(), input.size()))










Epoch: 13/100, Step: 1000, D Loss: 0.028675712645053864, G Loss: 4.3559699058532715












Epoch: 27/100, Step: 2000, D Loss: 0.05598441883921623, G Loss: 5.900113105773926










Epoch: 40/100, Step: 3000, D Loss: 0.037481945008039474, G Loss: 4.589871406555176






PermissionError: [Errno 13] Permission denied: 'C:\\Users\\Cole\\compsci\\CS230\\GAN-Tutorial\\Notebooks\\sample\\ConditionalGAN_epoch_45_type1.jpg'

## Sample

In [None]:
# generation to image
G.eval()
scipy.misc.toimage(get_sample_image(G))

In [None]:
def save_checkpoint(state, file_name='checkpoint.pth.tar'):
    torch.save(state, file_name)

In [None]:
# Saving params.
# torch.save(D.state_dict(), 'D_c.pkl')
# torch.save(G.state_dict(), 'G_c.pkl')
save_checkpoint({'epoch': epoch + 1, 'state_dict':D.state_dict(), 'optimizer' : D_opt.state_dict()}, 'D_c.pth.tar')
save_checkpoint({'epoch': epoch + 1, 'state_dict':G.state_dict(), 'optimizer' : G_opt.state_dict()}, 'G_c.pth.tar')