In [None]:
# Plan 1 GAN Training Notebook

# Import necessary libraries
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Import GAN models and training functions
from plan1_gan_models import (
    SimpleGenerator,
    EmbeddingAsInputGenerator,
    Discriminator,
    AC_Generator,
    AC_Discriminator,
    Info_Generator,
    Info_Discriminator,
)
from plan1_gan_training import (
    train_normal_gan,
    train_acgan,
    train_infogan,
)

# Load embeddings (replace with actual embedding loading code)
# Example: Loading a precomputed embedding
embeddings = torch.randn(1000, 50)  # Replace with your actual embeddings

# Prepare the embedding DataLoader
from plan1_gan_training import prepare_embedding_loader
embedding_loader = prepare_embedding_loader(embeddings, add_noise=True, noise_dim=50, batch_size=64)

# Load real dataset (replace with actual data loading)
# Example: Using a DataLoader for the MNIST dataset
from torchvision import datasets, transforms

mnist_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])
mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=mnist_transform)
real_data_loader = DataLoader(mnist_dataset, batch_size=64, shuffle=True)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Train Normal GAN
print("\n--- Training Normal GAN ---\n")
generator = SimpleGenerator(input_dim=100).to(device)
discriminator = Discriminator().to(device)
train_normal_gan(
    generator,
    discriminator,
    embedding_loader,
    real_data_loader,
    latent_dim=100,
    epochs=20,
    device=device
)

# Train ACGAN
print("\n--- Training ACGAN ---\n")
generator = AC_Generator(latent_dim=100, num_classes=10).to(device)
discriminator = AC_Discriminator(num_classes=10).to(device)
train_acgan(
    generator,
    discriminator,
    embedding_loader,
    real_data_loader,
    latent_dim=100,
    num_classes=10,
    epochs=20,
    device=device
)

# Train InfoGAN
print("\n--- Training InfoGAN ---\n")
generator = Info_Generator(embedding_dim=50, continuous_dim=2, discrete_dim=10).to(device)
discriminator = Info_Discriminator(continuous_dim=2, discrete_dim=10).to(device)
train_infogan(
    generator,
    discriminator,
    embedding_loader,
    real_data_loader,
    latent_dim=100,
    continuous_dim=2,
    discrete_dim=10,
    epochs=20,
    device=device
)
