In [None]:
import torch
from torch import nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image

In [None]:
image_size = 28*28
batch_size = 128
num_epochs = 200
latent_dim = 100

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim=100, context_dim=10, output_dim=28 * 28):
        super(Generator, self).__init__()

        self.hidden1_z = nn.Sequential(nn.Linear(latent_dim, 200), 
                                       nn.Dropout(p=0.5), 
                                       nn.ReLU(), )
        
        self.hidden1_context = nn.Sequential(nn.Linear(context_dim, 1000), 
                                             nn.Dropout(p=0.5), 
                                             nn.ReLU(), )
        
        self.hidden2 = nn.Sequential(nn.Linear(1200, 1200), 
                                     nn.Dropout(p=0.5), 
                                     nn.ReLU(), )
        
        self.out_layer = nn.Sequential(nn.Linear(1200, output_dim), 
                                       nn.Sigmoid(), )

    def forward(self, noise, context):
        h = torch.cat((self.hidden1_z(noise), self.hidden1_context(context)), dim=1)
        h = self.hidden2(h)
        return self.out_layer(h)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_dim=28 * 28, context_dim=10):
        super(Discriminator, self).__init__()

        self.hidden1_x = nn.Sequential(nn.Linear(input_dim, 240), 
                                       nn.Dropout(p=0.5), 
                                       nn.LeakyReLU(), )
        
        self.hidden1_context = nn.Sequential(nn.Linear(context_dim, 50), 
                                             nn.Dropout(p=0.5), 
                                             nn.LeakyReLU(), )
        
        self.hidden2 = nn.Sequential(nn.Linear(290, 240), 
                                     nn.Dropout(p=0.5), 
                                     nn.LeakyReLU(), )
        
        self.out_layer = nn.Sequential(nn.Linear(240, 1), 
                                       nn.Sigmoid(), )

    def forward(self, img, context):
        img = img.view(img.size(0), -1)
        h = torch.cat((self.hidden1_x(img), self.hidden1_context(context)), dim=1)
        h = self.hidden2(h)
        return self.out_layer(h)

In [None]:
# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Optimizers
d_optimizer = optim.SGD(discriminator.parameters(), lr=0.1, momentum=0.5)
g_optimizer = optim.SGD(generator.parameters(), lr=0.1, momentum=0.5)

# Loss function
criterion = nn.BCELoss()

schedulers = [torch.optim.lr_scheduler.ExponentialLR(d_optimizer, 1 / 1.00004),
                torch.optim.lr_scheduler.ExponentialLR(g_optimizer, 1 / 1.00004)]

In [None]:
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        real_images = images.to(device)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # --------- Train the Teacher (D) --------- #
        d_optimizer.zero_grad() 
        outputs = discriminator(real_images, labels)
        d_real_loss = criterion(outputs, real_labels)
        
        z = torch.randn(batch_size, latent_dim).to(device) 
        fake_images = generator(z, labels)
        outputs = discriminator(fake_images.detach(), labels)
        d_fake_loss = criterion(outputs, fake_labels)

        d_loss = (d_real_loss + d_fake_loss) / 2

        d_loss.backward()
        d_optimizer.step()
        
        # --------- Train the Student (G) --------- #
        z = torch.randn(batch_size, latent_dim).to(device) 
        fake_images = generator(z, labels)
        outputs = discriminator(fake_images, labels)
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        g_optimizer.step()

        for scheduler in schedulers:
            scheduler.step()

        if (i+1) % 400 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}')

    # Save generated images every epoch
    save_image(fake_images.reshape(fake_images.size(0), 1, 28, 28), f'./data/fake_images-{epoch+1}.png')

print("Training complete.")