<a href="https://colab.research.google.com/github/lohaoxi/basic-pytorch-gans/blob/master/02_conditional_gan_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [17]:
import os
import random
import math

import scipy.linalg
import numpy as np
from matplotlib import pyplot as plt

import torch
import torchvision
import torch.nn.functional as F
import torch.nn as nn
from torch import optim
from torch.autograd import Variable
from torchvision.models import inception_v3

if not os.path.exists('visuals'):
    os.mkdir('visuals')

In [18]:
BATCH_SIZE = 64
K = 4
N_EPOCHS = 4096
NOISE_DIM = 100
IMAGE_DIM = 28*28
MAXOUT_SIZE = 5
HIDDEN_DIM = (240, 240)
LABEL_DIM = 10

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class FlattenTransform:
    
    def __call__(self, inputs):
        return inputs.view(inputs.shape[0], -1)

data_train = torchvision.datasets.MNIST(
    "./data/mnist", 
    train=True, 
    download=True,
    transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(), 
        FlattenTransform()
        ])
    )

loader_train = torch.utils.data.DataLoader(
    data_train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4
    )

In [20]:
print("Number of observation: {0}".format(len(data_train)))
print("Size of each observation: {0}".format(np.array(data_train[0][0]).shape))

Number of observation: 60000
Size of each observation: (1, 784)


In [21]:


class Maxout(nn.Module):

    def __init__(self, n_pieces):
        super(Maxout, self).__init__()
        self.n_pieces = n_pieces
        
    def forward(self, batch):
        assert batch.shape[1] % self.n_pieces == 0
        batch = batch.view(
            batch.shape[0], 
            batch.shape[1] // self.n_pieces, 
            self.n_pieces
            )
        batch, _ = batch.max(dim=2)
        return batch
    

class Generator(nn.Module):

    def __init__(self, noise_dim, lbl_dim, hid_dim, out_dim):
        super(Generator, self).__init__()
        self.noise_dim = noise_dim
        self.lbl_dim = lbl_dim
        self.hid_dim = hid_dim
        self.out_dim = out_dim

        self.fc1 = nn.Linear(self.noise_dim + self.lbl_dim, self.hid_dim, bias=True)
        self.fc2 = nn.Linear(self.hid_dim, self.hid_dim, bias=True)
        self.fc3 = nn.Linear(self.hid_dim, self.hid_dim, bias=True)
        self.fc4 = nn.Linear(self.hid_dim, self.hid_dim, bias=True)
        self.fc5 = nn.Linear(self.hid_dim, self.hid_dim, bias=True)
        self.fc6 = nn.Linear(self.hid_dim, self.out_dim, bias=True)

    def forward(self, batch, label):

        batch = batch.view(batch.size(0), -1)
        batch = torch.cat((batch, label), dim=1)

        batch = F.dropout(F.relu(self.fc1(batch)))
        batch = F.dropout(F.relu(self.fc2(batch)))
        batch = F.dropout(F.relu(self.fc3(batch)))
        batch = F.dropout(F.relu(self.fc4(batch)))
        batch = F.dropout(F.relu(self.fc5(batch)))

        batch = torch.sigmoid(self.fc6(batch))

        return batch


class Discriminator(nn.Module):

    def __init__(self, in_dim, lbl_dim, hid_dim, out_dim, maxout_size):
        super(Discriminator, self).__init__()
        self.in_dim = in_dim
        self.lbl_dim = lbl_dim
        self.hid_dim = hid_dim
        self.out_dim = out_dim
        self.maxout_size = maxout_size
        self.Maxout = Maxout(maxout_size)

        self.fc1 = nn.Linear(self.in_dim + self.lbl_dim, self.hid_dim, bias=True)
        self.fc2 = nn.Linear(self.hid_dim // self.maxout_size, self.hid_dim, bias=True)
        self.fc3 = nn.Linear(self.hid_dim // self.maxout_size, self.hid_dim, bias=True)
        self.fc4 = nn.Linear(self.hid_dim // self.maxout_size, self.out_dim, bias=True)

    def forward(self, batch, label):

        batch = batch.view(batch.size(0), -1)
        batch = torch.cat((batch, label), dim=1)

        batch = self.Maxout(self.fc1(batch))
        batch = self.Maxout(self.fc2(batch))
        batch = self.Maxout(self.fc3(batch))
        batch = torch.sigmoid(self.fc4(batch))
        
        return batch

In [28]:
generator = Generator(NOISE_DIM, LABEL_DIM, HIDDEN_DIM[0], IMAGE_DIM).to(device)
discriminator = Discriminator(IMAGE_DIM, LABEL_DIM, HIDDEN_DIM[1], 1, MAXOUT_SIZE).to(device)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Number of parameters in the Generatoris: {}".format(count_parameters(generator)))
print("Number of parameters in the Discriminator: {}".format(count_parameters(discriminator)))

Number of parameters in the Generatoris: 446944
Number of parameters in the Discriminator: 214369


In [29]:
def encodeOneHot(lbls, lbl_dim):
    ret = torch.FloatTensor(lbls.shape[0], lbl_dim)
    ret.zero_()
    ret.scatter_(dim=1, index=lbls.view(-1, 1), value=1)
    return ret

In [30]:
lbls_real = torch.ones(BATCH_SIZE, 1).to(device)
lbls_fake = torch.zeros(BATCH_SIZE, 1).to(device)

test_z = (2 * torch.randn(10, NOISE_DIM) - 1).to(device)
test_y = encodeOneHot(lbls=torch.tensor(np.arange(0, 10)), 
                      lbl_dim=LABEL_DIM).to(device)

num_steps = len(loader_train) // BATCH_SIZE

discriminator_optimizer = torch.optim.SGD(
    discriminator.parameters(),
    lr=0.002,
    momentum=0.7
)

generator_optimizer = torch.optim.SGD(
    generator.parameters(),
    lr=0.002,
    momentum=0.7
)

criterion = torch.nn.BCELoss()

In [31]:
def visualizeGAN(imgs, lbls, epoch):

    fig, axes = plt.subplots(2, 5, figsize=(20, 18))
    
    fig.suptitle('Epoch {}'.format(str(epoch).zfill(4)))

    for row, axe in enumerate(axes):
        for col, cell in enumerate(axe):
            cell.imshow(
                imgs[row * 5 + col],
                cmap='gray'
            )
            
            cell.set_title('{}'.format(
                torch.argmax(lbls[row * 5 + col])
            ))

            cell.axis("off")


    plt.axis("off")
    plt.tight_layout()

    fig.savefig(os.path.join("visuals", "{}.jpg".format(str(epoch).zfill(6))))
    
    plt.close()

In [None]:
d_loss_ls = []
g_loss_ls = []
d_lr_ls = []
g_lr_ls = []


for epoch in range(N_EPOCHS):
    
    # Loss Log
    d_counter = 0
    g_counter = 0
    d_loss = 0
    g_loss = 0

    for i, (images, labels) in enumerate(loader_train):

        if i == num_steps:
            break

        # Train Discriminator
        for _ in range(4):
        
            real_images = images.to(device)
            real_conditions = encodeOneHot(labels, LABEL_DIM).to(device)
            
            fake_conditions = encodeOneHot(torch.randint(0, 10, (BATCH_SIZE,)), LABEL_DIM).to(device)

            fake_images = generator(
                (2 * torch.randn(BATCH_SIZE, NOISE_DIM) - 1).to(device),
                fake_conditions
            )

            discriminator_optimizer.zero_grad()
            
            real_outputs = discriminator(real_images, real_conditions)
            fake_outputs = discriminator(fake_images, fake_conditions)
            
            d_x = criterion(real_outputs, lbls_real)
            d_g_z = criterion(fake_outputs, lbls_fake)

            d_x.backward()
            d_g_z.backward()

            discriminator_optimizer.step()
            
            # Loss Log
            d_counter += 1
            d_loss = d_x.item() + d_g_z.item()


        # Train Generator
        z = (2 * torch.randn(BATCH_SIZE, NOISE_DIM) - 1).to(device)
        y = encodeOneHot(torch.randint(0, 10, (BATCH_SIZE,)), LABEL_DIM).to(device)

        generator.zero_grad()

        outputs = discriminator(generator(z, y), y)

        loss = criterion(outputs, lbls_real)

        loss.backward()

        generator_optimizer.step()
        
        # LR Decay
#         discriminator_scheduler.step()
#         generator_scheduler.step()
        
        # Loss Log
        g_counter += 1
        g_loss += loss.item()

    # Loss Log
    if epoch % 10 == 0:
        print(
            'e:{}, G:{:.3f}, D:{:.3f}'.format(
                epoch,
                g_loss / g_counter,
                d_loss / d_counter
#                 generator_scheduler.get_lr(),
#                 discriminator_scheduler.get_lr()
            )
        )
    
    # Loss Log for Plot
    g_loss_ls.append(g_loss / g_counter)
    d_loss_ls.append(d_loss / d_counter)
    
    # Learning Rate Decay Log
#     g_lr_ls.append(generator_scheduler.get_lr())
#     d_lr_ls.append(discriminator_scheduler.get_lr())


    # Visualize Results
    if epoch % 5 == 0:

        generated = generator(test_z, test_y).detach().cpu().view(-1, 28, 28)

        visualizeGAN(generated, test_y, epoch)

  cpuset_checked))


e:0, G:1.540, D:0.017
e:10, G:6.016, D:0.000
e:20, G:2.799, D:0.005
e:30, G:5.766, D:0.000
e:40, G:6.221, D:0.000
e:50, G:6.856, D:0.001
e:60, G:6.781, D:0.000
e:70, G:7.638, D:0.000
e:80, G:7.727, D:0.000
e:90, G:4.411, D:0.001
e:100, G:7.090, D:0.000
e:110, G:9.897, D:0.000
e:120, G:8.455, D:0.000
e:130, G:9.367, D:0.000
e:140, G:8.567, D:0.000
e:150, G:8.290, D:0.000
e:160, G:7.698, D:0.000


In [27]:
for epoch in range(N_EPOCHS):

    D_loss = 0
    G_loss = 0
    D_x = 0
    D_g_z1 = 0
    D_g_z2 = 0

    for i, (imgs, lbls) in enumerate(loader_train):

        D_counter = 0
        G_counter = 0

        if i == num_steps: break

        # Train Discriminator

        for _ in range(K):
            
            conds_real = encodeOneHot(lbls=lbls, 
                                      lbl_dim=LABEL_DIM).to(device)

            imgs_real = imgs.to(device)

            conds_fake = encodeOneHot(lbls=torch.randint(0, 10, (BATCH_SIZE,)), 
                                      lbl_dim=LABEL_DIM).to(device)

            imgs_fake = generator(batch=(2 * torch.randn(BATCH_SIZE, NOISE_DIM) - 1).to(device), 
                                  label=conds_fake)
            

            discriminator_optimizer.zero_grad()

            outs_real = discriminator(imgs_real, conds_real)
            outs_fake = discriminator(imgs_fake, conds_fake)

            discriminator_err_real = criterion(outs_real, lbls_real)
            discriminator_err_fake = criterion(outs_fake, lbls_fake)

            discriminator_err_real.backward()
            discriminator_err_fake.backward()

            discriminator_optimizer.step()
            
            D_loss += (discriminator_err_real.item() + discriminator_err_fake.item())
            D_g_z1 += outs_fake.mean().item()
            D_g_z2 += outs.mean().item()
            D_counter += 1

        # Train Generator

        z = (2 * torch.randn(BATCH_SIZE, NOISE_DIM) - 1).to(device)

        y = encodeOneHot(lbls=torch.randint(0, 10, (BATCH_SIZE,)), 
                         lbl_dim=LABEL_DIM).to(device)

        generator.zero_grad()

        outs = discriminator(generator(z, y), y)

        generator_err = criterion(outs, lbls_real)

        generator_err.backward()

        generator_optimizer.step()

        G_loss += generator_err.item()
        D_x += outs_real.mean().item()
        G_counter += 1


    print("epcoh: {}\t D_loss: {:.4f}\t G_loss: {:.4f}\t D_x: {:.4f}\t D_g_z1: {:.4f}\t D_g_z2: {:.4f}\t".format(
        str(epoch).zfill(6), 
        D_loss / D_counter,
        G_loss / G_counter,
        D_x / D_counter,
        D_g_z1 / D_counter,
        D_g_z2 / G_counter,
    )
    )

    if epoch % 10 == 0:

        generated = generator(test_z, test_y).detach().cpu().view(-1, 28, 28)

        visualizeGAN(generated, test_y, epoch)




  cpuset_checked))


epcoh: 000000	 D_loss: 1.1294	 G_loss: 1.1701	 D_x: 0.1233	 D_g_z1: 0.3376	 D_g_z2: 1.2422	
epcoh: 000001	 D_loss: 0.6595	 G_loss: 2.2695	 D_x: 0.1478	 D_g_z1: 0.1135	 D_g_z2: 0.4672	
epcoh: 000002	 D_loss: 0.3138	 G_loss: 3.2330	 D_x: 0.1925	 D_g_z1: 0.0422	 D_g_z2: 0.1726	
epcoh: 000003	 D_loss: 0.1338	 G_loss: 4.0181	 D_x: 0.2238	 D_g_z1: 0.0190	 D_g_z2: 0.0777	
epcoh: 000004	 D_loss: 0.0654	 G_loss: 4.6523	 D_x: 0.2369	 D_g_z1: 0.0099	 D_g_z2: 0.0402	
epcoh: 000005	 D_loss: 0.0373	 G_loss: 5.1339	 D_x: 0.2425	 D_g_z1: 0.0060	 D_g_z2: 0.0245	
epcoh: 000006	 D_loss: 0.0256	 G_loss: 5.5083	 D_x: 0.2448	 D_g_z1: 0.0041	 D_g_z2: 0.0167	
epcoh: 000007	 D_loss: 0.0192	 G_loss: 5.7919	 D_x: 0.2461	 D_g_z1: 0.0031	 D_g_z2: 0.0125	
epcoh: 000008	 D_loss: 0.0153	 G_loss: 6.0293	 D_x: 0.2468	 D_g_z1: 0.0024	 D_g_z2: 0.0098	
epcoh: 000009	 D_loss: 0.0117	 G_loss: 6.2248	 D_x: 0.2476	 D_g_z1: 0.0020	 D_g_z2: 0.0081	
epcoh: 000010	 D_loss: 0.0096	 G_loss: 6.3990	 D_x: 0.2481	 D_g_z1: 0.0017	 D_g_

KeyboardInterrupt: ignored

In [None]:
# Visualize Results
test_z = (2 * torch.randn(10, NOISE_DIM) - 1).to(device)

generated = generator(test_z, test_y).detach().cpu().view(-1, 1, 28, 28)

grid = torchvision.utils.make_grid(
    generated,
    nrow=5,
    padding=10,
    pad_value=1
)

img = np.transpose(
    grid.numpy(),
    (1, 2, 0)
)

fig = plt.figure(figsize=(16, 16))
plt.axis("off")
plt.imshow(img);