<a href="https://colab.research.google.com/github/jeffvun/digit-generator-app/blob/main/train_digit_generator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# !pip install torch torchvision matplotlib

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os


In [2]:
# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
latent_dim = 100
num_classes = 10
image_shape = (1, 28, 28)
img_size = 28

# Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_loader = DataLoader(
    datasets.MNIST(root='./data', train=True, download=True, transform=transform),
    batch_size=128, shuffle=True
)

100%|██████████| 9.91M/9.91M [00:00<00:00, 12.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 337kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 2.69MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.33MB/s]


In [3]:
# Generator
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(latent_dim + num_classes, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, int(np.prod(image_shape))),
            nn.Tanh()
        )

    def forward(self, z, labels):
        c = self.label_emb(labels)
        x = torch.cat([z, c], dim=1)
        out = self.model(x)
        return out.view(out.size(0), *image_shape)


In [4]:

# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(num_classes + int(np.prod(image_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        c = self.label_emb(labels)
        x = torch.cat([img.view(img.size(0), -1), c], dim=1)
        return self.model(x)



In [5]:
# Initialize
G = Generator().to(device)
D = Discriminator().to(device)

adversarial_loss = nn.BCELoss()
optimizer_G = torch.optim.Adam(G.parameters(), lr=0.0002)
optimizer_D = torch.optim.Adam(D.parameters(), lr=0.0002)


In [6]:

# Training loop (light version for demo – increase epochs if needed)
epochs = 10
for epoch in range(epochs):
    for imgs, labels in train_loader:
        batch_size = imgs.size(0)
        valid = torch.ones(batch_size, 1).to(device)
        fake = torch.zeros(batch_size, 1).to(device)
        real_imgs = imgs.to(device)
        labels = labels.to(device)

        # Train Generator
        z = torch.randn(batch_size, latent_dim).to(device)
        gen_labels = torch.randint(0, num_classes, (batch_size,), device=device)
        gen_imgs = G(z, gen_labels)
        g_loss = adversarial_loss(D(gen_imgs, gen_labels), valid)

        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

        # Train Discriminator
        real_loss = adversarial_loss(D(real_imgs, labels), valid)
        fake_loss = adversarial_loss(D(gen_imgs.detach(), gen_labels), fake)
        d_loss = (real_loss + fake_loss) / 2

        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

    print(f"Epoch {epoch+1}/{epochs} | D loss: {d_loss.item():.4f} | G loss: {g_loss.item():.4f}")


Epoch 1/10 | D loss: 0.0485 | G loss: 5.4930
Epoch 2/10 | D loss: 0.2127 | G loss: 2.4785
Epoch 3/10 | D loss: 0.0642 | G loss: 7.0778
Epoch 4/10 | D loss: 0.1115 | G loss: 7.4612
Epoch 5/10 | D loss: 0.1059 | G loss: 4.1566
Epoch 6/10 | D loss: 0.0777 | G loss: 7.8390
Epoch 7/10 | D loss: 0.0356 | G loss: 3.7385
Epoch 8/10 | D loss: 0.1385 | G loss: 3.0959
Epoch 9/10 | D loss: 0.1613 | G loss: 4.4026
Epoch 10/10 | D loss: 0.1846 | G loss: 5.3350


In [7]:

# Save Generator
os.makedirs("models", exist_ok=True)
torch.save(G.state_dict(), "models/cgan_generator.pt")