In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os

# ✅ Config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
z_dim = 100
lr = 0.0002
batch_size = 128
epochs = 50
image_dir = "mnist_dcgan_output"
os.makedirs(image_dir, exist_ok=True)

# ✅ Data (MNIST, grayscale 28x28)
transform = transforms.Compose([
    transforms.Resize(28),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
dataloader = DataLoader(
    datasets.MNIST(root="./data", train=True, transform=transform, download=True),
    batch_size=batch_size, shuffle=True
)

# ✅ Generator
class Generator(nn.Module):
    def __init__(self, z_dim=100, img_channels=1, feature_g=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z_dim, feature_g * 4, 3, 1, 0),    # 1x1 -> 3x3
            nn.BatchNorm2d(feature_g * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(feature_g * 4, feature_g * 2, 4, 2, 1),  # 3x3 -> 7x7
            nn.BatchNorm2d(feature_g * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(feature_g * 2, feature_g, 4, 2, 1),      # 7x7 -> 14x14
            nn.BatchNorm2d(feature_g),
            nn.ReLU(True),

            nn.ConvTranspose2d(feature_g, img_channels, 4, 2, 1),       # 14x14 -> 28x28
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)

# ✅ Discriminator (fixed 7x7 kernel issue)
class Discriminator(nn.Module):
    def __init__(self, img_channels=1, feature_d=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(img_channels, feature_d, 4, 2, 1),         # 28x28 → 14x14
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(feature_d, feature_d * 2, 4, 2, 1),         # 14x14 → 7x7
            nn.BatchNorm2d(feature_d * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(feature_d * 2, 1, 4, 1, 0),                 # 7x7 → 4x4
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x).mean([2, 3])


: 

In [None]:
# Initialize models
G = Generator(z_dim=z_dim).to(device)
D = Discriminator().to(device)

# Loss and optimizers
criterion = nn.BCELoss()
opt_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

# Training loop
for epoch in range(epochs):
    for real, _ in dataloader:
        real = real.to(device)
        batch_size = real.size(0)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # === Train Discriminator ===
        z = torch.randn(batch_size, z_dim, 1, 1).to(device)
        fake = G(z).detach()

        d_real = D(real)
        d_fake = D(fake)
        d_loss = criterion(d_real, real_labels) + criterion(d_fake, fake_labels)

        opt_D.zero_grad()
        d_loss.backward()
        opt_D.step()

        # === Train Generator ===
        z = torch.randn(batch_size, z_dim, 1, 1).to(device)
        fake = G(z)
        d_output = D(fake)
        g_loss = criterion(d_output, real_labels)  # Generator tries to fool D

        opt_G.zero_grad()
        g_loss.backward()
        opt_G.step()

    print(f"Epoch [{epoch+1}/{epochs}] | Loss D: {d_loss.item():.4f} | Loss G: {g_loss.item():.4f}")

    if (epoch + 1) % 10 == 0:
        save_image(fake[:25], f"{image_dir}/fake_epoch_{epoch+1}.png", nrow=5, normalize=True)
