In [1]:
import os
import numpy as np
import math
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch

In [2]:
epochs, batch_size  = 100, 64
lr, b1, b2 = 2e-4, 0.5, 0.999
latent_dim = 100
img_size = 32
channels = 1
n_classes = 10
img_shape = (channels, img_size, img_size)
if torch.cuda.is_available(): 
    print("Train on GPU \nCUDA is available")
    cuda = True 
else:
    print("Train on the CPU \nCUDA is not available")
    cuda = False

Train on GPU 
CUDA is available


In [3]:
os.makedirs("data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=batch_size,
    shuffle=True,
)

In [4]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(n_classes, n_classes)
        
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim + n_classes, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        input_z = torch.cat((self.label_emb(labels), noise), -1)
        img = self.model(input_z)
        img = img.view(img.size(0), *img_shape)
        return img

In [5]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.label_embedding = nn.Embedding(n_classes, n_classes)
        
        self.model = nn.Sequential(
            nn.Linear(n_classes + int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1),
        )

    def forward(self, img, labels):
        input_x = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)
        pred = self.model(input_x)
        return pred

In [6]:
G = Generator()
D = Discriminator()
adversarial_loss = torch.nn.MSELoss()

if cuda:
    G.cuda()
    D.cuda()
    adversarial_loss.cuda()

In [7]:
def sample_save(n_row, epoch):
    z = Variable(FloatTensor(np.random.normal
                             (0, 1, (n_row ** 2, latent_dim))))
    labels = np.array([num for _ in range(n_row) 
                       for num in range(n_row)])
    labels = Variable(LongTensor(labels))
    img_g = G(z, labels)
    save_image(img_g.data, "CGAN_results/%d.png" 
               %epoch, nrow=n_row, normalize=True)

In [9]:
optimizer_G = torch.optim.Adam(G.parameters(),lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(D.parameters(),lr=lr, betas=(b1, b2))

FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

os.makedirs("CGAN_results", exist_ok=True)
for epoch in range(epochs):
    for i, (imgs, labels) in enumerate(dataloader):

        batch_size = imgs.shape[0]

        real = Variable(FloatTensor(batch_size, 1).fill_(1.0),requires_grad=False)
        fake = Variable(FloatTensor(batch_size, 1).fill_(0.0),requires_grad=False)
        real_imgs = Variable(imgs.type(FloatTensor))
        labels = Variable(labels.type(LongTensor))
        
        ## Train Generator ##
        optimizer_G.zero_grad()
        z = Variable(FloatTensor(np.random.normal
                                 (0, 1, (batch_size, latent_dim))))
        gen_labels = Variable(LongTensor(np.random.randint(0, n_classes, batch_size)))
        gen_imgs = G(z, gen_labels)
        fake_pred = D(gen_imgs, gen_labels)
        g_loss = adversarial_loss(fake_pred, real)
        g_loss.backward()
        optimizer_G.step()
        
        ## Train Discriminator ##
        optimizer_D.zero_grad()
        real_pred = D(real_imgs, labels)
        d_real_loss = adversarial_loss(real_pred, real)
        fake_pred = D(gen_imgs.detach(), gen_labels)
        d_fake_loss = adversarial_loss(fake_pred, fake)
        d_loss = (d_real_loss + d_fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

    print(
        "[Epoch %d/%d] [D loss: %f] [G loss: %f]"
        % (epoch, epochs, d_loss.item(), g_loss.item())
    )

    sample_save(n_row=10, epoch=epoch+1)
torch.save(G.state_dict(), './Generator.pth')
torch.save(D.state_dict(), './Discriminator.pth')

  real = Variable(FloatTensor(batch_size, 1).fill_(1.0),requires_grad=False)


[Epoch 0/100] [D loss: 0.136978] [G loss: 0.345224]
[Epoch 1/100] [D loss: 0.076319] [G loss: 0.675061]
[Epoch 2/100] [D loss: 0.096402] [G loss: 0.581036]
[Epoch 3/100] [D loss: 0.074213] [G loss: 0.718553]
[Epoch 4/100] [D loss: 0.148068] [G loss: 0.495342]
[Epoch 5/100] [D loss: 0.153791] [G loss: 0.514250]
[Epoch 6/100] [D loss: 0.231861] [G loss: 0.176790]
[Epoch 7/100] [D loss: 0.146515] [G loss: 0.541679]
[Epoch 8/100] [D loss: 0.180008] [G loss: 0.426851]
[Epoch 9/100] [D loss: 0.210126] [G loss: 0.389083]
[Epoch 10/100] [D loss: 0.176833] [G loss: 0.459842]
[Epoch 11/100] [D loss: 0.196350] [G loss: 0.651260]
[Epoch 12/100] [D loss: 0.184078] [G loss: 0.458031]
[Epoch 13/100] [D loss: 0.228992] [G loss: 0.258839]
[Epoch 14/100] [D loss: 0.189214] [G loss: 0.353010]
[Epoch 15/100] [D loss: 0.246752] [G loss: 0.247599]
[Epoch 16/100] [D loss: 0.222593] [G loss: 0.341538]
[Epoch 17/100] [D loss: 0.210312] [G loss: 0.235605]
[Epoch 18/100] [D loss: 0.241484] [G loss: 0.291656]
[Ep