In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [13]:
batch_size = 128
z_dim = 100
num_classes = 10
img_size = 28
channels = 1
epochs = 50
lr = 0.0002
betal = 0.5

os.makedirs("cgan_generated", exist_ok=True)

In [14]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('.', train=True, download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True
)

In [15]:
class Generator(nn.Module):
    def __init__(self, z_dim, num_classes, img_shape):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.img_shape = img_shape
        input_dim = z_dim + num_classes

        self.model = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(True),

            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(True),

            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(True),

            nn.Linear(1024, int(torch.prod(torch.tensor(img_shape)))),
            nn.Tanh()
        )

    def forward(self, noise, labels):
      # Concatenate noise and label embedding
      x = torch.cat([noise, self.label_emb(labels)], -1)
      img = self.model(x)
      img = img.view(x.size(0), *self.img_shape)
      return img


In [16]:
class Discriminator(nn.Module):
  def __init__(self, num_classes, img_shape):
    super().__init__()
    self.label_emb = nn.Embedding(num_classes, num_classes)
    input_dim = int(torch.prod(torch.tensor(img_shape))) + num_classes

    self.model = nn.Sequential(
        nn.Linear(input_dim, 512),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Linear(512, 256),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Linear(256, 1),
        nn.Sigmoid()
    )

  def forward(self, img, labels):
    # Flatten image and concatenate label
    img_flate = img.view(img.size(0), -1)
    x = torch.cat([img_flate, self.label_emb(labels)], dim=1)
    return self.model(x)

In [17]:
img_shape = (channels, img_size, img_size)

generator = Generator(z_dim, num_classes, img_shape).to(device)
discriminator = Discriminator(num_classes, img_shape).to(device)

criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(betal, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(betal, 0.999))

In [18]:
k = 3 # Generator updates per iteration
p = 1 # Discriminator updates per iteration

In [19]:
for epoch in range(1, epochs + 1):
  for i, (real_imgs, real_labels) in enumerate(train_loader):
    batch_size_curr = real_imgs.size(0)
    real_imgs = real_imgs.to(device)
    real_labels = real_labels.to(device)

    real = torch.ones(batch_size_curr, 1, device=device)
    fake = torch.zeros(batch_size_curr, 1, device=device)

    ## Train Discriminator p times
    for _ in range(p):
      z = torch.randn(batch_size_curr, z_dim, device=device)
      fake_labels = torch.randint(0, num_classes, (batch_size_curr,), device=device)

      with torch.no_grad():
        fake_imgs = generator(z, fake_labels)

      # Real images
      real_validity = discriminator(real_imgs, real_labels)
      d_real_loss = criterion(real_validity, real)

      # Fake images
      fake_validity = discriminator(fake_imgs, fake_labels)
      d_fake_loss = criterion(fake_validity, fake)

      d_loss = d_real_loss + d_fake_loss

      optimizer_D.zero_grad()
      d_loss.backward()
      optimizer_D.step()


    ## Train Generator k times
    for _ in range(k):
      z = torch.randn(batch_size_curr, z_dim, device=device)
      gen_labels = torch.randint(0, num_classes, (batch_size_curr,), device=device)
      gen_imgs = generator(z, gen_labels)

      # Try to full the discriminator
      validity = discriminator(gen_imgs, gen_labels)
      g_loss = criterion(validity, real)

      optimizer_G.zero_grad()
      g_loss.backward()
      optimizer_G.step()

    # Print progress
    if i % 200 == 0:
      print(f'[Epoch {epoch}/{epochs}] [Batch {i}/{len(train_loader)}] [D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]')

    # Save example image afte each epoch
    generator.eval()
    with torch.no_grad():
      z = torch.randn(10, z_dim, device=device)
      labels = torch.arange(0, 10, dtype=torch.long, device=device)
      samples = generator(z, labels)
      samples = samples * 0.5 + 0.5  # Denormalize
      save_image(samples, f'cgan_generated/epoch_{epoch}.png', nrow=10)
    generator.train()




[Epoch 1/50] [Batch 0/469] [D loss: 1.3255] [G loss: 0.6091]
[Epoch 1/50] [Batch 200/469] [D loss: 1.4029] [G loss: 0.6220]
[Epoch 1/50] [Batch 400/469] [D loss: 1.3976] [G loss: 0.6563]
[Epoch 2/50] [Batch 0/469] [D loss: 1.3754] [G loss: 0.6693]
[Epoch 2/50] [Batch 200/469] [D loss: 1.3672] [G loss: 0.7090]
[Epoch 2/50] [Batch 400/469] [D loss: 1.4098] [G loss: 0.6814]
[Epoch 3/50] [Batch 0/469] [D loss: 1.3763] [G loss: 0.7083]
[Epoch 3/50] [Batch 200/469] [D loss: 1.4047] [G loss: 0.6685]
[Epoch 3/50] [Batch 400/469] [D loss: 1.3768] [G loss: 0.7066]
[Epoch 4/50] [Batch 0/469] [D loss: 1.3901] [G loss: 0.6906]
[Epoch 4/50] [Batch 200/469] [D loss: 1.3834] [G loss: 0.6818]
[Epoch 4/50] [Batch 400/469] [D loss: 1.3721] [G loss: 0.7274]
[Epoch 5/50] [Batch 0/469] [D loss: 1.3974] [G loss: 0.6891]
[Epoch 5/50] [Batch 200/469] [D loss: 1.3688] [G loss: 0.7032]
[Epoch 5/50] [Batch 400/469] [D loss: 1.3909] [G loss: 0.7056]
[Epoch 6/50] [Batch 0/469] [D loss: 1.3924] [G loss: 0.7163]
[Epo

In [20]:
def generate_digit_images(generator, digit, num_samples=16, save_path=None):
  generator.eval()
  z = torch.randn(num_samples, z_dim, device=device)
  labels = torch.full((num_samples,), digit, device=device, dtype=torch.long)

  with torch.no_grad():
    gen_images = generator(z, labels)
    gen_images = gen_images * 0.5 + 0.5

  if save_path:
    save_image(gen_imgs, save_path, nrow=4)
    print(f'Saved to {save_path}')
  return gen_images

In [None]:
# Generate 16 samples of digit 7
generated_images = generate_digit_images(generator, digit=7, save_path='generated_digit_7.png')