In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

z_dim = 100
image_dim = 28 * 28
num_classes = 10
batch_size = 128
lr = 0.0002
epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = torchvision.datasets.MNIST(root="./data", train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

class Generator(nn.Module):
    def __init__(self, z_dim, num_classes, img_dim):
        super().__init__()
        self.label_embedding = nn.Embedding(num_classes, z_dim)
        self.model = nn.Sequential(
            nn.Linear(z_dim * 2, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, img_dim),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        label_emb = self.label_embedding(labels)
        x = torch.cat([noise, label_emb], dim=1)
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self, img_dim, num_classes):
        super().__init__()
        self.label_embedding = nn.Embedding(num_classes, img_dim)
        self.model = nn.Sequential(
            nn.Linear(img_dim * 2, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        label_emb = self.label_embedding(labels)
        x = torch.cat([img, label_emb], dim=1)
        return self.model(x)

generator = Generator(z_dim, num_classes, image_dim).to(device)
discriminator = Discriminator(image_dim, num_classes).to(device)

criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

for epoch in range(epochs):
    for real_imgs, labels in dataloader:
        real_imgs, labels = real_imgs.view(-1, image_dim).to(device), labels.to(device)

        batch_size = real_imgs.shape[0]
        real = torch.ones(batch_size, 1).to(device)
        fake = torch.zeros(batch_size, 1).to(device)

        noise = torch.randn(batch_size, z_dim).to(device)
        fake_labels = torch.randint(0, num_classes, (batch_size,)).to(device)
        fake_imgs = generator(noise, fake_labels)

        optimizer_D.zero_grad()
        loss_real = criterion(discriminator(real_imgs, labels), real)
        loss_fake = criterion(discriminator(fake_imgs.detach(), fake_labels), fake)
        loss_D = (loss_real + loss_fake) / 2
        loss_D.backward()
        optimizer_D.step()

        optimizer_G.zero_grad()
        loss_G = criterion(discriminator(fake_imgs, fake_labels), real)
        loss_G.backward()
        optimizer_G.step()

    print(f"Epoch [{epoch+1}/{epochs}]  Loss D: {loss_D.item():.4f}, Loss G: {loss_G.item():.4f}")

    with torch.no_grad():
        noise = torch.randn(10, z_dim).to(device)
        labels = torch.arange(0, 10).to(device)
        generated_imgs = generator(noise, labels).view(-1, 1, 28, 28).cpu()
        grid = torchvision.utils.make_grid(generated_imgs, nrow=10, normalize=True)
        plt.imshow(grid.permute(1, 2, 0))
        plt.show()