In [None]:
# Plan 2 GAN Training Notebook

# Import necessary libraries
import torch
from torch.utils.data import DataLoader
from plan2_gan_models import (
    SimpleGANGenerator, SimpleGANDiscriminator,
    ContrastiveGANGenerator, ContrastiveGANDiscriminator,
    VAEGANEncoder, VAEGANGenerator, VAEGANDiscriminator,
    WGANGenerator, WGANCritic,
    CrossDomainGenerator, CrossDomainDiscriminator,
    CycleGenerator,
    DualGANGenerator, DualGANDiscriminator,
    ContrastiveDualGANGenerator, ContrastiveDualGANDiscriminator,
    SemiSupervisedGANDiscriminator
)
from plan2_gan_training import (
    train_wgan_gp, train_vae_gan, train_contrastive_gan,
    train_cross_domain_gan, train_cycle_gan,
    train_dual_gan, train_contrastive_dual_gan,
    train_semi_supervised_gan
)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Sample DataLoader for demonstration purposes
def create_dummy_loader(num_samples=1000, embedding_dim=50, batch_size=64):
    """
    Creates a dummy DataLoader for testing GAN training.

    Args:
        num_samples (int): Number of samples in the dataset.
        embedding_dim (int): Dimensionality of embeddings.
        batch_size (int): Batch size for DataLoader.

    Returns:
        DataLoader: DataLoader containing dummy embeddings.
    """
    embeddings = torch.randn(num_samples, embedding_dim)
    labels = torch.randint(0, 10, (num_samples,))  # For Semi-Supervised GAN
    dataset = torch.utils.data.TensorDataset(embeddings, labels)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Example loaders for embeddings
embedding_loader = create_dummy_loader()
embedding_loader_a = create_dummy_loader()
embedding_loader_b = create_dummy_loader()

# Train WGAN-GP
print("\n--- Training WGAN-GP ---\n")
generator = WGANGenerator(latent_dim=100, embedding_dim=50).to(device)
critic = WGANCritic(embedding_dim=50).to(device)
train_wgan_gp(
    generator,
    critic,
    embedding_loader,
    latent_dim=100,
    epochs=10,
    device=device
)

# Train VAE-GAN
print("\n--- Training VAE-GAN ---\n")
encoder = VAEGANEncoder(embedding_dim=50, latent_dim=20).to(device)
generator = VAEGANGenerator(latent_dim=20, embedding_dim=50).to(device)
discriminator = VAEGANDiscriminator(embedding_dim=50).to(device)
train_vae_gan(
    encoder,
    generator,
    discriminator,
    embedding_loader,
    latent_dim=20,
    epochs=10,
    device=device
)

# Train Contrastive GAN
print("\n--- Training Contrastive GAN ---\n")
generator = ContrastiveGANGenerator(latent_dim=100, embedding_dim=50).to(device)
discriminator = ContrastiveGANDiscriminator(embedding_dim=50).to(device)
train_contrastive_gan(
    generator,
    discriminator,
    embedding_loader,
    latent_dim=100,
    epochs=10,
    device=device
)

# Train Cross-Domain GAN
print("\n--- Training Cross-Domain GAN ---\n")
generator = CrossDomainGenerator(latent_dim=100, embedding_dim=50).to(device)
discriminator = CrossDomainDiscriminator(embedding_dim=50).to(device)
train_cross_domain_gan(
    generator,
    discriminator,
    embedding_loader,
    latent_dim=100,
    epochs=10,
    device=device
)

# Train CycleGAN
print("\n--- Training CycleGAN ---\n")
generator_a = CycleGenerator(embedding_dim=50).to(device)
generator_b = CycleGenerator(embedding_dim=50).to(device)
discriminator_a = SimpleGANDiscriminator(embedding_dim=50).to(device)
discriminator_b = SimpleGANDiscriminator(embedding_dim=50).to(device)
train_cycle_gan(
    generator_a,
    generator_b,
    discriminator_a,
    discriminator_b,
    embedding_loader_a,
    embedding_loader_b,
    epochs=10,
    device=device
)

# Train Dual GAN
print("\n--- Training Dual GAN ---\n")
generator_a = DualGANGenerator(latent_dim=100, embedding_dim=50).to(device)
generator_b = DualGANGenerator(latent_dim=100, embedding_dim=50).to(device)
discriminator_a = DualGANDiscriminator(embedding_dim=50).to(device)
discriminator_b = DualGANDiscriminator(embedding_dim=50).to(device)
train_dual_gan(
    generator_a,
    generator_b,
    discriminator_a,
    discriminator_b,
    embedding_loader_a,
    embedding_loader_b,
    epochs=10,
    device=device
)

# Train Contrastive-Guided Dual GAN
print("\n--- Training Contrastive-Guided Dual GAN ---\n")
generator_a = ContrastiveDualGANGenerator(latent_dim=100, embedding_dim=50).to(device)
generator_b = ContrastiveDualGANGenerator(latent_dim=100, embedding_dim=50).to(device)
discriminator_a = ContrastiveDualGANDiscriminator(embedding_dim=50).to(device)
discriminator_b = ContrastiveDualGANDiscriminator(embedding_dim=50).to(device)
train_contrastive_dual_gan(
    generator_a,
    generator_b,
    discriminator_a,
    discriminator_b,
    embedding_loader_a,
    embedding_loader_b,
    epochs=10,
    device=device
)

# Train Semi-Supervised GAN
print("\n--- Training Semi-Supervised GAN ---\n")
generator = WGANGenerator(latent_dim=100, embedding_dim=50).to(device)
discriminator = SemiSupervisedGANDiscriminator(embedding_dim=50, num_classes=10).to(device)
train_semi_supervised_gan(
    generator,
    discriminator,
    embedding_loader,
    latent_dim=100,
    num_classes=10,
    epochs=10,
    device=device
)
