In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
import os, time
import numpy as np
import matplotlib.pyplot as plt
import itertools
import pickle
import imageio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from tqdm import tqdm_notebook
import kornia
from torchvision.utils import save_image
from IPython.core.display import Image, display

In [2]:
# refactored from https://github.com/znxlwm/pytorch-MNIST-CelebA-cGAN-cDCGAN/blob/master/pytorch_MNIST_cDCGAN.py
# G(z)
class generator(nn.Module):
    # initializers
    def __init__(self, d=128, n_classes=10):
        super(generator, self).__init__()
        self.deconv1_1 = nn.ConvTranspose2d(100, d*2, 4, 1, 0)
        self.deconv1_1_bn = nn.BatchNorm2d(d*2)
        self.deconv1_2 = nn.ConvTranspose2d(n_classes, d*2, 4, 1, 0)
        self.deconv1_2_bn = nn.BatchNorm2d(d*2)
        self.deconv2 = nn.ConvTranspose2d(d*4, d*2, 4, 2, 1)
        self.deconv2_bn = nn.BatchNorm2d(d*2)
        self.deconv3 = nn.ConvTranspose2d(d*2, d, 4, 2, 1)
        self.deconv3_bn = nn.BatchNorm2d(d)
        self.deconv4 = nn.ConvTranspose2d(d, 1, 4, 2, 1)

    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    # forward method
    def forward(self, input, label):
        x = F.relu(self.deconv1_1_bn(self.deconv1_1(input)))
        y = F.relu(self.deconv1_2_bn(self.deconv1_2(label)))
        x = torch.cat([x, y], 1)
        x = F.relu(self.deconv2_bn(self.deconv2(x)))
        x = F.relu(self.deconv3_bn(self.deconv3(x)))
        x = torch.tanh(self.deconv4(x))
        return x

class discriminator(nn.Module):
    # initializers
    def __init__(self, d=128, n_classes=10):
        super(discriminator, self).__init__()
        self.conv1_1 = nn.Conv2d(1, d//2, 4, 2, 1)
        self.conv1_2 = nn.Conv2d(n_classes, d//2, 4, 2, 1)
        self.conv2 = nn.Conv2d(d, d*2, 4, 2, 1)
        self.conv2_bn = nn.BatchNorm2d(d*2)
        self.conv3 = nn.Conv2d(d*2, d*4, 4, 2, 1)
        self.conv3_bn = nn.BatchNorm2d(d*4)
#         self.conv4 = nn.Conv2d(d * 4, 101, 4, 1, 0)
        self.conv4 = nn.Conv2d(d * 4, 1, 4, 1, 0)

    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    # forward method
    def forward(self, images, labels):
        x = F.leaky_relu(self.conv1_1(images), 0.2)
        y = F.leaky_relu(self.conv1_2(labels), 0.2)
        x = torch.cat([x, y], 1)
        x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)
        x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)
        x = self.conv4(x)
        return torch.sigmoid(x[:, :1])

def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

In [3]:
# help(nn.Conv2d)

In [4]:
def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'):
    x = range(len(hist['D_losses']))
    y1 = hist['D_losses']
    y2 = hist['G_losses']
    plt.plot(x, y1, label='D_loss')
    plt.plot(x, y2, label='G_loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend(loc=4)
    plt.grid(True)
    plt.tight_layout()
    if save:
        plt.savefig(path)
    if show:
        plt.show()
    else:
        plt.close()

In [5]:
# training parameters
batch_size = 128
lr = 0.0002
train_epoch = 50

# data_loader
img_size = 32


# MNIST loader.
n_classes = 10
transform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
])
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True)

# from symbols_dataset import make_loader
# n_classes = 22
# dataloader, dataset = make_loader(batch_size, img_size)


# results save folder
root = 'MNIST_cDCGAN_results/'
model = 'MNIST_cDCGAN_'
if not os.path.isdir(root):
    os.mkdir(root)
if not os.path.isdir(root + 'Fixed_results'):
    os.mkdir(root + 'Fixed_results')

train_hist = {}
train_hist['D_losses'] = []
train_hist['G_losses'] = []
train_hist['per_epoch_ptimes'] = []
train_hist['total_ptime'] = []

onehot = torch.zeros(n_classes, n_classes)
onehot = onehot.scatter_(1, torch.LongTensor(np.arange(n_classes)).view(n_classes, 1), 1).view(n_classes, n_classes, 1, 1)
fill = torch.zeros([n_classes, n_classes, img_size, img_size])
for i in range(n_classes):
    fill[i, i, :, :] = 1


In [6]:
# fixed noise & label
num_fixed = 100

fixed_z = torch.randn((num_fixed, 100, 1, 1)).cuda()

fixed_y_label = torch.zeros(num_fixed, n_classes, 1, 1).cuda()
for i in range(num_fixed):
    label = (i // 4) % n_classes
    fixed_y_label[i, label, :, :] = 1
    
def show_result(num_epoch, show = True, save = False, path = 'result.png'):
    G.eval()
    test_images = G(fixed_z, fixed_y_label)
    test_images = (test_images.cpu() + 1) * 0.5
    G.train()
    test_images = test_images[:100].data.view(100, 1, 32, 32)
    save_image(test_images, path, nrow=10, padding=1, pad_value=1, scale_each=False, normalize=False)
    display(Image(path))

In [7]:
import matplotlib.pyplot as plt
import matplotlib as mpl

def show_dataset(dataset, n=6):
    img = np.hstack([ np.asarray(dataset[i][0][0]) for i in range(n) ])
    plt.figure(figsize = (10,2))
    plt.imshow(img)
    plt.axis('off')

In [9]:
# show_dataset(dataset, 12)

In [None]:
# network
G = generator(128, n_classes=n_classes)
D = discriminator(128, n_classes=n_classes)
G.weight_init(mean=0.0, std=0.02)
D.weight_init(mean=0.0, std=0.02)
G.cuda()
D.cuda()

# Binary Cross Entropy loss
BCE_loss = nn.BCELoss()

# Adam optimizer
G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

print('training start!')
start_time = time.time()
for epoch in range(train_epoch):
    D_losses = []
    G_losses = []
    
    if (epoch+1) in  [10, 16]:
        G_optimizer.param_groups[0]['lr'] /= 10
        D_optimizer.param_groups[0]['lr'] /= 10
        print("learning rate change!")

    epoch_start_time = time.time()
    y_real = torch.ones(batch_size).cuda()
    y_fake = torch.zeros(batch_size).cuda()
    
    for i, (images, y_) in enumerate(tqdm_notebook(dataloader)):
        images = images.cuda()
        
        #--------------------------------------------------
        # train discriminator
        #--------------------------------------------------
        D.zero_grad()
        mini_batch = images.size()[0]

        if mini_batch != batch_size:
            y_real = torch.ones(mini_batch).cuda()
            y_fake = torch.zeros(mini_batch).cuda()
        
        y_fill = fill[y_].cuda()

        # Train desciminator on real images.
        D_result = D(images, y_fill).squeeze()
        D_real_loss = BCE_loss(D_result, y_real)

        # Train desciminator on fake, generated images.
        z_ = torch.randn((mini_batch, 100)).view(-1, 100, 1, 1).cuda()
        y_ = (torch.rand(mini_batch, 1) * n_classes).long().squeeze()
        y_label_ = onehot[y_].cuda()
        y_fill_ = fill[y_].cuda()

        G_result = G(z_, y_label_)
        D_result = D(G_result, y_fill_).squeeze()
        D_fake_loss = BCE_loss(D_result, y_fake)

        # Total loss is classifying fake + classifying real.
        D_train_loss = D_real_loss + D_fake_loss

        D_train_loss.backward()
        D_optimizer.step()
        D_losses.append(D_train_loss)

        #--------------------------------------------------
        # train generator
        #--------------------------------------------------
        G.zero_grad()

        z = torch.randn((mini_batch, 100)).view(-1, 100, 1, 1).cuda()
        y = (torch.rand(mini_batch, 1) * n_classes).long().squeeze()

        G_result = G(z, onehot[y].cuda())
        D_result = D(G_result, fill[y].cuda()).squeeze()

        G_train_loss = BCE_loss(D_result, y_real)

        G_train_loss.backward()
        G_optimizer.step()
        G_losses.append(G_train_loss)
    

    fixed_p = root + 'Fixed_results/' + model + str(epoch + 1) + '.png'
    show_result((epoch+1), show=True, path=fixed_p)

    epoch_end_time = time.time()
    per_epoch_ptime = epoch_end_time - epoch_start_time

    print('[%d/%d] - ptime: %.2f, loss_d: %.3f, loss_g: %.3f' % \
          ((epoch + 1), train_epoch, per_epoch_ptime, torch.mean(torch.FloatTensor(D_losses)),
                                                              torch.mean(torch.FloatTensor(G_losses))))
    
    train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses)))
    train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses)))
    train_hist['per_epoch_ptimes'].append(per_epoch_ptime)

training start!


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

In [None]:
show_train_hist(train_hist, show=True, save=True, path=root + 'train_hist.png')