In [1]:
import torch
import os
from torch import nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import torch.nn.functional as F
from PIL import Image
from matplotlib import pyplot as plt

In [None]:
image_size = 28*28
batch_size = 128
num_epochs = 100
latent_dim = 100
num_classes = 10

OUTPUT_DIR = "./data/cgan"

# Load Data

In [3]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
if not os.path.exists(OUTPUT_DIR):
    os.mkdir(OUTPUT_DIR)

# Model

In [4]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.label_emb = nn.Embedding(num_classes, num_classes)

        self.model = nn.Sequential(
            nn.Linear(latent_dim + num_classes, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, image_size),
            nn.Tanh()
        )

    def forward(self, noise, context):
        noise = noise.view(-1, latent_dim)
        context_feature = self.label_emb(context)
        x = torch.cat([noise, context_feature], 1)

        return self.model(x)

In [5]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.label_emb = nn.Embedding(num_classes, num_classes)

        self.model = nn.Sequential(
            nn.Linear(image_size + num_classes, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, context):
        img = img.view(-1, image_size)
        context_feature = self.label_emb(context)

        x = torch.cat((img, context_feature), dim=1)
        return self.model(x)

# Train

In [6]:
# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Optimizers
d_optimizer = optim.Adam(discriminator.parameters(), lr=1e-4)
g_optimizer = optim.Adam(generator.parameters(), lr=1e-4)

# Loss function
criterion = nn.BCELoss()

In [7]:
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        real_images, labels = images.to(device), labels.to(device)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # --------- Train the Teacher (D) --------- #
        d_optimizer.zero_grad()
        outputs = discriminator(real_images, labels)
        d_real_loss = criterion(outputs, real_labels)

        z = torch.randn(batch_size, latent_dim).to(device)
        fake_images = generator(z, labels)
        outputs = discriminator(fake_images, labels)
        d_fake_loss = criterion(outputs, fake_labels)

        d_loss = d_real_loss + d_fake_loss

        d_loss.backward()
        d_optimizer.step()

        # --------- Train the Student (G) --------- #
        g_optimizer.zero_grad()
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_images = generator(z, labels)
        outputs = discriminator(fake_images, labels)
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        g_optimizer.step()

        if (i+1) % 400 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}')

    # Save generated images every epoch
    save_image(fake_images.reshape(fake_images.size(0), 1, 28, 28), './data/cgan/fake_image-%03d.png' % (epoch+1))

print("Training complete.")

Epoch [1/200], Step [400/468], D Loss: 0.9366745948791504, G Loss: 2.1608455181121826
Epoch [2/200], Step [400/468], D Loss: 0.7984415292739868, G Loss: 1.1351810693740845
Epoch [3/200], Step [400/468], D Loss: 0.691845178604126, G Loss: 2.7750165462493896
Epoch [4/200], Step [400/468], D Loss: 0.1668279469013214, G Loss: 3.363769054412842
Epoch [5/200], Step [400/468], D Loss: 0.25167202949523926, G Loss: 5.232165813446045
Epoch [6/200], Step [400/468], D Loss: 0.12460343539714813, G Loss: 5.728146553039551
Epoch [7/200], Step [400/468], D Loss: 0.16725042462348938, G Loss: 4.162979602813721
Epoch [8/200], Step [400/468], D Loss: 0.38697531819343567, G Loss: 5.2110066413879395
Epoch [9/200], Step [400/468], D Loss: 0.11041513085365295, G Loss: 5.125596523284912
Epoch [10/200], Step [400/468], D Loss: 0.32399511337280273, G Loss: 4.737241268157959
Epoch [11/200], Step [400/468], D Loss: 0.33340802788734436, G Loss: 3.6825966835021973
Epoch [12/200], Step [400/468], D Loss: 0.1875279843

KeyboardInterrupt: 

# Test

In [None]:
for label in range (10):
    z = torch.randn(batch_size, latent_dim).to(device)
    fakes_classes = torch.ones(10, device=device, dtype=torch.long) * label
    fake_images = generator(z, labels)

    save_image(fake_images.reshape(fake_images.size(0), 1, 28, 28), './data/cgan/result-%d.png' % (label))

In [None]:
image = Image.open(OUTPUT_DIR + "result-0.png")
plt.imshow(image)
plt.show()

In [None]:
image = Image.open(OUTPUT_DIR + "result-1.png")
plt.imshow(image)
plt.show()

In [None]:
image = Image.open(OUTPUT_DIR + "result-2.png")
plt.imshow(image)
plt.show()

In [None]:
image = Image.open(OUTPUT_DIR + "result-3.png")
plt.imshow(image)
plt.show()

In [None]:
image = Image.open(OUTPUT_DIR + "result-4.png")
plt.imshow(image)
plt.show()

In [None]:
image = Image.open(OUTPUT_DIR + "result-5.png")
plt.imshow(image)
plt.show()

In [None]:
image = Image.open(OUTPUT_DIR + "result-6.png")
plt.imshow(image)
plt.show()

In [None]:
image = Image.open(OUTPUT_DIR + "result-7.png")
plt.imshow(image)
plt.show()

In [None]:
image = Image.open(OUTPUT_DIR + "result-8.png")
plt.imshow(image)
plt.show()

In [None]:
image = Image.open(OUTPUT_DIR + "result-9.png")
plt.imshow(image)
plt.show()