# Plan 2: Test GAN Models

## Testing All GAN Types

This section ensures that all GAN types in the pipeline are working correctly by running each one with a minimal training configuration (1 epoch).

In [1]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import logging
from datetime import datetime
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset

from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Repository path (adjust if needed)
repo_path = "/content/drive/MyDrive/GAN-thesis-project"

# Add repository path to sys.path for module imports
if repo_path not in sys.path:
    sys.path.append(repo_path)

# Change working directory to the repository
os.chdir(repo_path)

# Verify the working directory
print(f"Current working directory: {os.getcwd()}")

# GAN Models and Training Functions
from src.gan_workflows.plan2.plan2_gan_models import (
    SimpleGANGenerator, SimpleGANDiscriminator,
    ContrastiveGANGenerator, ContrastiveGANDiscriminator,
    VAEGANEncoder, VAEGANGenerator, VAEGANDiscriminator,
    WGANGenerator, WGANCritic,
    CrossDomainGenerator, CrossDomainDiscriminator,
    CycleGenerator, CycleDiscriminator,
    DualGANGenerator, DualGANDiscriminator,
    ContrastiveDualGANGenerator, ContrastiveDualGANDiscriminator,
    SemiSupervisedGANDiscriminator,
    ConditionalGANGenerator, ConditionalGANDiscriminator,
    InfoGANGenerator, InfoGANDiscriminator,
    compute_gradient_penalty
)
from src.gan_workflows.plan2.plan2_gan_training import (
    train_wgan_gp, train_vae_gan, train_contrastive_gan,
    train_cross_domain_gan, train_cycle_gan,
    train_dual_gan, train_contrastive_dual_gan,
    train_semi_supervised_gan,
    train_conditional_gan,
    train_infogan
)

from src.data_utils import (
    load_embeddings, analyze_embeddings
)

# Set random seed and device
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Current working directory: /content/drive/MyDrive/GAN-thesis-project
Using device: cpu


In [2]:
embedding_dir = "./saved_embeddings/embeddings/autoencoders_BasicAutoencoder"  # Example embedding path
embedding_file = os.path.join(embedding_dir, "BasicAutoencoder_embeddings.pt")

# Load embeddings and labels using the new function
embeddings, labels, data_loader = load_embeddings(embedding_file, device)

# Analyze the embeddings (optional: set expected_dim if necessary)
analyze_embeddings(embeddings, expected_dim=50, labels=labels)

  data = torch.load(embedding_file)
INFO - Embeddings are of shape: torch.Size([7000, 50])
INFO - Data type: torch.float32
INFO - Device: cpu
INFO - Mean: -0.10591454058885574
INFO - Standard Deviation: 4.491614818572998
INFO - Min: -22.35531997680664
INFO - Max: 15.808435440063477
INFO - Median: 0.15539658069610596
INFO - Mean L2 Norm: 31.488544464111328
INFO - Standard Deviation of L2 Norms: 4.21454381942749
INFO - Each embedding has 50 dimensions.
INFO - Sparsity (proportion of non-zero elements): 1.0
INFO - Skewness of embeddings: -0.6672645237194799
INFO - Kurtosis of embeddings: 1.5908177243367119


Loading embeddings from: ./saved_embeddings/embeddings/autoencoders_BasicAutoencoder/BasicAutoencoder_embeddings.pt


INFO - Pairwise distance (mean): 25.41992176637177
INFO - Average Cosine similarity with true labels: -0.00153539318125695
INFO - Number of outliers detected in embeddings: 30
INFO - Embeddings analysis completed.


In [3]:
# Step 1: Setup configuration for testing
gan_types = [
    "WGAN-GP", "VAE-GAN", "Contrastive-GAN", "Cross-Domain-GAN",
    "Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN",
    "Semi-Supervised-GAN", "Conditional-GAN", "InfoGAN"
]

latent_dim = 100
embedding_dim = embeddings.size(1)  # Assuming embeddings are already loaded in your workspace
num_classes = 10  # For Conditional GAN and InfoGAN (e.g., 10 classes for MNIST)
categorical_dim = 10  # For InfoGAN (e.g., 10 categories for MNIST)
epochs = 3  # Run for 2 epochs as requested
learning_rate = 1e-4
# Initialize the data loaders before the GAN models
embedding_loader_a = data_loader
embedding_loader_b = data_loader
embedding_loader = data_loader

# GAN Configuration Dictionary
gan_configurations = {
    "WGAN-GP": {
        "generator": WGANGenerator,
        "critic": WGANCritic,
        "train_function": train_wgan_gp,
        "train_kwargs": {"lambda_gp": 10}
    },
    "VAE-GAN": {
        "encoder": VAEGANEncoder,
        "generator": VAEGANGenerator,
        "discriminator": VAEGANDiscriminator,
        "train_function": train_vae_gan
    },
    "Contrastive-GAN": {
        "generator": ContrastiveGANGenerator,
        "discriminator": ContrastiveGANDiscriminator,
        "train_function": train_contrastive_gan
    },
    "Cross-Domain-GAN": {
        "generator": CrossDomainGenerator,
        "discriminator": CrossDomainDiscriminator,
        "train_function": train_cross_domain_gan
    },
    "Cycle-GAN": {
        "generator_a": CycleGenerator,
        "generator_b": CycleGenerator,
        "discriminator_a": CycleDiscriminator,
        "discriminator_b": CycleDiscriminator,
        "train_function": train_cycle_gan
    },
    "Dual-GAN": {
        "generator_a": DualGANGenerator,
        "generator_b": DualGANGenerator,
        "discriminator_a": DualGANDiscriminator,
        "discriminator_b": DualGANDiscriminator,
        "train_function": train_dual_gan
    },
    "Contrastive-Dual-GAN": {
        "generator_a": ContrastiveDualGANGenerator,
        "generator_b": ContrastiveDualGANGenerator,
        "discriminator_a": ContrastiveDualGANDiscriminator,
        "discriminator_b": ContrastiveDualGANDiscriminator,
        "train_function": train_contrastive_dual_gan
    },
    "Semi-Supervised-GAN": {
        "generator": SimpleGANGenerator,
        "discriminator": SemiSupervisedGANDiscriminator,
        "train_function": train_semi_supervised_gan
    },
    "Conditional-GAN": {
        "generator": ConditionalGANGenerator,
        "discriminator": ConditionalGANDiscriminator,
        "train_function": train_conditional_gan
    },
    "InfoGAN": {
        "generator": InfoGANGenerator,
        "discriminator": InfoGANDiscriminator,
        "train_function": train_infogan
    }
}

In [4]:
# Step 2: Loop through GAN types and test each one
for gan_type in gan_types:
    print(f"\nTesting {gan_type}...")

    try:
        # Initialize the correct GAN models based on `gan_type`
        if gan_type in gan_configurations:
            config = gan_configurations[gan_type]
            print(f"Initializing {gan_type} configuration: {config}")

            # Generator initialization for Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN
            if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
                if gan_type == "Cycle-GAN":
                    # Cycle-GAN requires two generators and two discriminators
                    generator_a = config["generator_a"](embedding_dim=embedding_dim).to(device)
                    generator_b = config["generator_b"](embedding_dim=embedding_dim).to(device)
                    discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
                    discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
                    print(f"CycleGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
                else:
                    # Dual-GAN or Contrastive-Dual-GAN requires both latent_dim and embedding_dim for generators
                    generator_a = config["generator_a"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
                    generator_b = config["generator_b"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
                    discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
                    discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
                    print(f"DualGAN/ContrastiveDualGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
            else:
                # For other GANs (WGAN-GP, VAE-GAN, etc.), initialize a single generator
                generator_args = {"latent_dim": latent_dim, "embedding_dim": embedding_dim}
                if gan_type == "InfoGAN":
                    # Correctly initialize InfoGAN generator with latent_dim and categorical_dim
                    generator_args["categorical_dim"] = categorical_dim
                generator = config["generator"](**generator_args).to(device)
                print(f"Generator initialized on {generator}")

            # Handle critic/discriminator initialization conditionally
            if gan_type == "WGAN-GP":
                # WGAN-GP uses a critic instead of a discriminator
                critic = config["critic"](embedding_dim=embedding_dim).to(device)
                print(f"Critic initialized on {critic}")
            elif gan_type == "Semi-Supervised-GAN":
                # Semi-Supervised-GAN requires num_classes for the discriminator
                discriminator = config["discriminator"](embedding_dim=embedding_dim, num_classes=num_classes).to(device)
                print(f"Semi-Supervised-GAN discriminator initialized on {discriminator}")
            elif gan_type == "InfoGAN":
                # InfoGAN discriminator requires embedding_dim and categorical_dim
                discriminator = config["discriminator"](embedding_dim=embedding_dim, categorical_dim=categorical_dim).to(device)
                print(f"InfoGAN discriminator initialized on {discriminator}")
            elif gan_type not in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]:
                # For other models (VAE-GAN, Conditional-GAN, etc.), initialize a single discriminator
                discriminator_args = {"embedding_dim": embedding_dim}
                if gan_type in ["Conditional-GAN", "InfoGAN"]:
                    discriminator_args["num_classes"] = num_classes
                if gan_type == "InfoGAN":
                    discriminator_args["categorical_dim"] = categorical_dim
                discriminator = config["discriminator"](**discriminator_args).to(device)
                print(f"Discriminator initialized on {discriminator}")
            else:
                # If discriminator is part of the configuration (Cycle-GAN, Dual-GAN, etc.), use the ones in the config
                discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
                discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
                print(f"Discriminators initialized on {discriminator_a}, {discriminator_b}")

            # Initialize encoder for VAE-GAN
            if gan_type == "VAE-GAN":
                encoder = config["encoder"](embedding_dim=embedding_dim, latent_dim=latent_dim).to(device)
                print(f"Encoder initialized on {encoder}")

            # Select the appropriate training function
            train_function = config["train_function"]

            # Handle any additional configurations (like learning rate, lambda_gp)
            train_kwargs = config.get("train_kwargs", {})  # Default to empty dict if not specified
            train_kwargs.update({
                "latent_dim": latent_dim,
                "epochs": epochs,  # Pass epochs as part of train_kwargs
                "device": device,
                "learning_rate": learning_rate,
                "num_classes": num_classes if gan_type in ["Conditional-GAN", "InfoGAN", "Semi-Supervised-GAN"] else None,  # Add for Conditional, InfoGAN, and Semi-Supervised-GAN
                "categorical_dim": categorical_dim if gan_type == "InfoGAN" else None,  # Add only for InfoGAN
                "lambda_gp": 10 if gan_type == "WGAN-GP" else None,  # Add only for WGAN-GP
            })

            # Add data loaders conditionally
            if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
                # Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN: pass two data loaders
                train_kwargs.update({
                    "data_loader_a": embedding_loader_a,  # Two loaders for these models
                    "data_loader_b": embedding_loader_b
                })
            else:
                # For all other models (WGAN-GP, VAE-GAN, etc.), use a single loader
                train_kwargs.update({
                    "data_loader": embedding_loader  # Single data loader for other models
                })

            # Training loop (now handled inside the train_function)
            if gan_type == "VAE-GAN":
                # VAE-GAN requires encoder, generator, and discriminator
                train_function(encoder, generator, discriminator, **train_kwargs)
            else:
                # For other models, pass the appropriate generator/discriminator/critic
                train_function(generator_a, generator_b, discriminator_a, discriminator_b, **train_kwargs) if gan_type in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"] else train_function(generator, discriminator if "discriminator" in locals() else critic, **train_kwargs)

            print(f"{gan_type} training test passed!")

        else:
            raise ValueError(f"Unsupported GAN type: {gan_type}")

    except Exception as e:
        print(f"Error testing {gan_type}: {e}")


Testing WGAN-GP...
Initializing WGAN-GP configuration: {'generator': <class 'src.gan_workflows.plan2.plan2_gan_models.WGANGenerator'>, 'critic': <class 'src.gan_workflows.plan2.plan2_gan_models.WGANCritic'>, 'train_function': <function train_wgan_gp at 0x7ee3f05ee8e0>, 'train_kwargs': {'lambda_gp': 10}}
Generator initialized on WGANGenerator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=100, out_features=256, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): Linear(in_features=256, out_features=50, bias=True)
  )
)
Critic initialized on WGANCritic(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=50, out_features=512, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=True)
        (1): LeakyReL