In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [12]:
batch_size = 128
dim_z = 100
dim_label = 10  # Number of classes (0-9)
num_epochs = 50
lr = 0.0002
embed_len = 16

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [14]:
# Generator using CNN
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_embedding = nn.Embedding(dim_label, embed_len)  # Embedding layer for labels
        self.model = nn.Sequential(
            nn.ConvTranspose2d(dim_z + dim_label, 128, 4, 1, 0, bias=False),  # Upsample to 4x4
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 3, 2, 1, bias=False),  # Upsample to 7x7
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),  # Upsample to 28x28
            nn.Tanh()
        )

    def forward(self, z, labels):
        labels = self.label_embedding(labels).unsqueeze(2).unsqueeze(3)  # Convert labels to embedding and reshape
        input = torch.cat([z, labels], dim=1)  # Concatenate latent vector and labels over C
        return self.model(input)

In [15]:
# Discriminator using CNN
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.label_embedding = nn.Embedding(dim_label, embed_len)  # Embedding layer for labels
        self.model = nn.Sequential(
            nn.Conv2d(1 + dim_label, 64, 4, 2, 1, bias=False),  # Downsample to 14x14
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),  # Downsample to 7x7
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 1, 3, 1, 0, bias=False),  # Final classification layer (1x1 output)
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        labels = self.label_embedding(labels).unsqueeze(2).unsqueeze(3)  # Convert labels to embedding and reshape
        labels = labels.expand(-1, -1, img.size(2), img.size(3))  # Expand labels to match image dimensions
        input = torch.cat([img, labels], dim=1)  # Concatenate image and label
        return self.model(input).view(-1, 1)  # Flatten output

In [16]:
# Initialize networks
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Loss and Optimizers
criterion = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))


writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")
step = 0

In [19]:
# Load MNIST Dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)


In [20]:
# Training Loop
for epoch in range(num_epochs):
    for i, (real_images, labels) in enumerate(dataloader):
        batch_size = real_images.size(0)
        real_images = real_images.to(device)
        labels = labels.to(device)

        # Train Discriminator
        z = torch.randn(batch_size, dim_z, 1, 1).to(device)  # Generate random latent vector
        fake_images = generator(z, labels).detach()  # Generate fake images and detach from computation graph
        real_preds = discriminator(real_images, labels)  # Discriminator prediction on real images
        fake_preds = discriminator(fake_images, labels)  # Discriminator prediction on fake images

        d_loss_real = criterion(real_preds, torch.ones_like(real_preds))  # Loss for real images
        d_loss_fake = criterion(fake_preds, torch.zeros_like(fake_preds))  # Loss for fake images
        d_loss = (d_loss_real + d_loss_fake)/2  # Total loss for discriminator

        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # Train Generator
        z = torch.randn(batch_size, dim_z, 1, 1).to(device)  # Generate new latent vector
        fake_images = generator(z, labels)  # Generate new fake images
        fake_preds = discriminator(fake_images, labels)  # Discriminator's response to fake images
        g_loss = criterion(fake_preds, torch.ones_like(fake_preds))  # Generator wants to fool discriminator

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()



        print(f"Epoch [{epoch+1}/{num_epochs}] D Loss: {d_loss.item():.4f} G Loss: {g_loss.item():.4f}")




RuntimeError: Given transposed=1, weight of size [110, 128, 4, 4], expected input[128, 116, 1, 1] to have 110 channels, but got 116 channels instead

In [21]:
# Save Model
torch.save(generator.state_dict(), "generator_cnn.pth")

In [22]:
# Generate Sample Images
def generate_images(generator, num_samples=10):
    generator.eval()
    z = torch.randn(num_samples, dim_z, 1, 1).to(device)  # Generate latent vectors
    labels = torch.arange(num_samples).to(device)  # Generate sequential labels
    with torch.no_grad():
        fake_images = generator(z, labels).cpu()  # Generate images without tracking gradients
    return fake_images, labels



In [23]:
# Generate and visualize
generated_images, labels = generate_images(generator)
fig, axes = plt.subplots(1, 10, figsize=(10, 2))
for i in range(10):
    axes[i].imshow(generated_images[i].squeeze(), cmap="gray")  # Display generated images
    axes[i].set_title(f"Label: {labels[i].item()}")
    axes[i].axis("off")
plt.show()

RuntimeError: Given transposed=1, weight of size [110, 128, 4, 4], expected input[10, 116, 1, 1] to have 110 channels, but got 116 channels instead

In [None]:
'''
import matplotlib.pyplot as plt

def plot_images(images, labels):
    fig, axes = plt.subplots(1, len(images), figsize=(10, 2))
    for i, img in enumerate(images):
        axes[i].imshow(img.squeeze(), cmap="gray")  # Display generated images
        axes[i].set_title(f"Label: {labels[i].item()}")
        axes[i].axis("off")
    plt.show()
'''



'\nimport matplotlib.pyplot as plt\n\ndef plot_images(images, labels):\n    fig, axes = plt.subplots(1, len(images), figsize=(10, 2))\n    for i, img in enumerate(images):\n        axes[i].imshow(img.squeeze(), cmap="gray")  # Display generated images\n        axes[i].set_title(f"Label: {labels[i].item()}")\n        axes[i].axis("off")\n    plt.show()\n'