In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 128
z_dim = 100
num_classes = 10
img_size = 28
channels = 1
epochs = 50
lr = 0.0002
beta1 = 0.5

os.makedirs("cgan_generated", exist_ok=True)

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('.', train=True, download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True
)

100%|██████████| 9.91M/9.91M [00:00<00:00, 56.6MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.77MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 12.4MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.27MB/s]


In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim, num_classes, img_shape):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.img_shape = img_shape
        input_dim = z_dim + num_classes

        self.model = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(True),

            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(True),

            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(True),

            nn.Linear(1024, int(torch.prod(torch.tensor(img_shape)))),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        # Concatenate noise and label embedding
        x = torch.cat([noise, self.label_emb(labels)], dim=1)
        img = self.model(x)
        img = img.view(x.size(0), *self.img_shape)
        return img

In [None]:
class Discriminator(nn.Module):
    def __init__(self, num_classes, img_shape):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        input_dim = int(torch.prod(torch.tensor(img_shape))) + num_classes

        self.model = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        # Flatten image and concatenate label
        img_flat = img.view(img.size(0), -1)
        x = torch.cat([img_flat, self.label_emb(labels)], dim=1)
        return self.model(x)

In [None]:
img_shape = (channels, img_size, img_size)

generator = Generator(z_dim, num_classes, img_shape).to(device)
discriminator = Discriminator(num_classes, img_shape).to(device)

criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

In [None]:
k = 3  # Generator updates per iteration
p = 1  # Discriminator updates per iteration

In [None]:
for epoch in range(1, epochs + 1):
    for i, (real_imgs, real_labels) in enumerate(train_loader):
        batch_size_curr = real_imgs.size(0)
        real_imgs = real_imgs.to(device)
        real_labels = real_labels.to(device)

        real = torch.ones(batch_size_curr, 1, device=device)
        fake = torch.zeros(batch_size_curr, 1, device=device)

        ### ---- Train Discriminator p times ---- ###
        for _ in range(p):
            z = torch.randn(batch_size_curr, z_dim, device=device)
            fake_labels = torch.randint(0, num_classes, (batch_size_curr,), device=device)

            with torch.no_grad():
                gen_imgs = generator(z, fake_labels)

            # Real images
            real_validity = discriminator(real_imgs, real_labels)
            d_real_loss = criterion(real_validity, real)

            # Fake images
            fake_validity = discriminator(gen_imgs.detach(), fake_labels)
            d_fake_loss = criterion(fake_validity, fake)

            d_loss = d_real_loss + d_fake_loss

            optimizer_D.zero_grad()
            d_loss.backward()
            optimizer_D.step()

        ### ---- Train Generator k times ---- ###
        for _ in range(k):
            z = torch.randn(batch_size_curr, z_dim, device=device)
            gen_labels = torch.randint(0, num_classes, (batch_size_curr,), device=device)
            gen_imgs = generator(z, gen_labels)

            # Try to fool the discriminator
            validity = discriminator(gen_imgs, gen_labels)
            g_loss = criterion(validity, real)  # want D(G(z)) = 1

            optimizer_G.zero_grad()
            g_loss.backward()
            optimizer_G.step()

        # Print progress
        if i % 200 == 0:
            print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(train_loader)}] "
                  f"D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")

    # Save example images after each epoch
    generator.eval()
    with torch.no_grad():
        z = torch.randn(10, z_dim, device=device)
        labels = torch.arange(0, 10, dtype=torch.long, device=device)
        samples = generator(z, labels)
        samples = samples * 0.5 + 0.5  # Denormalize
        save_image(samples, f"cgan_generated/epoch_{epoch}.png", nrow=10)
    generator.train()

[Epoch 1/50] [Batch 0/469] D Loss: 1.3981 | G Loss: 0.5717
[Epoch 1/50] [Batch 200/469] D Loss: 1.3837 | G Loss: 0.6472
[Epoch 1/50] [Batch 400/469] D Loss: 1.3707 | G Loss: 0.6674
[Epoch 2/50] [Batch 0/469] D Loss: 1.3861 | G Loss: 0.6555
[Epoch 2/50] [Batch 200/469] D Loss: 1.3817 | G Loss: 0.6983
[Epoch 2/50] [Batch 400/469] D Loss: 1.3584 | G Loss: 0.7138
[Epoch 3/50] [Batch 0/469] D Loss: 1.3716 | G Loss: 0.6829
[Epoch 3/50] [Batch 200/469] D Loss: 1.4026 | G Loss: 0.6727
[Epoch 3/50] [Batch 400/469] D Loss: 1.3662 | G Loss: 0.6929
[Epoch 4/50] [Batch 0/469] D Loss: 1.4064 | G Loss: 0.6960
[Epoch 4/50] [Batch 200/469] D Loss: 1.3955 | G Loss: 0.7097
[Epoch 4/50] [Batch 400/469] D Loss: 1.4073 | G Loss: 0.7166
[Epoch 5/50] [Batch 0/469] D Loss: 1.4045 | G Loss: 0.6947
[Epoch 5/50] [Batch 200/469] D Loss: 1.3706 | G Loss: 0.7089
[Epoch 5/50] [Batch 400/469] D Loss: 1.3769 | G Loss: 0.7099
[Epoch 6/50] [Batch 0/469] D Loss: 1.3949 | G Loss: 0.6984
[Epoch 6/50] [Batch 200/469] D Loss:

In [None]:
def generate_digit_images(generator, digit, num_samples=16, save_path=None):
    generator.eval()
    z = torch.randn(num_samples, z_dim).to(device)
    labels = torch.full((num_samples,), digit, dtype=torch.long).to(device)

    with torch.no_grad():
        gen_imgs = generator(z, labels)
        gen_imgs = gen_imgs * 0.5 + 0.5

    if save_path:
        save_image(gen_imgs, save_path, nrow=4)
        print(f"Saved to {save_path}")
    return gen_imgs

In [None]:
# Generate 16 samples of digit "7"
generate_digit_images(generator, digit=7, num_samples=16, save_path="cgan_generated/seven.png")

Saved to cgan_generated/seven.png


tensor([[[[2.9802e-08, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 5.9605e-08,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 3.0100e-06],
          ...,
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 4.0531e-06, 8.0466e-07,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [7.4506e-07, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]]],


        [[[2.9802e-08, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 5.9605e-08,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 3.3379e-06],
          ...,
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.00