In [None]:
import torch
import os
from torch import nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import torch.nn.functional as F
from PIL import Image
from matplotlib import pyplot as plt

In [None]:
image_size = 28*28
batch_size = 128
num_epochs = 30
latent_dim = 100
num_classes = 10

OUTPUT_DIR = "./data/dcgan"

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
if not os.path.exists(OUTPUT_DIR):
    os.mkdir(OUTPUT_DIR)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)

        self.model = nn.Sequential(
            nn.ConvTranspose2d(latent_dim + num_classes, 128, kernel_size=7, stride=1, padding=0),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        c = self.label_emb(labels)
        c = c.view(-1, num_classes, 1, 1)
        noise = noise.view(-1, latent_dim, 1, 1)
        x = torch.cat([noise, c], 1)
        return self.model(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)

        self.model = nn.Sequential(
            nn.Conv2d(1 + num_classes, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Conv2d(128, 1, kernel_size=7, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        labels = self.label_emb(labels)
        labels = labels.view(-1, num_classes, 1, 1).expand(-1, -1, image_size, image_size)
        img = torch.cat((img, labels), 1)
        return self.model(img)

In [None]:
# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Optimizers
d_optimizer = optim.Adam(discriminator.parameters(), lr=1e-4)
g_optimizer = optim.Adam(generator.parameters(), lr=1e-4)

# Loss function
criterion = nn.BCELoss()

In [None]:
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        real_images, labels = images.to(device), labels.to(device)
        real_labels = torch.ones(batch_size, 1, 1, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1, 1, 1).to(device)

        # Train Discriminator
        d_optimizer.zero_grad()
        real_outputs = discriminator(real_images, labels)
        d_real_loss = criterion(real_outputs, real_labels)

        z = torch.randn(batch_size, latent_dim, 1, 1).to(device)
        fake_images = generator(z, labels)
        fake_outputs = discriminator(fake_images, labels)
        d_fake_loss = criterion(fake_outputs, fake_labels)
        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        d_optimizer.step()

        # Train Generator
        g_optimizer.zero_grad()
        z = torch.randn(batch_size, latent_dim, 1, 1).to(device)
        fake_images = generator(z, labels)
        outputs = discriminator(fake_images, labels)
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        g_optimizer.step()

        if (i + 1) % 400 == 0:
            print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}')

    # Save generated images every epoch
    save_image(fake_images.data.reshape(batch_size, 1, 28, 28), f'{OUTPUT_DIR}/fake_image-{epoch + 1:03d}.png')

In [None]:
for label in range(10):
    z = torch.randn(batch_size, latent_dim).to(device)
    labels = torch.full((batch_size,), label, device=device, dtype=torch.long)
    fake_images = generator(z, labels)

    save_image(fake_images.reshape(fake_images.size(0), 1, 28, 28), os.path.join(OUTPUT_DIR, f'result-{label}.png'))

In [None]:
image = Image.open(OUTPUT_DIR + "/result-0.png")
plt.imshow(image)
plt.show()

In [None]:
image = Image.open(OUTPUT_DIR + "/result-1.png")
plt.imshow(image)
plt.show()

In [None]:
image = Image.open(OUTPUT_DIR + "/result-2.png")
plt.imshow(image)
plt.show()

In [None]:
image = Image.open(OUTPUT_DIR + "/result-3.png")
plt.imshow(image)
plt.show()

In [None]:
image = Image.open(OUTPUT_DIR + "/result-4.png")
plt.imshow(image)
plt.show()

In [None]:
image = Image.open(OUTPUT_DIR + "/result-5.png")
plt.imshow(image)
plt.show()

In [None]:
image = Image.open(OUTPUT_DIR + "/result-6.png")
plt.imshow(image)
plt.show()

In [None]:
image = Image.open(OUTPUT_DIR + "/result-7.png")
plt.imshow(image)
plt.show()

In [None]:
image = Image.open(OUTPUT_DIR + "/result-8.png")
plt.imshow(image)
plt.show()

In [None]:
image = Image.open(OUTPUT_DIR + "/result-9.png")
plt.imshow(image)
plt.show()