In [1]:
from email.generator import Generator

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [2]:
# hyper parameters
latent_dim = 100
num_labels = 10
image_size = 28
batch_size = 64
epochs = 50
learning_rate = 2e-4
device = torch.device("mps" if torch.mps.is_available() else "cpu")

if torch.device('mps').type == "mps" and torch.backends.mps.is_available():
    print("MPS is available")
else:
    print("MPS is not available")

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

mnist = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
data_loader = DataLoader(mnist, batch_size=batch_size, shuffle=True)


MPS is available


In [3]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.label_embedding = nn.Embedding(num_embeddings=num_labels, embedding_dim=num_labels)

        self.layers = nn.Sequential(
            nn.Linear(latent_dim + num_labels, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, image_size * image_size),
            nn.Tanh()            
        )

    def forward(self, noise, labels):
        noise = noise.view(noise.size(0), latent_dim)
        labels = self.label_embedding(labels)
        combined_input = torch.cat((noise, labels), dim=1)
        output = self.layers(combined_input)
        return output.view(combined_input.size(0), 1, image_size, image_size)

In [4]:
class CNN_Generator(nn.Module):
    def __init__(self):
        super(CNN_Generator, self).__init__()
        
        self.label_embedding = nn.Embedding(num_embeddings=num_labels, embedding_dim=num_labels)
        
        self.fc_layer = nn.Linear(latent_dim + num_labels, 7 * 7 * 128)
        
        self.conv_layers = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), # 14x14 128 channels
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),    
            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=4, stride=2, padding=1), # 28x28 64 channels
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=3, stride=1, padding=1), # 28x28 32 channels
            nn.Tanh()            
        )
        
    def forward(self, noise, labels):
        noise = noise.view(noise.size(0), latent_dim)
        labels = self.label_embedding(labels)
        combined_input = torch.cat((noise, labels), dim=1)
        output = self.fc_layer(combined_input)
        reshaped = output.view(output.size(0), 128, 7, 7)
        output = self.conv_layers(reshaped)
        return output.view(combined_input.size(0), 1, image_size, image_size)

In [5]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.label_embedding = nn.Embedding(num_embeddings=num_labels, embedding_dim=num_labels)
        
        self.layers = nn.Sequential(
            nn.Linear(image_size * image_size + num_labels, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.1),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.1),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.1),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
    def forward(self, image, labels):
        image = image.view(image.size(0), image_size * image_size)
        embedding = self.label_embedding(labels)
        combined_input = torch.cat((image, embedding), dim=1)
        return self.layers(combined_input)                      # out.squeeze()?

In [6]:
class CNN_Discriminator(nn.Module):
    def __init__(self):
        super(CNN_Discriminator, self).__init__()
        
        self.label_embedding = nn.Embedding(num_embeddings=num_labels, embedding_dim=image_size*image_size)
        
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=2, out_channels=32, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        self.fc_layers = nn.Sequential(
            nn.Linear(7 * 7 * 64, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.1),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.1),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
    def forward(self, image, labels):
        # embed labels and concat as extra channel
        label_embeddings = self.label_embedding(labels).view(labels.size(0), 1, image_size, image_size)
        input_with_labels = torch.cat((image, label_embeddings), dim=1)
        
        # conv layers
        processed_image = self.conv_layers(input_with_labels)
        flattened_image = processed_image.view(processed_image.size(0), -1)
        
        # fc layers
        return self.fc_layers(flattened_image) 


In [7]:
# create models
generator = CNN_Generator().to(device)
discriminator = CNN_Discriminator().to(device)

adversarial_loss = nn.BCELoss()

optim_gen = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optim_disc = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

In [8]:
# Training Loop

for epoch in range(epochs):
    for i, (imgs, lbls) in enumerate(data_loader):
        batch_size = imgs.size(0)
        real_images = imgs.to(device)
        real_labels = lbls.to(device)

        # Create real / fake labels (1x1 tensor)
        real = torch.ones(batch_size, 1).to(device)
        fake = torch.zeros(batch_size, 1).to(device)

        # Train generator
        optim_gen.zero_grad()
        rand_noise = torch.randn(batch_size, latent_dim).to(device)
        gen_labels = torch.randint(0, num_labels, (batch_size,)).to(device)
        gen_images = generator.forward(rand_noise, gen_labels)
        gen_loss = adversarial_loss(discriminator(gen_images, gen_labels), real)
        gen_loss.backward()
        optim_gen.step()

        # Train discriminator
        optim_disc.zero_grad()
        disc_real_loss = adversarial_loss(discriminator.forward(real_images, real_labels), real)
        disc_fake_loss = adversarial_loss(discriminator.forward(gen_images.detach(), gen_labels), fake)
        disc_loss_total = disc_real_loss + disc_fake_loss
        disc_loss_total.backward()
        optim_disc.step()

    # Print progress
    print(f"Epoch {epoch+1} of {epochs}: gen_loss = {gen_loss:.4f}, disc_loss = {disc_loss_total}")

    # Generate and save samples after each epoch

    z = torch.randn(10, latent_dim).to(device)
    sample_labels = torch.arange(0, 10).to(device)
    generated_imgs = generator(z, sample_labels).cpu().detach()
    grid = torch.cat([img for img in generated_imgs], dim=2).squeeze()
    plt.imshow(grid, cmap="gray")
    plt.axis("off")
    plt.show()

In [9]:
torch.save(generator.state_dict(), "generator_cnn_v2.pth")
torch.save(discriminator.state_dict(), "discriminator_cnn_v2.pth")