In [None]:
import torch
import torch.nn as nn
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import transforms
import torch.optim as optim
import os
import shutil

In [None]:
class Discriminator(nn.Module):
    def __init__(self, image_dim, num_classes, embed_length):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(image_dim+embed_length, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        self.embed = nn.Embedding(num_classes, embed_length)

    def forward(self, x, label):
        embedding = self.embed(label)
        x = torch.cat([x, embedding], dim=1)
        return self.disc(x)

class Generator(nn.Module):
    def __init__(self, z_dim, image_dim, num_classes, embed_length):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim+embed_length, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, image_dim),
            nn.Tanh()
        )
        self.embed = nn.Embedding(num_classes, embed_length)

    def forward(self, z, label):
        embedding = self.embed(label)
        z = torch.cat([z, embedding], dim=1)
        return self.gen(z)

In [None]:
image_dim = 28 * 28 * 1
batch_size = 32
z_dim = 100
epochs = 100
lr = 3e-4
num_classes = 10
embed_length = 16

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
disc = Discriminator(image_dim, num_classes, embed_length).to(device)
gen = Generator(z_dim, image_dim, num_classes, embed_length).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
fixed_label = torch.randint(0, 10, (batch_size,)).to(device)

pixel_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(0.5, 0.5)
    ]
)

In [None]:
dataset = datasets.MNIST(root="dataset/", transform=pixel_transform, download=False)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
disc_optim = optim.Adam(disc.parameters(), lr=lr)
gen_optim = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()

In [None]:
for epoch in range(epochs):
    for batch_idx, (real, label) in  enumerate(loader):
        real = real.view(-1, 784).to(device)
        label = label.to(device)

        noise = torch.randn((batch_size, z_dim)).to(device)
        fake = gen(noise, label)
        disc_real = disc(real, label).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake, label).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        disc_optim.step()

        output = disc(fake, label).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        gen_optim.step()

In [None]:
model_path = f"RudiCGAN.pth"
if os.path.exists(model_path):
    shutil.rmtree(model_path)

torch.save(gen.state_dict(), "RudiCGAN.pth")