In [None]:
import torch
import os
from torch import nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import torch.nn.functional as F
from PIL import Image
from matplotlib import pyplot as plt

In [None]:
image_size = 28*28
batch_size = 128
num_epochs = 30
latent_dim = 100
num_classes = 10
image_shape = (1, 28, 28)
OUTPUT_DIR = "./data/dcgan"

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
if not os.path.exists(OUTPUT_DIR):
    os.mkdir(OUTPUT_DIR)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 128, 4, 1, 0, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        z = z.view(-1, latent_dim, 1, 1)
        return self.model(z)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 1, 7, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, img):
        img = img.view(-1, *image_shape)
        return self.model(img)


In [None]:
# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Optimizers
d_optimizer = optim.Adam(discriminator.parameters(), lr=1e-4)
g_optimizer = optim.Adam(generator.parameters(), lr=1e-4)

# Loss function
criterion = nn.BCELoss()

In [None]:
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(train_loader):
        real_images = images.to(device)
        real_labels = torch.ones(batch_size, 1, 1, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1, 1, 1).to(device)

        # Train Discriminator
        d_optimizer.zero_grad()
        real_loss = criterion(discriminator(real_images), real_labels)
        z = torch.randn(batch_size, latent_dim, 1, 1).to(device)
        fake_images = generator(z)
        fake_loss = criterion(discriminator(fake_images.detach()), fake_labels)
        d_loss = real_loss + fake_loss
        d_loss.backward()
        d_optimizer.step()

        # Train Generator
        g_optimizer.zero_grad()
        g_loss = criterion(discriminator(fake_images), real_labels)
        g_loss.backward()
        g_optimizer.step()

        if (i + 1) % 400 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}')

    # Save generated images every epoch
    save_image(fake_images.data, f'{OUTPUT_DIR}/fake_image-{epoch+1:03d}.png', normalize=True)

In [None]:
image = Image.open(OUTPUT_DIR + f'{OUTPUT_DIR}/fake_image-005.png')
plt.imshow(image)
plt.show()

In [None]:
image = Image.open(OUTPUT_DIR + f'{OUTPUT_DIR}/fake_image-010.png')
plt.imshow(image)
plt.show()

In [None]:
image = Image.open(OUTPUT_DIR + f'{OUTPUT_DIR}/fake_image-015.png')
plt.imshow(image)
plt.show()

In [None]:
image = Image.open(OUTPUT_DIR + f'{OUTPUT_DIR}/fake_image-020.png')
plt.imshow(image)
plt.show()

In [None]:
image = Image.open(OUTPUT_DIR + f'{OUTPUT_DIR}/fake_image-025.png')
plt.imshow(image)
plt.show()

In [None]:
image = Image.open(OUTPUT_DIR + f'{OUTPUT_DIR}/fake_image-029.png')
plt.imshow(image)
plt.show()