In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from contextlib import contextmanager

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import pandas as pd

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Generator
class Generator(nn.Module):
    def __init__(self, noise_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(noise_dim, 128 * 7 * 7),  # bsz d -> bsz 128*7*7
            nn.ReLU(True), 
            nn.Unflatten(1, (128, 7, 7)),  # bsz 128*7*7 -> bsz 128 7 7
            
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),  # bsz 128 7 7 -> bsz 64 14 14
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1, bias=False),  # bsz 64 14 14 -> bsz 32 28 28
            nn.BatchNorm2d(32),
            nn.ReLU(True),

            nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1),  # bsz 32 28 28 -> bsz 1 28 28
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)


# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1, bias=False),  # bsz 1 28 28 -> bsz 32 14 14
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1, bias=False),  # bsz 32 14 14 -> bsz 64 7 7
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Flatten(),  # bsz 64 7 7 -> bsz 64*7*7
            nn.Linear(64*7*7, 1)
        )

    def forward(self, x):
        return self.model(x)

# Loss function
def gan_loss(real_logits, fake_logits, optimizer_idx):
    if optimizer_idx == 0:  # generator update
        return -torch.mean(fake_logits)
    elif optimizer_idx == 1:  # discriminator update
        real_loss = torch.mean(torch.nn.ReLU()(1.0 - real_logits))
        fake_loss = torch.mean(torch.nn.ReLU()(1.0 + fake_logits))
        return real_loss + fake_loss


@contextmanager
def freeze_grads(model):
    """
    Temporarily freezes the parameters of a PyTorch model.
    Args:
        model (torch.nn.Module): The model whose parameters will be frozen.
    Yields:
        None
    """
    original_requires_grad = {}
    for param in model.parameters():
        original_requires_grad[param] = param.requires_grad
        param.requires_grad = False
    
    try:
        yield
    finally:
        for param, requires_grad in original_requires_grad.items():
            param.requires_grad = requires_grad

# Training loop
def train_gan(generator, discriminator, data_loader, noise_dim, num_epochs, lr, version, indices):
    generator.to(device)
    discriminator.to(device)

    optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizers = [optimizer_G, optimizer_D]

    all_stats = []
    for epoch in range(num_epochs):
        
        generator.train()

        stats_per_epoch = []
        for real_images, _ in data_loader:
            real_images = real_images.view(real_images.size(0), 1, 28, 28).to(device)
            batch_size = real_images.size(0)

            # Pass forward generator
            noise = torch.randn(batch_size, noise_dim, device=device)
            fake_images = generator(noise)

            # Train 2 steps
            stat = {}

            if version == "v0":
                for optimizer_idx in indices:
                    if optimizer_idx == 0:  # generator update
                        fake_logits = discriminator(fake_images)
                        loss = gan_loss(None, fake_logits, optimizer_idx=0)
                    elif optimizer_idx == 1:  # discriminator update
                        real_logits = discriminator(real_images)
                        fake_logits = discriminator(fake_images.detach())
                        loss = gan_loss(real_logits, fake_logits, optimizer_idx=1)
                    optimizers[optimizer_idx].zero_grad()
                    loss.backward()
                    optimizers[optimizer_idx].step()
                    stat[f'loss_{optimizer_idx}'] = loss.item()
 
            elif version == "v1":
                for optimizer_idx in indices:
                    optimizers[optimizer_idx].zero_grad()
                for optimizer_idx in indices:
                    if optimizer_idx == 0:  # generator update
                        with freeze_grads(discriminator):
                            fake_logits = discriminator(fake_images)
                            loss = gan_loss(None, fake_logits, optimizer_idx=0)
                    elif optimizer_idx == 1:  # discriminator update
                        with freeze_grads(generator):
                            real_logits = discriminator(real_images)
                            fake_logits = discriminator(fake_images.detach())
                            loss = gan_loss(real_logits, fake_logits, optimizer_idx=1)
                    stat[f'loss_{optimizer_idx}'] = loss.item()
                    loss.backward()
                for optimizer_idx in indices:
                    optimizers[optimizer_idx].step()
            
            stats_per_epoch.append(stat)
        
        all_stats.append(stats_per_epoch)
        avg_stats = {key: sum(stat[key] for stat in stats_per_epoch) / len(stats_per_epoch) for key in stats_per_epoch[0]}
        print(f"Epoch [{epoch+1}/{num_epochs}] Loss D: {avg_stats['loss_1']:.4f}, Loss G: {avg_stats['loss_0']:.4f}")
        
        # Generate images
        generator.eval()
        noise = torch.randn(16, noise_dim, device=device)
        fake_images = generator(noise)
        fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
        fake_images = (fake_images + 1) / 2
        fake_images = fake_images.clamp(0, 1)
        grid = fake_images.cpu().detach().numpy()
        grid = grid.transpose(0, 2, 3, 1)
        plt.figure(figsize=(8, 2))
        for i in range(16):
            plt.subplot(2, 8, i+1)
            plt.imshow(grid[i], cmap='gray')
            plt.axis('off')
        plt.show()

    return all_stats

# Hyperparameters
noise_dim = 100
batch_size = 64
num_epochs = 3
lr = 0.0002

def do_expe(version, indices):
    # Seed and make deterministic
    torch.manual_seed(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

    # Data loader
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])
    mnist = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    data_loader = DataLoader(mnist, batch_size=batch_size, shuffle=True)

    # Initialize models
    generator = Generator(noise_dim).to(device)
    discriminator = Discriminator().to(device)

    # Train GAN
    train_stats = train_gan(generator, discriminator, data_loader, noise_dim, num_epochs, lr, version=version, indices=indices)

    # # plot
    # df = pd.DataFrame(train_stats)
    # df.plot()
    # plt.show()

do_expe("v0", [0, 1])
do_expe("v1", [0, 1])
do_expe("v0", [1, 0])
do_expe("v1", [1, 0])