# DCGAN

### Imports

In [None]:
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.optim as optim

Tensor = torch.FloatTensor

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(device)

### Parameters

In [None]:
os.makedirs("./images/dcgan", exist_ok=True)
os.makedirs("./models", exist_ok=True)

params = {
    "n_epochs" : 200,
    "batch_size" : 64,
    "lr" : 2e-4,
    "b1" : 0.5,
    "b2" : 0.999,
    #"n_cpu" : 8,
    "latent_dim" : 100,
    "img_size" : 32,
    "channels" : 1,
    "sample_interval" : 1000,
    "load_chk" : True,
}

### Dataset - MNIST

In [None]:
# Configure data loader
os.makedirs("./data/mnist", exist_ok=True)

dataset = datasets.MNIST(
            "./data/mnist",
            train=True,
            download=True,
            transform=transforms.Compose(
                [transforms.Resize(params['img_size']), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
            ),
        )
dataloader = DataLoader(
    dataset,
    batch_size=params['batch_size'],
    shuffle=True,
)

testset = datasets.MNIST(
            "./data/mnist",
            train=False,
            download=True,
            transform=transforms.Compose(
                [transforms.Resize(params['img_size']), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
            ),
        )
testloader = DataLoader(
    testset,
    batch_size=params['batch_size'],
    shuffle=True,
)
#print(len(dataset))

# Generator & Discriminator

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = params['img_size'] // 4
        self.l1 = nn.Sequential(nn.Linear(params['latent_dim'], 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, params['channels'], 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(params['channels'], 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = params['img_size'] // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)

        return validity

### Training

In [None]:
if not params['load_chk'] :

    # Loss function
    adversarial_loss = torch.nn.BCELoss()

    # Initialize generator and discriminator
    generator = Generator()
    discriminator = Discriminator()

    generator.to(device)
    discriminator.to(device)
    adversarial_loss.to(device)

    # Initialize weights
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=params['lr'], betas=(params['b1'], params['b2']))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=params['lr'], betas=(params['b1'], params['b2']))

    g_loss_list = []
    d_loss_list = []

    for epoch in range(params['n_epochs']):
        for i, (imgs, _) in enumerate(dataloader):

            # Adversarial ground truths
            valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
            fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)
            valid = valid.to(device)
            fake = fake.to(device)

            # Configure input
            real_imgs = Variable(imgs.type(Tensor))
            real_imgs = real_imgs.to(device)

            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()

            # Sample noise as generator input
            z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], params['latent_dim']))))
            z = z.to(device)

            # Generate a batch of images
            gen_imgs = generator(z)

            # Loss measures generator's ability to fool the discriminator
            g_loss = adversarial_loss(discriminator(gen_imgs), valid)

            g_loss.backward()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()

            # Measure discriminator's ability to classify real from generated samples
            real_loss = adversarial_loss(discriminator(real_imgs), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()

            batches_done = epoch * len(dataloader) + i
            if batches_done % params['sample_interval'] == 0:
                save_image(gen_imgs.data[:25], "./images/dcgan/%d.png" % batches_done, nrow=5, normalize=True)

        print(
            "[Epoch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch+1, params['n_epochs'], d_loss.item(), g_loss.item())
        )
        d_loss_list.append(d_loss.item())
        g_loss_list.append(g_loss.item())

    torch.save({
        'gen' : generator.state_dict(),
        'disc' : discriminator.state_dict(),
    }, './models/dcgan.pth')

# Evaluation

### CNN for Classification

In [None]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # Layer 1: Convolutional. Input = 32x32x1. Output = 28x28x6.
        self.conv1 = nn.Conv2d(1, 6, (5,5))
        # Layer 2: Convolutional. Output = 10x10x16.
        self.conv2 = nn.Conv2d(6, 16, (5,5))
        # Layer 3: Fully Connected. Input = 400. Output = 120.
        self.fc1   = nn.Linear(400, 120)
        # Layer 4: Fully Connected. Input = 120. Output = 84.
        self.fc2   = nn.Linear(120, 84)
        # Layer 5: Fully Connected. Input = 84. Output = 10.
        self.fc3   = nn.Linear(84, 10)
    def forward(self, x):
        # Activation. # Pooling. Input = 28x28x6. Output = 14x14x6.
        x = F.max_pool2d(F.relu(self.conv1(x)), (2,2))
         # Activation. # Pooling. Input = 10x10x16. Output = 5x5x16.
        x = F.max_pool2d(F.relu(self.conv2(x)), (2,2))
        # Flatten. Input = 5x5x16. Output = 400.
        x = x.flatten(start_dim=1)
        # Activation.
        x = F.relu(self.fc1(x))
        # Activation.
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features 

# Train & test part from https://github.com/activatedgeek/LeNet-5
def train(epoch):
    global cur_batch_win
    net.train()
    loss_list, batch_list = [], []
    for i, (images, labels) in enumerate(dataloader):
        optimizer.zero_grad()
        output = net(images.to(device))
        loss = criterion(output, labels.to(device))

        loss_list.append(loss.detach().cpu().item())
        batch_list.append(i+1)

        #if i % 10 == 0:
        #    print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch, i, loss.detach().cpu().item()))

        loss.backward()
        optimizer.step()

def evaluate(target_loader, target_dataset):
    predictions = []
    net.eval()
    total_correct = 0
    avg_loss = 0.0
    for i, (images, labels) in enumerate(target_loader):
        output = net(images.to(device))
        avg_loss += criterion(output, labels.to(device)).sum()
        pred = output.detach().max(1)[1]
        total_correct += pred.eq(labels.to(device).view_as(pred)).sum()
        predictions.append(pred)

    avg_loss /= len(target_dataset)
    avg_loss = avg_loss.detach().cpu().item()
    accuracy    = float(total_correct) / len(target_dataset)
    print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss, accuracy))
    #return accuracy, np.array(torch.cat(predictions))
    return accuracy, avg_loss

### Training

In [None]:
if not params['load_chk'] :

    net = LeNet()
    criterion = nn.CrossEntropyLoss()
    #optimizer = optim.Adam(net.parameters(), lr=0.001)
    optimizer = optim.SGD(net.parameters(), lr=0.001)
    net.to(device)

    EPOCHS = 20
    print("Training...")
    val_acc_list = []
    val_loss_list = []

    for e in range(EPOCHS):
        print("Epoch : {}".format(e+1))
        train(e)
        val_acc, val_loss = evaluate(testloader, testset)
        val_acc_list.append(val_acc)
        val_loss_list.append(val_loss)

    torch.save({
        'cnn' : net.state_dict()
        }, './models/cnn_{}.pth'.format(EPOCHS))

### Load Checkpoint

In [None]:
if params['load_chk'] :

    #loading models from file
    generator = Generator()
    discriminator = Discriminator()
    net = LeNet()

    chk = torch.load('./models/dcgan.pth', map_location=device)
    chk_cnn = torch.load('./models/cnn_{}.pth'.format(EPOCHS), map_location=device)
    generator.load_state_dict(chk['gen'])
    discriminator.load_state_dict(chk['disc'])
    net.load_state_dict(chk_cnn['cnn'])

    generator.to(device)
    discriminator.to(device)
    net.to(device)

### Inception Score

In [None]:
def inception_score(r):
    p_y = np.mean(r, axis=0)
    e = r/p_y
    e = r*np.log(e, where= e>0)
    e = np.sum(e, axis=1)
    e = np.mean(e, axis=0)
    return np.exp(e)

# Generated Data for IS evaluation : 1,000 total
test_size = 1000

# Generate a batch of images
z = Variable(Tensor(np.random.normal(0, 1, (test_size, params['latent_dim']))))
z = z.to(device)
gen_imgs = generator(z)

# Obtain Predictions (w/ softmax)
preds = F.softmax(net(gen_imgs),dim=1)
r = preds.cpu().detach().numpy()

i_score = inception_score(r)
print("Inception Score : {}".format(i_score))