# CGAN

### Imports

In [9]:
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.optim as optim

FloatTensor = torch.FloatTensor
LongTensor = torch.LongTensor

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(device)

cpu


### Parameters

In [10]:
os.makedirs("./images/cgan", exist_ok=True)
os.makedirs("./models", exist_ok=True)

params = {
    "n_epochs" : 200,
    "batch_size" : 64,
    "lr" : 2e-4,
    "b1" : 0.5,
    "b2" : 0.999,
    #"n_cpu" : 8,
    "n_classes" : 10,
    "latent_dim" : 100,
    "img_size" : 32,
    "channels" : 1,
    "sample_interval" : 1000,
    "load_chk" : True,
}

img_shape = (params['channels'], params['img_size'], params['img_size'])

### Dataset - MNIST

In [11]:
# Configure data loader
os.makedirs("./data/mnist", exist_ok=True)

dataset = datasets.MNIST(
            "./data/mnist",
            train=True,
            download=True,
            transform=transforms.Compose(
                [transforms.Resize(params['img_size']), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
            ),
        )
dataloader = DataLoader(
    dataset,
    batch_size=params['batch_size'],
    shuffle=True,
)

testset = datasets.MNIST(
            "./data/mnist",
            train=False,
            download=True,
            transform=transforms.Compose(
                [transforms.Resize(params['img_size']), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
            ),
        )
testloader = DataLoader(
    testset,
    batch_size=params['batch_size'],
    shuffle=True,
)

# Generator & Discriminator

In [12]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.label_emb = nn.Embedding(params['n_classes'], params['n_classes'])

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(params['latent_dim'] + params['n_classes'], 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        # Concatenate label embedding and image to produce input
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        img = self.model(gen_input)
        img = img.view(img.size(0), *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.label_embedding = nn.Embedding(params['n_classes'], params['n_classes'])

        self.model = nn.Sequential(
            nn.Linear(params['n_classes'] + int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1),
        )

    def forward(self, img, labels):
        # Concatenate label embedding and image to produce input
        d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)
        validity = self.model(d_in)
        return validity

### Training

In [13]:
def sample_image(n_row, batches_done):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Sample noise
    z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, params['latent_dim'])))).to(device)
    # Get labels ranging from 0 to n_classes for n rows
    labels = np.array([num for _ in range(n_row) for num in range(n_row)])
    labels = Variable(LongTensor(labels)).to(device)
    gen_imgs = generator(z, labels)
    save_image(gen_imgs.data, "./images/cgan/%d.png" % batches_done, nrow=n_row, normalize=True)

if not params['load_chk'] :

    # Loss functions
    adversarial_loss = torch.nn.MSELoss()

    # Initialize generator and discriminator
    generator = Generator()
    discriminator = Discriminator()
    #adversarial_loss.to(device)

    generator.to(device)
    discriminator.to(device)

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=params['lr'], betas=(params['b1'], params['b2']))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=params['lr'], betas=(params['b1'], params['b2']))

    # ----------
    #  Training
    # ----------
    g_loss_list = []
    d_loss_list = []

    for epoch in range(params['n_epochs']):
        for i, (imgs, labels) in enumerate(dataloader):

            batch_size = imgs.shape[0]

            # Adversarial ground truths
            valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False).to(device)
            fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False).to(device)

            # Configure input
            real_imgs = Variable(imgs.type(FloatTensor)).to(device)
            labels = Variable(labels.type(LongTensor)).to(device)

            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()

            # Sample noise and labels as generator input
            z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, params['latent_dim'])))).to(device)
            gen_labels = Variable(LongTensor(np.random.randint(0, params['n_classes'], batch_size))).to(device)

            # Generate a batch of images
            gen_imgs = generator(z, gen_labels)

            # Loss measures generator's ability to fool the discriminator
            validity = discriminator(gen_imgs, gen_labels)
            g_loss = adversarial_loss(validity, valid)

            g_loss.backward()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()

            # Loss for real images
            validity_real = discriminator(real_imgs, labels)
            d_real_loss = adversarial_loss(validity_real, valid)

            # Loss for fake images
            validity_fake = discriminator(gen_imgs.detach(), gen_labels)
            d_fake_loss = adversarial_loss(validity_fake, fake)

            # Total discriminator loss
            d_loss = (d_real_loss + d_fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()

            batches_done = epoch * len(dataloader) + i
            if batches_done % params['sample_interval'] == 0:
                sample_image(n_row=10, batches_done=batches_done)

        print(
            "[Epoch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch+1, params['n_epochs'], d_loss.item(), g_loss.item())
        )
        d_loss_list.append(d_loss.item())
        g_loss_list.append(g_loss.item())

    torch.save({
        'gen' : generator.state_dict(),
        'disc' : discriminator.state_dict(),
    }, './models/cgan.pth')

# Evaluation

### CNN for classification

In [14]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # Layer 1: Convolutional. Input = 32x32x1. Output = 28x28x6.
        self.conv1 = nn.Conv2d(1, 6, (5,5))
        # Layer 2: Convolutional. Output = 10x10x16.
        self.conv2 = nn.Conv2d(6, 16, (5,5))
        # Layer 3: Fully Connected. Input = 400. Output = 120.
        self.fc1   = nn.Linear(400, 120)
        # Layer 4: Fully Connected. Input = 120. Output = 84.
        self.fc2   = nn.Linear(120, 84)
        # Layer 5: Fully Connected. Input = 84. Output = 10.
        self.fc3   = nn.Linear(84, 10)
    def forward(self, x):
        # Activation. # Pooling. Input = 28x28x6. Output = 14x14x6.
        x = F.max_pool2d(F.relu(self.conv1(x)), (2,2))
         # Activation. # Pooling. Input = 10x10x16. Output = 5x5x16.
        x = F.max_pool2d(F.relu(self.conv2(x)), (2,2))
        # Flatten. Input = 5x5x16. Output = 400.
        x = x.flatten(start_dim=1)
        # Activation.
        x = F.relu(self.fc1(x))
        # Activation.
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features 

# Train & test part from https://github.com/activatedgeek/LeNet-5
def train(epoch):
    global cur_batch_win
    net.train()
    loss_list, batch_list = [], []
    for i, (images, labels) in enumerate(dataloader):
        optimizer.zero_grad()
        output = net(images.to(device))
        loss = criterion(output, labels.to(device))

        loss_list.append(loss.detach().cpu().item())
        batch_list.append(i+1)

        #if i % 10 == 0:
        #    print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch, i, loss.detach().cpu().item()))

        loss.backward()
        optimizer.step()

def evaluate(target_loader, target_dataset):
    predictions = []
    net.eval()
    total_correct = 0
    avg_loss = 0.0
    for i, (images, labels) in enumerate(target_loader):
        output = net(images.to(device))
        avg_loss += criterion(output, labels.to(device)).sum()
        pred = output.detach().max(1)[1]
        total_correct += pred.eq(labels.to(device).view_as(pred)).sum()
        predictions.append(pred)

    avg_loss /= len(target_dataset)
    avg_loss = avg_loss.detach().cpu().item()
    accuracy    = float(total_correct) / len(target_dataset)
    print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss, accuracy))
    #return accuracy, np.array(torch.cat(predictions))
    return accuracy, avg_loss

### Training

In [15]:
EPOCHS = 20
if not params['load_chk'] :

    net = LeNet()
    criterion = nn.CrossEntropyLoss()
    #optimizer = optim.Adam(net.parameters(), lr=0.001)
    optimizer = optim.SGD(net.parameters(), lr=0.001)
    net.to(device)

    print("Training...")
    val_acc_list = []
    val_loss_list = []

    for e in range(EPOCHS):
        print("Epoch : {}".format(e+1))
        train(e)
        val_acc, val_loss = evaluate(testloader, testset)
        val_acc_list.append(val_acc)
        val_loss_list.append(val_loss)

    torch.save({
        'cnn' : net.state_dict()
        }, './models/cnn_{}.pth'.format(EPOCHS))

### Load Checkpoint

In [16]:
if params['load_chk'] :

    #loading models from file
    generator = Generator()
    discriminator = Discriminator()
    net = LeNet()

    chk = torch.load('./models/cgan.pth', map_location=device)
    chk_cnn = torch.load('./models/cnn_{}.pth'.format(EPOCHS), map_location=device)
    generator.load_state_dict(chk['gen'])
    discriminator.load_state_dict(chk['disc'])
    net.load_state_dict(chk_cnn['cnn'])

    generator.to(device)
    discriminator.to(device)
    net.to(device)

### Inception Score

In [17]:
def inception_score(r):
    p_y = np.mean(r, axis=0)
    e = r/p_y
    e = r*np.log(e, where= e>0)
    e = np.sum(e, axis=1)
    e = np.mean(e, axis=0)
    return np.exp(e)

# Generated Data for IS evaluation : 1,000 total
test_size = 1000

# Generate a batch of images
z = Variable(FloatTensor(np.random.normal(0, 1, (test_size, params['latent_dim'])))).to(device)
gen_labels = Variable(LongTensor(np.random.randint(0, params['n_classes'], test_size))).to(device)
gen_imgs = generator(z, gen_labels)

# Obtain Predictions (w/ softmax)
preds = F.softmax(net(gen_imgs),dim=1)
r = preds.cpu().detach().numpy()

i_score = inception_score(r)
print("Inception Score : {}".format(i_score))

Inception Score : 8.422941207885742


### TARR
- Use label (i) to generate image w/ CGAN -> (i) becomes ground-truth
- Compare Prediction Accuracy

In [18]:
class TestDataset(Dataset):
    def __init__(self, x_data, y_data):
        self.x_data = x_data
        self.y_data = y_data
        self.len = self.y_data.shape[0]

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len

In [19]:
# Generated Data for IS evaluation : 10,000 total, 1000 of each class
test_size = 1000
criterion = nn.CrossEntropyLoss()
test_labels = torch.arange(0, params['n_classes']).repeat(test_size)

# Generate a batch of images
z = Variable(FloatTensor(np.random.normal(0, 1, (test_size*params['n_classes'], params['latent_dim'])))).to(device)
gen_labels = Variable(LongTensor(test_labels).to(device))
gen_imgs = generator(z, gen_labels)

gen_dataset = TestDataset(gen_imgs, gen_labels)
gen_loader = DataLoader(gen_dataset, batch_size=params['batch_size'], shuffle=True)

acc = evaluate(gen_loader, gen_dataset)

Test Avg. Loss: 0.001321, Accuracy: 0.978100
