# Plan 2: GAN Training with Embeddings

## Objective
The goal of this plan is to train GANs using embeddings generated from Plan 1. This involves:
1. Setting up a GAN architecture that incorporates embeddings.
2. Training the GAN to generate high-quality synthetic embeddings or data.
3. Evaluating the GAN's performance using relevant metrics.

## Key Steps

### 1. Define GAN Architecture
- **Generator**: Accepts embeddings and noise as inputs, producing synthetic data.
- **Discriminator**: Evaluates the authenticity of generated data, optionally conditioned on embeddings.

### 2. Load Pre-Generated Embeddings
- Load embeddings created in Plan 1 from their respective directories.
- Normalize and preprocess embeddings for GAN training.

### 3. Train the GAN
- Set up training loops for the generator and discriminator.
- Use appropriate loss functions, such as adversarial loss (e.g., Wasserstein loss).
- Optionally, include auxiliary tasks (e.g., reconstruction loss) for better embedding alignment.

### 4. Save Outputs
- Save trained GAN models.
- Save generated embeddings or data for downstream evaluation.

### 5. Evaluate GAN Performance
- Use metrics like FID, IS, and qualitative visualization.
- Compare performance with baseline models or methods.

---

## Starting Point

**Files Available**:
- `plan2_gan_models.py`: Contains the GAN architecture.
- `plan2_gan_training.py`: Implements the training pipeline.
- `main_plan2_gan_training.ipynb`: High-level control notebook.
- `plan2_experiments.ipynb`: For experimentation and evaluation.

## Next Steps
1. **Review Architecture**:
   - Inspect `plan2_gan_models.py` to understand the generator and discriminator setup.
2. **Pipeline Setup**:
   - Review and prepare the training pipeline in `plan2_gan_training.py`.
3. **Experimentation**:
   - Use `plan2_experiments.ipynb` for controlled experiments.


In [None]:
# Plan 2 GAN Training Notebook

import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import torch.nn.functional as F
import logging
import json
import datetime
from datetime import datetime
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset, random_split
from torchvision.models import inception_v3
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.neighbors import NearestNeighbors
from scipy.linalg import sqrtm
from scipy.stats import wasserstein_distance, entropy, spearmanr, gaussian_kde


from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Repository path (adjust if needed)
repo_path = "/content/drive/MyDrive/GAN-thesis-project"

# Add repository path to sys.path for module imports
if repo_path not in sys.path:
    sys.path.append(repo_path)

# Change working directory to the repository
os.chdir(repo_path)

# Verify the working directory
print(f"Current working directory: {os.getcwd()}")


# Set random seed and device
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Mounted at /content/drive
Current working directory: /content/drive/MyDrive/GAN-thesis-project
Using device: cpu


In [None]:
# just for you to see what are the implemented function and classes that have been coded up...

import inspect

# Import the entire modules
import src.data_utils as data_utils
import src.cl_loss_function as cl_loss
import src.losses as losses
import src.gan_workflows.plan2.plan2_gan_models as gan_models
import src.gan_workflows.plan2.plan2_gan_training as gan_training

# Function to list functions and classes in a module
def list_functions_and_classes(module):
    members = inspect.getmembers(module)
    functions = [name for name, obj in members if inspect.isfunction(obj)]
    classes = [name for name, obj in members if inspect.isclass(obj)]
    return functions, classes

# Function to print functions and classes in a readable format
def print_functions_and_classes(module_name, module):
    functions, classes = list_functions_and_classes(module)
    print(f"Module: {module_name}")
    print("  Functions:")
    for func in functions:
        print(f"    - {func}")
    print("  Classes:")
    for cls in classes:
        print(f"    - {cls}")
    print()  # Add a blank line for separation

# Print functions and classes for each module
print_functions_and_classes("src.data_utils", data_utils)
print_functions_and_classes("src.cl_loss_function", cl_loss)
print_functions_and_classes("src.losses", losses)
print_functions_and_classes("src.embeddings.encoder_models", gan_models)
print_functions_and_classes("src.embeddings.encoder_training", gan_training)

Module: src.data_utils
  Functions:
    - analyze_embeddings
    - analyze_embeddings_v2
    - create_dataloader
    - create_embedding_loaders
    - generate_embeddings
    - kurtosis
    - load_data
    - load_embeddings
    - load_mnist_data
    - pdist
    - preprocess_images
    - save_embeddings
    - skew
    - split_dataset
    - train_test_split
    - visualize_embeddings
  Classes:
    - DataLoader
    - LocalOutlierFactor
    - TensorDataset

Module: src.cl_loss_function
  Functions:
    - augment
    - compute_nt_xent_loss_with_augmentation
    - compute_triplet_loss_with_augmentation
    - contrastive_loss
    - hflip
    - info_nce_loss
    - resize
  Classes:
    - BYOLLoss
    - BarlowTwinsLoss
    - ContrastiveHead
    - DataLoader
    - NTXentLoss
    - PCA
    - Predictor
    - TensorDataset
    - TripletLoss
    - VicRegLoss

Module: src.losses
  Functions:
    - add_noise
    - cyclical_beta_schedule
    - linear_beta_schedule
    - loss_function_dae_ssim
    - vae

In [None]:
# GAN Models and Training Functions
from src.gan_workflows.plan2.plan2_gan_models import (
    SimpleGANGenerator, SimpleGANDiscriminator,
    ContrastiveGANGenerator, ContrastiveGANDiscriminator,
    VAEGANEncoder, VAEGANGenerator, VAEGANDiscriminator,
    WGANGenerator, WGANCritic,
    CrossDomainGenerator, CrossDomainDiscriminator,
    CycleGenerator, CycleDiscriminator,
    DualGANGenerator, DualGANDiscriminator,
    ContrastiveDualGANGenerator, ContrastiveDualGANDiscriminator,
    SemiSupervisedGANDiscriminator,
    ConditionalGANGenerator, ConditionalGANDiscriminator,
    InfoGANGenerator, InfoGANDiscriminator,
    compute_gradient_penalty
)

from src.gan_workflows.plan2.plan2_gan_training import (
    train_wgan_gp, train_vae_gan, train_contrastive_gan,
    train_cross_domain_gan, train_cycle_gan,
    train_dual_gan, train_contrastive_dual_gan,
    train_semi_supervised_gan,
    train_conditional_gan,
    train_infogan
)

from src.data_utils import (
    load_embeddings, analyze_embeddings
)

In [None]:
embedding_dir = "./saved_embeddings/embeddings/"  # Example embedding path
embedding_file = os.path.join(embedding_dir, "autoencoders_BasicAutoencoder/BasicAutoencoder_embeddings.pt")

# Load embeddings and labels
embeddings, labels, data_loader = load_embeddings(embedding_file, device)

INFO - Loading embeddings from: ./saved_embeddings/embeddings/autoencoders_BasicAutoencoder/BasicAutoencoder_embeddings.pt
  data = torch.load(embedding_file)


In [None]:
# ## OLD VERSION, I used to use this one for validation...

# def validate_embeddings(embeddings):
#     """
#     Validate and provide information about the shape and data type of embeddings.
#     """
#     if embeddings is None or len(embeddings) == 0:
#         raise ValueError("Embeddings are empty or not properly generated.")
#     print(f"Embeddings are of shape: {embeddings.shape}")
#     print(f"Data type: {embeddings.dtype}")
#     print(f"Device: {embeddings.device}")
#     if torch.isnan(embeddings).any():
#         raise ValueError("Embeddings contain NaN values.")
#     if torch.isinf(embeddings).any():
#         raise ValueError("Embeddings contain infinite values.")
#     if embeddings.ndim != 2:
#         raise ValueError("Embeddings should be a 2D tensor.")
#     print("Embeddings validation passed.")

# validate_embeddings(embeddings)

In [None]:
# Get a batch of training data to see how my batches are
embedding_batch = next(iter(data_loader))
print('embedding batches', embedding_batch.shape)

embedding batches torch.Size([64, 50])


In [None]:
# GAN Configuration Dictionary
gan_configurations = {
    "WGAN-GP": {
        "generator": WGANGenerator,
        "critic": WGANCritic,
        "train_function": train_wgan_gp,
        "train_kwargs": {"lambda_gp": 10}
    },
    "VAE-GAN": {
        "encoder": VAEGANEncoder,
        "generator": VAEGANGenerator,
        "discriminator": VAEGANDiscriminator,
        "train_function": train_vae_gan
    },
    "Contrastive-GAN": {
        "generator": ContrastiveGANGenerator,
        "discriminator": ContrastiveGANDiscriminator,
        "train_function": train_contrastive_gan
    },
    "Cross-Domain-GAN": {
        "generator": CrossDomainGenerator,
        "discriminator": CrossDomainDiscriminator,
        "train_function": train_cross_domain_gan
    },
    "Cycle-GAN": {
        "generator_a": CycleGenerator,
        "generator_b": CycleGenerator,
        "discriminator_a": CycleDiscriminator,
        "discriminator_b": CycleDiscriminator,
        "train_function": train_cycle_gan
    },
    "Dual-GAN": {
        "generator_a": DualGANGenerator,
        "generator_b": DualGANGenerator,
        "discriminator_a": DualGANDiscriminator,
        "discriminator_b": DualGANDiscriminator,
        "train_function": train_dual_gan
    },
    "Contrastive-Dual-GAN": {
        "generator_a": ContrastiveDualGANGenerator,
        "generator_b": ContrastiveDualGANGenerator,
        "discriminator_a": ContrastiveDualGANDiscriminator,
        "discriminator_b": ContrastiveDualGANDiscriminator,
        "train_function": train_contrastive_dual_gan
    },
    "Semi-Supervised-GAN": {
        "generator": SimpleGANGenerator,
        "discriminator": SemiSupervisedGANDiscriminator,
        "train_function": train_semi_supervised_gan
    },
    "Conditional-GAN": {
        "generator": ConditionalGANGenerator,
        "discriminator": ConditionalGANDiscriminator,
        "train_function": train_conditional_gan
    },
    "InfoGAN": {
        "generator": InfoGANGenerator,
        "discriminator": InfoGANDiscriminator,
        "train_function": train_infogan
    }
}

In [None]:
gan_configurations.keys()

dict_keys(['WGAN-GP', 'VAE-GAN', 'Contrastive-GAN', 'Cross-Domain-GAN', 'Cycle-GAN', 'Dual-GAN', 'Contrastive-Dual-GAN', 'Semi-Supervised-GAN', 'Conditional-GAN', 'InfoGAN'])

In [None]:
import os

# Base directory where embeddings are stored
embedding_base_dir = "./saved_embeddings/embeddings/"

def list_available_embeddings(base_dir, filter_by=None):
    """
    List available embedding directories and files, optionally filtered by method.

    Args:
        base_dir (str): The base directory containing embeddings.
        filter_by (str or list, optional): Method(s) to filter by (e.g., "autoencoder", "vae").
                                           If None, all embeddings are displayed.
    """
    print("\n📂 Available Embeddings:\n")

    if isinstance(filter_by, str):
        filter_by = [filter_by]  # Convert single filter to list

    for method in sorted(os.listdir(base_dir)):
        method_path = os.path.join(base_dir, method)

        # Check if it's a directory
        if os.path.isdir(method_path):
            if filter_by is None or any(f.lower() in method.lower() for f in filter_by):
                pt_files = [f for f in sorted(os.listdir(method_path)) if f.endswith(".pt")]
                if pt_files:
                    print(f"\n🔹 {method}")  # Show only the category
                    for file in pt_files:
                        print(f"   📄 {method}/{file}")  # Show category + filename

# Default: Show everything
list_available_embeddings(embedding_base_dir)

# Example: Show only autoencoder-related embeddings
# list_available_embeddings(embedding_base_dir, filter_by="vae")

# Example: Show both autoencoder and VAE embeddings
# list_available_embeddings(embedding_base_dir, filter_by=["autoencoder", "vae"])


📂 Available Embeddings:


🔹 autoencoder_AdvancedAutoencoder_barlow_twins
   📄 autoencoder_AdvancedAutoencoder_barlow_twins/AdvancedAutoencoder_barlow_twins_embeddings.pt

🔹 autoencoder_AdvancedAutoencoder_contrastive
   📄 autoencoder_AdvancedAutoencoder_contrastive/AdvancedAutoencoder_contrastive_embeddings.pt

🔹 autoencoder_AdvancedAutoencoder_info_nce
   📄 autoencoder_AdvancedAutoencoder_info_nce/AdvancedAutoencoder_info_nce_embeddings.pt

🔹 autoencoder_AdvancedAutoencoder_mse
   📄 autoencoder_AdvancedAutoencoder_mse/AdvancedAutoencoder_mse_embeddings.pt

🔹 autoencoder_AdvancedAutoencoder_ntxent
   📄 autoencoder_AdvancedAutoencoder_ntxent/AdvancedAutoencoder_ntxent_embeddings.pt

🔹 autoencoder_AdvancedAutoencoder_vicreg
   📄 autoencoder_AdvancedAutoencoder_vicreg/AdvancedAutoencoder_vicreg_embeddings.pt

🔹 autoencoder_EnhancedAutoencoder_barlow_twins
   📄 autoencoder_EnhancedAutoencoder_barlow_twins/EnhancedAutoencoder_barlow_twins_embeddings.pt

🔹 autoencoder_EnhancedAutoencoder_co

In [None]:
# ==========================
# CONFIGURATION & EMBEDDING LOADING
# ==========================

embedding_dir = "./saved_embeddings/embeddings/"
embedding_relative_path = "autoencoder_EnhancedAutoencoder_ntxent/EnhancedAutoencoder_ntxent_embeddings.pt"

# Full path to the embedding file
embedding_file = os.path.join(embedding_dir, embedding_relative_path)

# Extract identifier from the path (remove directory and _embeddings.pt)
embedding_identifier = embedding_relative_path.split("/")[-1].replace("_embeddings.pt", "")

report_dir = "./reports/GANs_Evaluations"
os.makedirs(report_dir, exist_ok=True)

# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = {
    "gan_type": 'WGAN-GP',
    "embedding_identifier": embedding_identifier,  # Dynamically extracted identifier
    "latent_dim": 100,
    "embedding_dim": None,  # Will be set after loading embeddings
    "num_classes": 10,
    "categorical_dim": 10,
    "epochs": 1,
    "batch_size": 64,
    "learning_rate": 1e-4,
    "device": device,
    "lambda_gp": 10,
    "beta1": 0.5,
    "beta2": 0.999,
    "save_path": "gan_model.pth",
    "eval_fraction": 0.1,  # Fraction of embeddings used for evaluation
    "kl_method": "histogram",  # Can be "histogram" or "kde"
    "show_model_architecture": True
}

def load_embeddings(embedding_file, device, batch_size=64):
    """Loads embeddings and labels from a specified file."""
    data = torch.load(embedding_file)
    embeddings = data["embeddings"].to(device)
    labels = data["labels"].to(device)
    data_loader = DataLoader(embeddings, batch_size=batch_size, shuffle=True)
    return embeddings, labels, data_loader

def split_embeddings(embeddings, labels, eval_fraction=0.1, batch_size=64):
    """Splits embeddings into training and evaluation sets."""
    num_samples = embeddings.size(0)
    num_eval = int(num_samples * eval_fraction)
    num_train = num_samples - num_eval

    train_embeddings, eval_embeddings = random_split(embeddings, [num_train, num_eval])
    train_labels, eval_labels = random_split(labels, [num_train, num_eval])

    train_loader = DataLoader(train_embeddings, batch_size=batch_size, shuffle=True)
    eval_loader = DataLoader(eval_embeddings, batch_size=batch_size, shuffle=False)
    return train_loader, eval_loader, train_embeddings, eval_embeddings

# Load embeddings and split
embeddings, labels, full_data_loader = load_embeddings(embedding_file, device)
config["embedding_dim"] = embeddings.size(1)
train_loader, eval_loader, train_embeddings, eval_embeddings = split_embeddings(embeddings, labels, config["eval_fraction"], config["batch_size"])

# Update config with data loaders
config.update({
    "data_loader": train_loader,
    "data_loader_a": train_loader,
    "data_loader_b": train_loader,
    "eval_loader": eval_loader,
    "original_embeddings": embeddings  # Store original embeddings for evaluation
})

# ==========================
# MODEL INITIALIZATION
# ==========================

def initialize_gan_components(config, gan_configurations):
    """Initialize GAN components based on type."""
    components = {}
    gan_type = config["gan_type"]
    multi_gan_types = ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]

    if gan_type in multi_gan_types:
        gen_args = {"embedding_dim": config["embedding_dim"]}
        if gan_type != "Cycle-GAN":
            gen_args["latent_dim"] = config["latent_dim"]

        components.update({
            "generator_a": gan_configurations["generator_a"](**gen_args).to(config["device"]),
            "generator_b": gan_configurations["generator_b"](**gen_args).to(config["device"]),
            "discriminator_a": gan_configurations["discriminator_a"](embedding_dim=config["embedding_dim"]).to(config["device"]),
            "discriminator_b": gan_configurations["discriminator_b"](embedding_dim=config["embedding_dim"]).to(config["device"])
        })
    else:
        gen_args = {"latent_dim": config["latent_dim"], "embedding_dim": config["embedding_dim"]}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            gen_args["num_classes"] = config["num_classes"]
        if gan_type == "InfoGAN":
            gen_args["categorical_dim"] = config["categorical_dim"]

        components["generator"] = gan_configurations["generator"](**gen_args).to(config["device"])

    if gan_type == "WGAN-GP":
        components["critic"] = gan_configurations["critic"](embedding_dim=config["embedding_dim"]).to(config["device"])
    elif gan_type == "VAE-GAN":
        components["encoder"] = gan_configurations["encoder"](embedding_dim=config["embedding_dim"], latent_dim=config["latent_dim"]).to(config["device"])
        components["discriminator"] = gan_configurations["discriminator"](embedding_dim=config["embedding_dim"]).to(config["device"])
    elif gan_type not in multi_gan_types:
        disc_args = {"embedding_dim": config["embedding_dim"]}
        if gan_type in ["Conditional-GAN", "InfoGAN", "Semi-Supervised-GAN"]:
            disc_args["num_classes"] = config["num_classes"]
        if gan_type == "InfoGAN":
            disc_args["categorical_dim"] = config["categorical_dim"]

        components["discriminator"] = gan_configurations["discriminator"](**disc_args).to(config["device"])

    if config["show_model_architecture"]:
        print(f"Initialized components: {components}")

    return components

def run_gan_training(config):
    """Runs GAN training based on selected model."""
    gan_type = config["gan_type"]
    if gan_type not in gan_configurations:
        raise ValueError(f"Unsupported GAN type: {gan_type}")

    gan_config = gan_configurations[gan_type]
    components = initialize_gan_components(config, gan_config)
    train_function = gan_config["train_function"]

    print(f"🚀 Training {gan_type}...")

    if gan_type == "VAE-GAN":
        train_function(components["encoder"], components["generator"], components["discriminator"], **config)
    elif gan_type in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]:
        train_function(components["generator_a"], components["generator_b"], components["discriminator_a"], components["discriminator_b"], **config)
    else:
        discriminator = components.get("discriminator", components.get("critic", None))
        train_function(components["generator"], discriminator, **config)

    print(f"✅ {gan_type} training completed!")

# ==========================
# EVALUATION METRICS
# ==========================

def calculate_fid(real_embeddings, generated_embeddings, eps=1e-6):
    """Compute Fréchet Inception Distance (FID) with numerical stability."""
    mu1, sigma1 = np.mean(real_embeddings, axis=0), np.cov(real_embeddings, rowvar=False)
    mu2, sigma2 = np.mean(generated_embeddings, axis=0), np.cov(generated_embeddings, rowvar=False)

    diff = mu1 - mu2
    covmean = sqrtm(sigma1 @ sigma2 + np.eye(sigma1.shape[0]) * eps)  # Regularize singularity

    if np.iscomplexobj(covmean):
        covmean = covmean.real  # Remove imaginary component

    return np.sum(diff**2) + np.trace(sigma1 + sigma2 - 2 * covmean)

def calculate_kl_divergence(real_embeddings, generated_embeddings, method="histogram"):
    """
    Compute KL divergence between real and generated embeddings.

    Args:
        real_embeddings (np.ndarray): Real data embeddings.
        generated_embeddings (np.ndarray): Generated data embeddings.
        method (str): "histogram" (default) or "kde" for probability estimation.

    Returns:
        float: KL divergence score.
    """
    if method == "histogram":
        # Compute probability distributions using histograms
        real_prob = np.histogram(real_embeddings, bins=50, density=True)[0]
        gen_prob = np.histogram(generated_embeddings, bins=50, density=True)[0]

        # Avoid division by zero
        real_prob += 1e-10
        gen_prob += 1e-10

    elif method == "kde":
        # Use Kernel Density Estimation for smoother probability estimation
        real_kde = gaussian_kde(real_embeddings.T)
        gen_kde = gaussian_kde(generated_embeddings.T)

        # Sample points for estimation
        sample_points = np.linspace(
            min(real_embeddings.min(), generated_embeddings.min()),
            max(real_embeddings.max(), generated_embeddings.max()),
            100
        )

        real_prob = real_kde(sample_points) + 1e-10
        gen_prob = gen_kde(sample_points) + 1e-10

    else:
        raise ValueError("Invalid method. Choose 'histogram' or 'kde'.")

    return entropy(real_prob, gen_prob)

def calculate_cosine_similarity(real_embeddings, generated_embeddings):
    """Compute cosine similarity between real and generated embeddings."""
    return np.mean(cosine_similarity(real_embeddings, generated_embeddings))

def rank_similarity(real_embeddings, generated_embeddings):
    """Compute Spearman Rank Correlation between real and generated embeddings."""
    min_size = min(real_embeddings.shape[0], generated_embeddings.shape[0])

    # Select the first min_size embeddings instead of random sampling
    real_embeddings = real_embeddings[:min_size]
    generated_embeddings = generated_embeddings[:min_size]

    real_rank = np.argsort(real_embeddings, axis=0)
    gen_rank = np.argsort(generated_embeddings, axis=0)

    return spearmanr(real_rank.flatten(), gen_rank.flatten()).correlation

def unique_embedding_ratio(generated_embeddings):
    """Compute the ratio of unique embeddings in the generated set."""
    unique_embeddings = np.unique(generated_embeddings, axis=0)
    return len(unique_embeddings) / len(generated_embeddings)

def aggregate_quality_score(fid, kl, cosine, rank_corr):
    """Compute a weighted quality score based on multiple metrics."""
    # Normalize metrics (assuming lower is better for FID & KL)
    fid_norm = 1 / (1 + fid)
    kl_norm = 1 / (1 + kl)
    cosine_norm = cosine  # Higher is better, no need to invert
    rank_norm = (rank_corr + 1) / 2  # Convert [-1,1] to [0,1]

    # Compute final weighted score
    return (0.4 * fid_norm) + (0.2 * kl_norm) + (0.2 * cosine_norm) + (0.2 * rank_norm)

def calculate_wasserstein_distance(real_embeddings, generated_embeddings):
    """Compute Wasserstein Distance between real and generated embeddings with normalization."""
    real_embeddings = (real_embeddings - np.mean(real_embeddings)) / (np.std(real_embeddings) + 1e-8)
    generated_embeddings = (generated_embeddings - np.mean(generated_embeddings)) / (np.std(generated_embeddings) + 1e-8)

    return wasserstein_distance(real_embeddings.flatten(), generated_embeddings.flatten())

def calculate_coverage_score(real_embeddings, generated_embeddings, n_neighbors=5):
    """Compute Coverage Score: Percentage of real embeddings with at least one close match in generated embeddings."""
    neigh = NearestNeighbors(n_neighbors=n_neighbors)
    neigh.fit(generated_embeddings)
    distances, _ = neigh.kneighbors(real_embeddings)
    return np.mean(distances[:, 0] < 0.1)  # Adjust threshold as needed

def calculate_memorization_score(real_embeddings, generated_embeddings, n_neighbors=1, tolerance=1e-3):
    """Compute Memorization Score: Percentage of generated embeddings that are close to real embeddings within a small tolerance."""
    neigh = NearestNeighbors(n_neighbors=n_neighbors)
    neigh.fit(real_embeddings)
    distances, _ = neigh.kneighbors(generated_embeddings)

    return np.mean(distances[:, 0] < tolerance)  # Allow small tolerance for near-exact matches

def generate_synthetic_samples(generator, num_samples=1000, latent_dim=100, device="cuda"):
    """Generate synthetic samples using a trained GAN generator."""
    generator.eval()
    with torch.no_grad():
        latent_vectors = torch.randn(num_samples, latent_dim).to(device)
        generated_samples = generator(latent_vectors)
    return generated_samples.cpu()

def create_dataloader_from_samples(samples, batch_size=64):
    return DataLoader(samples, batch_size=batch_size, shuffle=False)

def convert_to_float(metrics):
    """Ensure all values in the dictionary are JSON serializable."""
    return {k: float(v) for k, v in metrics.items()}

def evaluate_gan(gan_components, config):
    """Evaluate GAN with multiple metrics, ensuring JSON serialization."""
    device = config["device"]
    generator = gan_components["generator"]

    generated_samples = generate_synthetic_samples(generator, num_samples=1000, latent_dim=config["latent_dim"], device=device)
    generated_dataloader = create_dataloader_from_samples(generated_samples, batch_size=config["batch_size"])

    real_embeddings = config["original_embeddings"].cpu().numpy()
    eval_embeddings = torch.cat([batch for batch in config["eval_loader"]], dim=0).cpu().numpy()
    gen_embeddings = torch.cat([batch for batch in generated_dataloader], dim=0).cpu().numpy()

    kl_method = config.get("kl_method", "histogram")

    metrics = {
        "FID (Original vs Generated)": calculate_fid(real_embeddings, gen_embeddings),
        "FID (Original vs Eval)": calculate_fid(real_embeddings, eval_embeddings),
        "KL Divergence": calculate_kl_divergence(real_embeddings, gen_embeddings),
        "Cosine Similarity": calculate_cosine_similarity(real_embeddings, gen_embeddings),
        "Spearman Rank Correlation": rank_similarity(real_embeddings, gen_embeddings),
        "Wasserstein Distance": calculate_wasserstein_distance(real_embeddings, gen_embeddings),
        "Coverage Score": calculate_coverage_score(real_embeddings, gen_embeddings),
        "Memorization Score": calculate_memorization_score(real_embeddings, gen_embeddings),
        "Unique Embedding Ratio": unique_embedding_ratio(gen_embeddings)
    }

    # Convert to native Python floats for JSON compatibility
    metrics = convert_to_float(metrics)

    # Compute aggregate quality score
    metrics["Aggregate Quality Score"] = aggregate_quality_score(
        metrics["FID (Original vs Generated)"],
        metrics["KL Divergence"],
        metrics["Cosine Similarity"],
        metrics["Spearman Rank Correlation"]
    )

    # Print results
    print(json.dumps(metrics, indent=4))

    # Save results
    # with open("evaluation_results.json", "w") as f:
    #     json.dump(metrics, f, indent=4)
    save_report(metrics, config["gan_type"], config["embedding_identifier"])

def save_report(metrics, gan_type, embedding_identifier):
    """Saves evaluation metrics as a JSON file with a meaningful name."""
    report_filename = f"evaluation_{gan_type}_{embedding_identifier}.json"
    report_path = os.path.join(report_dir, report_filename)

    with open(report_path, "w") as f:
        json.dump(metrics, f, indent=4)

    print(f"✅ Report saved: {report_path}")

# Initialize GAN components
gan_components = initialize_gan_components(config, gan_configurations[config["gan_type"]])

# Run training
run_gan_training(config)

# Evaluate using the original embeddings and evaluation set
evaluate_gan(gan_components, config)

  data = torch.load(embedding_file)


Initialized components: {'generator': WGANGenerator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=100, out_features=256, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): Linear(in_features=256, out_features=50, bias=True)
  )
), 'critic': WGANCritic(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=50, out_features=512, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (2): Linear(in_features=256, out_features=1, bias=True)
  )
)}
Initialized components: {'generator': WGANGenerator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=100, out_features=256, bias=Tr

In [None]:
# ==========================
# CONFIGURATION & EMBEDDING LOADING
# ==========================

embedding_dir = "./saved_embeddings/embeddings/"
embedding_relative_path = "autoencoder_EnhancedAutoencoder_ntxent/EnhancedAutoencoder_ntxent_embeddings.pt"

# Full path to the embedding file
embedding_file = os.path.join(embedding_dir, embedding_relative_path)

# Extract identifier from the path (remove directory and `_embeddings.pt`)
embedding_identifier = embedding_relative_path.split("/")[-1].replace("_embeddings.pt", "")

report_dir = "./reports/GANs_Evaluations"
os.makedirs(report_dir, exist_ok=True)

# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = {
    "gan_types": ['WGAN-GP', 'VAE-GAN', 'Contrastive-GAN', 'Cross-Domain-GAN'],
    "embedding_identifier": embedding_identifier,  # Dynamically extracted identifier
    "latent_dim": 100,
    "embedding_dim": None,  # Will be set after loading embeddings
    "num_classes": 10,
    "categorical_dim": 10,
    "epochs": 1,
    "batch_size": 64,
    "learning_rate": 1e-4,
    "device": device,
    "lambda_gp": 10,
    "beta1": 0.5,
    "beta2": 0.999,
    "save_path": "gan_model.pth",
    "eval_fraction": 0.1,  # Fraction of embeddings used for evaluation
    "kl_method": "histogram",  # Can be "histogram" or "kde"
    "show_model_architecture": False
}

def load_embeddings(embedding_file, device, batch_size=64):
    """Loads embeddings and labels from a specified file."""
    data = torch.load(embedding_file)
    embeddings = data["embeddings"].to(device)
    labels = data["labels"].to(device)
    data_loader = DataLoader(embeddings, batch_size=batch_size, shuffle=True)
    return embeddings, labels, data_loader

def split_embeddings(embeddings, labels, eval_fraction=0.1, batch_size=64):
    """Splits embeddings into training and evaluation sets."""
    num_samples = embeddings.size(0)
    num_eval = int(num_samples * eval_fraction)
    num_train = num_samples - num_eval

    train_embeddings, eval_embeddings = random_split(embeddings, [num_train, num_eval])
    train_labels, eval_labels = random_split(labels, [num_train, num_eval])

    train_loader = DataLoader(train_embeddings, batch_size=batch_size, shuffle=True)
    eval_loader = DataLoader(eval_embeddings, batch_size=batch_size, shuffle=False)
    return train_loader, eval_loader, train_embeddings, eval_embeddings

# Load embeddings and split
embeddings, labels, full_data_loader = load_embeddings(embedding_file, device)
config["embedding_dim"] = embeddings.size(1)
train_loader, eval_loader, train_embeddings, eval_embeddings = split_embeddings(embeddings, labels, config["eval_fraction"], config["batch_size"])

# Update config with data loaders
config.update({
    "data_loader": train_loader,
    "data_loader_a": train_loader,
    "data_loader_b": train_loader,
    "eval_loader": eval_loader,
    "original_embeddings": embeddings  # Store original embeddings for evaluation
})

# ==========================
# MODEL INITIALIZATION
# ==========================

def initialize_gan_components(config, gan_configurations):
    """Initialize GAN components based on type."""
    components = {}
    gan_type = config["gan_type"]
    multi_gan_types = ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]

    if gan_type in multi_gan_types:
        gen_args = {"embedding_dim": config["embedding_dim"]}
        if gan_type != "Cycle-GAN":
            gen_args["latent_dim"] = config["latent_dim"]

        components.update({
            "generator_a": gan_configurations["generator_a"](**gen_args).to(config["device"]),
            "generator_b": gan_configurations["generator_b"](**gen_args).to(config["device"]),
            "discriminator_a": gan_configurations["discriminator_a"](embedding_dim=config["embedding_dim"]).to(config["device"]),
            "discriminator_b": gan_configurations["discriminator_b"](embedding_dim=config["embedding_dim"]).to(config["device"])
        })
    else:
        gen_args = {"latent_dim": config["latent_dim"], "embedding_dim": config["embedding_dim"]}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            gen_args["num_classes"] = config["num_classes"]
        if gan_type == "InfoGAN":
            gen_args["categorical_dim"] = config["categorical_dim"]

        components["generator"] = gan_configurations["generator"](**gen_args).to(config["device"])

    if gan_type == "WGAN-GP":
        components["critic"] = gan_configurations["critic"](embedding_dim=config["embedding_dim"]).to(config["device"])
    elif gan_type == "VAE-GAN":
        components["encoder"] = gan_configurations["encoder"](embedding_dim=config["embedding_dim"], latent_dim=config["latent_dim"]).to(config["device"])
        components["discriminator"] = gan_configurations["discriminator"](embedding_dim=config["embedding_dim"]).to(config["device"])
    elif gan_type not in multi_gan_types:
        disc_args = {"embedding_dim": config["embedding_dim"]}
        if gan_type in ["Conditional-GAN", "InfoGAN", "Semi-Supervised-GAN"]:
            disc_args["num_classes"] = config["num_classes"]
        if gan_type == "InfoGAN":
            disc_args["categorical_dim"] = config["categorical_dim"]

        components["discriminator"] = gan_configurations["discriminator"](**disc_args).to(config["device"])

    if config["show_model_architecture"]:
        print(f"Initialized components: {components}")

    return components


def run_gan_training(config):
    """Runs GAN training based on selected model."""
    gan_type = config["gan_type"]
    if gan_type not in gan_configurations:
        raise ValueError(f"Unsupported GAN type: {gan_type}")

    gan_config = gan_configurations[gan_type]
    components = initialize_gan_components(config, gan_config)
    train_function = gan_config["train_function"]

    print(f"🚀 Training {gan_type}...")

    if gan_type == "VAE-GAN":
        train_function(components["encoder"], components["generator"], components["discriminator"], **config)
    elif gan_type in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]:
        train_function(components["generator_a"], components["generator_b"], components["discriminator_a"], components["discriminator_b"], **config)
    else:
        discriminator = components.get("discriminator", components.get("critic", None))
        train_function(components["generator"], discriminator, **config)

    print(f"✅ {gan_type} training completed!")


# ==========================
# EVALUATION METRICS
# ==========================

def calculate_fid(real_embeddings, generated_embeddings, eps=1e-6):
    """Compute Fréchet Inception Distance (FID) with numerical stability."""
    mu1, sigma1 = np.mean(real_embeddings, axis=0), np.cov(real_embeddings, rowvar=False)
    mu2, sigma2 = np.mean(generated_embeddings, axis=0), np.cov(generated_embeddings, rowvar=False)

    diff = mu1 - mu2
    covmean = sqrtm(sigma1 @ sigma2 + np.eye(sigma1.shape[0]) * eps)  # Regularize singularity

    if np.iscomplexobj(covmean):
        covmean = covmean.real  # Remove imaginary component

    return np.sum(diff**2) + np.trace(sigma1 + sigma2 - 2 * covmean)

def calculate_kl_divergence(real_embeddings, generated_embeddings, method="histogram"):
    """
    Compute KL divergence between real and generated embeddings.

    Args:
        real_embeddings (np.ndarray): Real data embeddings.
        generated_embeddings (np.ndarray): Generated data embeddings.
        method (str): "histogram" (default) or "kde" for probability estimation.

    Returns:
        float: KL divergence score.
    """
    if method == "histogram":
        # Compute probability distributions using histograms
        real_prob = np.histogram(real_embeddings, bins=50, density=True)[0]
        gen_prob = np.histogram(generated_embeddings, bins=50, density=True)[0]

        # Avoid division by zero
        real_prob += 1e-10
        gen_prob += 1e-10

    elif method == "kde":
        # Use Kernel Density Estimation for smoother probability estimation
        real_kde = gaussian_kde(real_embeddings.T)
        gen_kde = gaussian_kde(generated_embeddings.T)

        # Sample points for estimation
        sample_points = np.linspace(
            min(real_embeddings.min(), generated_embeddings.min()),
            max(real_embeddings.max(), generated_embeddings.max()),
            100
        )

        real_prob = real_kde(sample_points) + 1e-10
        gen_prob = gen_kde(sample_points) + 1e-10

    else:
        raise ValueError("Invalid method. Choose 'histogram' or 'kde'.")

    return entropy(real_prob, gen_prob)

def calculate_cosine_similarity(real_embeddings, generated_embeddings):
    """Compute cosine similarity between real and generated embeddings."""
    return np.mean(cosine_similarity(real_embeddings, generated_embeddings))

def rank_similarity(real_embeddings, generated_embeddings):
    """Compute Spearman Rank Correlation between real and generated embeddings."""
    min_size = min(real_embeddings.shape[0], generated_embeddings.shape[0])

    # Select the first `min_size` embeddings instead of random sampling
    real_embeddings = real_embeddings[:min_size]
    generated_embeddings = generated_embeddings[:min_size]

    real_rank = np.argsort(real_embeddings, axis=0)
    gen_rank = np.argsort(generated_embeddings, axis=0)

    return spearmanr(real_rank.flatten(), gen_rank.flatten()).correlation

def unique_embedding_ratio(generated_embeddings):
    """Compute the ratio of unique embeddings in the generated set."""
    unique_embeddings = np.unique(generated_embeddings, axis=0)
    return len(unique_embeddings) / len(generated_embeddings)

def aggregate_quality_score(fid, kl, cosine, rank_corr):
    """Compute a weighted quality score based on multiple metrics."""
    # Normalize metrics (assuming lower is better for FID & KL)
    fid_norm = 1 / (1 + fid)
    kl_norm = 1 / (1 + kl)
    cosine_norm = cosine  # Higher is better, no need to invert
    rank_norm = (rank_corr + 1) / 2  # Convert [-1,1] to [0,1]

    # Compute final weighted score
    return (0.4 * fid_norm) + (0.2 * kl_norm) + (0.2 * cosine_norm) + (0.2 * rank_norm)

def calculate_wasserstein_distance(real_embeddings, generated_embeddings):
    """Compute Wasserstein Distance between real and generated embeddings with normalization."""
    real_embeddings = (real_embeddings - np.mean(real_embeddings)) / (np.std(real_embeddings) + 1e-8)
    generated_embeddings = (generated_embeddings - np.mean(generated_embeddings)) / (np.std(generated_embeddings) + 1e-8)

    return wasserstein_distance(real_embeddings.flatten(), generated_embeddings.flatten())

def calculate_coverage_score(real_embeddings, generated_embeddings, n_neighbors=5):
    """Compute Coverage Score: Percentage of real embeddings with at least one close match in generated embeddings."""
    neigh = NearestNeighbors(n_neighbors=n_neighbors)
    neigh.fit(generated_embeddings)
    distances, _ = neigh.kneighbors(real_embeddings)
    return np.mean(distances[:, 0] < 0.1)  # Adjust threshold as needed

def calculate_memorization_score(real_embeddings, generated_embeddings, n_neighbors=1, tolerance=1e-3):
    """Compute Memorization Score: Percentage of generated embeddings that are close to real embeddings within a small tolerance."""
    neigh = NearestNeighbors(n_neighbors=n_neighbors)
    neigh.fit(real_embeddings)
    distances, _ = neigh.kneighbors(generated_embeddings)

    return np.mean(distances[:, 0] < tolerance)  # Allow small tolerance for near-exact matches

def generate_synthetic_samples(generator, num_samples=1000, latent_dim=100, device="cuda"):
    """Generate synthetic samples using a trained GAN generator."""
    generator.eval()
    with torch.no_grad():
        latent_vectors = torch.randn(num_samples, latent_dim).to(device)
        generated_samples = generator(latent_vectors)
    return generated_samples.cpu()

def create_dataloader_from_samples(samples, batch_size=64):
    return DataLoader(samples, batch_size=batch_size, shuffle=False)

def convert_to_float(metrics):
    """Ensure all values in the dictionary are JSON serializable."""
    return {k: float(v) for k, v in metrics.items()}

def evaluate_gan(gan_components, config):
    """Evaluate GAN with multiple metrics, ensuring JSON serialization."""
    device = config["device"]
    gan_type = config["gan_type"]
    multi_gan_types = ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]

    # Check if the GAN type is a multi-generator GAN
    if gan_type in multi_gan_types:
        # Evaluate each generator separately
        metrics = {}
        for gen_key in ["generator_a", "generator_b"]:
            if gen_key in gan_components:
                print(f"📊 Evaluating {gen_key} for {gan_type}...")
                generator = gan_components[gen_key]
                gen_metrics = evaluate_single_generator(generator, config)
                metrics[gen_key] = gen_metrics
    else:
        # Evaluate single-generator GAN
        generator = gan_components["generator"]
        metrics = evaluate_single_generator(generator, config)

    # Print results
    print(json.dumps(metrics, indent=4))

    # Save results
    save_report(metrics, gan_type, config["embedding_identifier"])

def evaluate_single_generator(generator, config):
    """Evaluate a single generator and return metrics."""
    device = config["device"]
    generated_samples = generate_synthetic_samples(generator, num_samples=1000, latent_dim=config["latent_dim"], device=device)
    generated_dataloader = create_dataloader_from_samples(generated_samples, batch_size=config["batch_size"])

    real_embeddings = config["original_embeddings"].cpu().numpy()
    eval_embeddings = torch.cat([batch for batch in config["eval_loader"]], dim=0).cpu().numpy()
    gen_embeddings = torch.cat([batch for batch in generated_dataloader], dim=0).cpu().numpy()

    kl_method = config.get("kl_method", "histogram")

    metrics = {
        "FID (Original vs Generated)": calculate_fid(real_embeddings, gen_embeddings),
        "FID (Original vs Eval)": calculate_fid(real_embeddings, eval_embeddings),
        "KL Divergence": calculate_kl_divergence(real_embeddings, gen_embeddings),
        "Cosine Similarity": calculate_cosine_similarity(real_embeddings, gen_embeddings),
        "Spearman Rank Correlation": rank_similarity(real_embeddings, gen_embeddings),
        "Wasserstein Distance": calculate_wasserstein_distance(real_embeddings, gen_embeddings),
        "Coverage Score": calculate_coverage_score(real_embeddings, gen_embeddings),
        "Memorization Score": calculate_memorization_score(real_embeddings, gen_embeddings),
        "Unique Embedding Ratio": unique_embedding_ratio(gen_embeddings)
    }

    # Convert to native Python floats for JSON compatibility
    metrics = convert_to_float(metrics)

    # Compute aggregate quality score
    metrics["Aggregate Quality Score"] = aggregate_quality_score(
        metrics["FID (Original vs Generated)"],
        metrics["KL Divergence"],
        metrics["Cosine Similarity"],
        metrics["Spearman Rank Correlation"]
    )

    return metrics

def save_report(metrics, gan_type, embedding_identifier):
    """Saves evaluation metrics as a JSON file with a meaningful name."""
    report_filename = f"evaluation_{gan_type}_{embedding_identifier}.json"
    report_path = os.path.join(report_dir, report_filename)

    with open(report_path, "w") as f:
        json.dump(metrics, f, indent=4)

    print(f"✅ Report saved: {report_path}")

# Loop through each GAN type
for gan_type in config["gan_types"]:
    print(f"\n🚀 Running pipeline for GAN type: {gan_type}")

    # Update the GAN type in the config
    config["gan_type"] = gan_type

    # Initialize GAN components
    gan_components = initialize_gan_components(config, gan_configurations[config["gan_type"]])

    # Run GAN training
    print("🔄 Running GAN training...")
    run_gan_training(config)

    # Evaluate GAN
    print("📊 Evaluating GAN...")
    evaluate_gan(gan_components, config)

    print(f"✅ Completed pipeline for GAN type: {gan_type}\n")


  data = torch.load(embedding_file)



🚀 Running pipeline for GAN type: WGAN-GP
🔄 Running GAN training...
🚀 Training WGAN-GP...
Epoch [1/1], Loss Critic: -80.2475, Loss Generator: -19.5842
✅ WGAN-GP training completed!
📊 Evaluating GAN...
{
    "FID (Original vs Generated)": 2363.023703740346,
    "FID (Original vs Eval)": 3.9708659932341845,
    "KL Divergence": 0.11302049284623569,
    "Cosine Similarity": 0.002239581197500229,
    "Spearman Rank Correlation": -0.003032662712662713,
    "Wasserstein Distance": 0.02205769618025731,
    "Coverage Score": 0.0,
    "Memorization Score": 0.0,
    "Unique Embedding Ratio": 1.0,
    "Aggregate Quality Score": 0.28000506380510365
}
✅ Report saved: ./reports/GANs_Evaluations/evaluation_WGAN-GP_EnhancedAutoencoder_ntxent.json
✅ Completed pipeline for GAN type: WGAN-GP


🚀 Running pipeline for GAN type: VAE-GAN
🔄 Running GAN training...
🚀 Training VAE-GAN...
Epoch [1/1], D Loss: 0.0087, G Loss: 325.4728
✅ VAE-GAN training completed!
📊 Evaluating GAN...
{
    "FID (Original vs Gener

In [None]:
import os

def list_available_embeddings(base_dir, filter_by=None):
    """
    List available embedding directories and files, optionally filtered by method.

    Args:
        base_dir (str): The base directory containing embeddings.
        filter_by (str or list, optional): Method(s) to filter by (e.g., "autoencoder", "vae").
                                           If None, all embeddings are returned.
    Returns:
        list: A list of paths to embedding files that match the filter.
    """
    embedding_paths = []

    if isinstance(filter_by, str):
        filter_by = [filter_by]  # Convert single filter to list

    for method in sorted(os.listdir(base_dir)):
        method_path = os.path.join(base_dir, method)

        # Check if it's a directory
        if os.path.isdir(method_path):
            if filter_by is None or any(f.lower() in method.lower() for f in filter_by):
                pt_files = [f for f in sorted(os.listdir(method_path)) if f.endswith(".pt")]
                for file in pt_files:
                    embedding_paths.append(os.path.join(method, file))  # Store relative path

    return embedding_paths

In [None]:
# ==========================
# CONFIGURATION & EMBEDDING LOADING
# ==========================

report_dir = "./reports/GANs_Evaluations"
os.makedirs(report_dir, exist_ok=True)

# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = {
    "gan_types": ['WGAN-GP', 'VAE-GAN', 'Contrastive-GAN', 'Cross-Domain-GAN'],
    "embedding_identifier": embedding_identifier,  # Dynamically extracted identifier
    "latent_dim": 100,
    "embedding_dim": None,  # Will be set after loading embeddings
    "num_classes": 10,
    "categorical_dim": 10,
    "epochs": 100,
    "batch_size": 64,
    "learning_rate": 1e-4,
    "device": device,
    "lambda_gp": 10,
    "beta1": 0.5,
    "beta2": 0.999,
    "save_path": "gan_model.pth",
    "eval_fraction": 0.1,  # Fraction of embeddings used for evaluation
    "kl_method": "histogram",  # Can be "histogram" or "kde"
    "show_model_architecture": False
}

# Base directory where embeddings are stored
embedding_base_dir = "./saved_embeddings/embeddings/"

# List all autoencoder embeddings
autoencoder_embeddings = list_available_embeddings(embedding_base_dir, filter_by="autoencoder")

# Loop through each embedding and run the GAN pipeline
for embedding_relative_path in autoencoder_embeddings:
    print(f"\n🚀 Processing embedding: {embedding_relative_path}")

    # Full path to the embedding file
    embedding_file = os.path.join(embedding_base_dir, embedding_relative_path)

    # Extract identifier from the path (remove directory and `_embeddings.pt`)
    embedding_identifier = embedding_relative_path.split("/")[-1].replace("_embeddings.pt", "")

    # Update config with the current embedding
    config.update({
        "embedding_identifier": embedding_identifier,
        "embedding_file": embedding_file
    })

    # Load embeddings and split
    embeddings, labels, full_data_loader = load_embeddings(embedding_file, device)
    config["embedding_dim"] = embeddings.size(1)
    train_loader, eval_loader, train_embeddings, eval_embeddings = split_embeddings(embeddings, labels, config["eval_fraction"], config["batch_size"])

    # Update config with data loaders
    config.update({
        "data_loader": train_loader,
        "data_loader_a": train_loader,
        "data_loader_b": train_loader,
        "eval_loader": eval_loader,
        "original_embeddings": embeddings  # Store original embeddings for evaluation
    })

    # Loop through each GAN type
    for gan_type in config["gan_types"]:
        print(f"\n🚀 Running pipeline for GAN type: {gan_type}")

        # Update the GAN type in the config
        config["gan_type"] = gan_type

        # Initialize GAN components
        gan_components = initialize_gan_components(config, gan_configurations[config["gan_type"]])

        # Run GAN training
        print("🔄 Running GAN training...")
        run_gan_training(config)

        # Evaluate GAN
        print("📊 Evaluating GAN...")
        evaluate_gan(gan_components, config)

        print(f"✅ Completed pipeline for GAN type: {gan_type}\n")


INFO - Loading embeddings from: ./saved_embeddings/embeddings/autoencoder_AdvancedAutoencoder_barlow_twins/AdvancedAutoencoder_barlow_twins_embeddings.pt



🚀 Processing embedding: autoencoder_AdvancedAutoencoder_barlow_twins/AdvancedAutoencoder_barlow_twins_embeddings.pt

🚀 Running pipeline for GAN type: WGAN-GP
🔄 Running GAN training...
🚀 Training WGAN-GP...


  data = torch.load(embedding_file)


Epoch [1/100], Loss Critic: -239.6185, Loss Generator: -53.4397
Epoch [2/100], Loss Critic: -136.9057, Loss Generator: -55.1639
Epoch [3/100], Loss Critic: -125.6706, Loss Generator: -66.4566
Epoch [4/100], Loss Critic: -62.6210, Loss Generator: -64.3935
Epoch [5/100], Loss Critic: -61.0054, Loss Generator: -10.6017
Epoch [6/100], Loss Critic: -32.9666, Loss Generator: 0.6043
Epoch [7/100], Loss Critic: -39.3012, Loss Generator: -3.1413
Epoch [8/100], Loss Critic: -31.1133, Loss Generator: -5.5190
Epoch [9/100], Loss Critic: -34.1949, Loss Generator: -5.2892
Epoch [10/100], Loss Critic: -28.9618, Loss Generator: 4.1461
Epoch [11/100], Loss Critic: -34.5196, Loss Generator: 5.0132
Epoch [12/100], Loss Critic: -55.4187, Loss Generator: 8.3848
Epoch [13/100], Loss Critic: -82.0642, Loss Generator: -2.0862
Epoch [14/100], Loss Critic: -78.6131, Loss Generator: 0.4594
Epoch [15/100], Loss Critic: -76.8143, Loss Generator: 0.6475
Epoch [16/100], Loss Critic: -77.6656, Loss Generator: -6.1789

INFO - Loading embeddings from: ./saved_embeddings/embeddings/autoencoder_AdvancedAutoencoder_contrastive/AdvancedAutoencoder_contrastive_embeddings.pt


{
    "FID (Original vs Generated)": 9805.101701114867,
    "FID (Original vs Eval)": 14.572063988153559,
    "KL Divergence": 1.087592931943656,
    "Cosine Similarity": 0.022753773257136345,
    "Spearman Rank Correlation": 0.0001342818142818143,
    "Wasserstein Distance": 0.20614021000235144,
    "Coverage Score": 0.0,
    "Memorization Score": 0.0,
    "Unique Embedding Ratio": 1.0,
    "Aggregate Quality Score": 0.2004090919594395
}
✅ Report saved: ./reports/GANs_Evaluations/evaluation_Cross-Domain-GAN_AdvancedAutoencoder_barlow_twins.json
✅ Completed pipeline for GAN type: Cross-Domain-GAN


🚀 Processing embedding: autoencoder_AdvancedAutoencoder_contrastive/AdvancedAutoencoder_contrastive_embeddings.pt

🚀 Running pipeline for GAN type: WGAN-GP
🔄 Running GAN training...
🚀 Training WGAN-GP...


  data = torch.load(embedding_file)


Epoch [1/100], Loss Critic: -344.0791, Loss Generator: -32.8177
Epoch [2/100], Loss Critic: -254.3502, Loss Generator: -67.9747
Epoch [3/100], Loss Critic: -207.2059, Loss Generator: -83.5320
Epoch [4/100], Loss Critic: -198.7449, Loss Generator: -77.0375
Epoch [5/100], Loss Critic: -163.8967, Loss Generator: -69.1879
Epoch [6/100], Loss Critic: -133.4260, Loss Generator: -15.3523
Epoch [7/100], Loss Critic: -107.1184, Loss Generator: -15.4292
Epoch [8/100], Loss Critic: -69.3341, Loss Generator: -17.4251
Epoch [9/100], Loss Critic: -68.5634, Loss Generator: -23.3014
Epoch [10/100], Loss Critic: -68.8129, Loss Generator: -21.5212
Epoch [11/100], Loss Critic: -61.1024, Loss Generator: -21.2977
Epoch [12/100], Loss Critic: -39.8845, Loss Generator: -8.8829
Epoch [13/100], Loss Critic: -36.1329, Loss Generator: -0.5370
Epoch [14/100], Loss Critic: -32.8410, Loss Generator: -9.2272
Epoch [15/100], Loss Critic: -28.4055, Loss Generator: -8.9424
Epoch [16/100], Loss Critic: -16.9508, Loss Ge

INFO - Loading embeddings from: ./saved_embeddings/embeddings/autoencoder_AdvancedAutoencoder_info_nce/AdvancedAutoencoder_info_nce_embeddings.pt


{
    "FID (Original vs Generated)": 11610.250352752015,
    "FID (Original vs Eval)": 21.660818827412584,
    "KL Divergence": 0.48849150040044664,
    "Cosine Similarity": 0.003447829745709896,
    "Spearman Rank Correlation": -0.003566887886887888,
    "Wasserstein Distance": 0.0318328487486994,
    "Coverage Score": 0.0,
    "Memorization Score": 0.0,
    "Unique Embedding Ratio": 1.0,
    "Aggregate Quality Score": 0.23473154689242354
}
✅ Report saved: ./reports/GANs_Evaluations/evaluation_Cross-Domain-GAN_AdvancedAutoencoder_contrastive.json
✅ Completed pipeline for GAN type: Cross-Domain-GAN


🚀 Processing embedding: autoencoder_AdvancedAutoencoder_info_nce/AdvancedAutoencoder_info_nce_embeddings.pt

🚀 Running pipeline for GAN type: WGAN-GP
🔄 Running GAN training...
🚀 Training WGAN-GP...


  data = torch.load(embedding_file)


Epoch [1/100], Loss Critic: -537.3763, Loss Generator: -49.0260
Epoch [2/100], Loss Critic: -413.9948, Loss Generator: -103.7260
Epoch [3/100], Loss Critic: -379.7723, Loss Generator: -115.3143
Epoch [4/100], Loss Critic: -326.1129, Loss Generator: -110.3955
Epoch [5/100], Loss Critic: -265.7595, Loss Generator: -109.8174
Epoch [6/100], Loss Critic: -259.5085, Loss Generator: -94.4175
Epoch [7/100], Loss Critic: -191.3212, Loss Generator: -79.1310
Epoch [8/100], Loss Critic: -142.6521, Loss Generator: -10.7109
Epoch [9/100], Loss Critic: -122.6847, Loss Generator: -30.4980
Epoch [10/100], Loss Critic: -97.8936, Loss Generator: -25.4067
Epoch [11/100], Loss Critic: -90.9831, Loss Generator: -22.1375
Epoch [12/100], Loss Critic: -84.2026, Loss Generator: -48.8489
Epoch [13/100], Loss Critic: -78.8044, Loss Generator: -35.7615
Epoch [14/100], Loss Critic: -55.5019, Loss Generator: -9.2877
Epoch [15/100], Loss Critic: -58.6011, Loss Generator: 1.6077
Epoch [16/100], Loss Critic: -53.6036, 

INFO - Loading embeddings from: ./saved_embeddings/embeddings/autoencoder_AdvancedAutoencoder_mse/AdvancedAutoencoder_mse_embeddings.pt


{
    "FID (Original vs Generated)": 20836.58342176412,
    "FID (Original vs Eval)": 43.55558403556802,
    "KL Divergence": 0.504732735377227,
    "Cosine Similarity": -0.0037232432514429092,
    "Spearman Rank Correlation": 0.009121525441525443,
    "Wasserstein Distance": 0.03150981391654281,
    "Coverage Score": 0.0,
    "Memorization Score": 0.0,
    "Unique Embedding Ratio": 1.0,
    "Aggregate Quality Score": 0.2331006688826057
}
✅ Report saved: ./reports/GANs_Evaluations/evaluation_Cross-Domain-GAN_AdvancedAutoencoder_info_nce.json
✅ Completed pipeline for GAN type: Cross-Domain-GAN


🚀 Processing embedding: autoencoder_AdvancedAutoencoder_mse/AdvancedAutoencoder_mse_embeddings.pt

🚀 Running pipeline for GAN type: WGAN-GP
🔄 Running GAN training...
🚀 Training WGAN-GP...
Epoch [1/100], Loss Critic: -84.3829, Loss Generator: -0.0751


  data = torch.load(embedding_file)


Epoch [2/100], Loss Critic: -187.2560, Loss Generator: -0.2011
Epoch [3/100], Loss Critic: -316.5254, Loss Generator: -0.4725
Epoch [4/100], Loss Critic: -474.1641, Loss Generator: -0.9036
Epoch [5/100], Loss Critic: -639.9299, Loss Generator: -1.5585
Epoch [6/100], Loss Critic: -842.8292, Loss Generator: -2.6515
Epoch [7/100], Loss Critic: -1099.3115, Loss Generator: -3.7667
Epoch [8/100], Loss Critic: -1363.0565, Loss Generator: -5.8119
Epoch [9/100], Loss Critic: -1625.7433, Loss Generator: -7.9601
Epoch [10/100], Loss Critic: -1981.9419, Loss Generator: -10.3709
Epoch [11/100], Loss Critic: -2242.9861, Loss Generator: -14.2733
Epoch [12/100], Loss Critic: -2630.1753, Loss Generator: -18.6024
Epoch [13/100], Loss Critic: -2990.2808, Loss Generator: -23.1401
Epoch [14/100], Loss Critic: -3458.2817, Loss Generator: -29.0569
Epoch [15/100], Loss Critic: -3907.2100, Loss Generator: -36.1066
Epoch [16/100], Loss Critic: -4408.2520, Loss Generator: -42.2402
Epoch [17/100], Loss Critic: -4

RuntimeError: all elements of input should be between 0 and 1

In [None]:
# ==========================
# CONFIGURATION & EMBEDDING LOADING
# ==========================

report_dir = "./reports/GANs_Evaluations"
os.makedirs(report_dir, exist_ok=True)

# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = {
    "gan_types": ['WGAN-GP', 'VAE-GAN', 'Contrastive-GAN', 'Cross-Domain-GAN'],
    "embedding_identifier": embedding_identifier,  # Dynamically extracted identifier
    "latent_dim": 100,
    "embedding_dim": None,  # Will be set after loading embeddings
    "num_classes": 10,
    "categorical_dim": 10,
    "epochs": 100,
    "batch_size": 64,
    "learning_rate": 1e-4,
    "device": device,
    "lambda_gp": 10,
    "beta1": 0.5,
    "beta2": 0.999,
    "save_path": "gan_model.pth",
    "eval_fraction": 0.1,  # Fraction of embeddings used for evaluation
    "kl_method": "histogram",  # Can be "histogram" or "kde"
    "show_model_architecture": False
}

# Base directory where embeddings are stored
embedding_base_dir = "./saved_embeddings/embeddings/"

# List all autoencoder embeddings
autoencoder_embeddings = list_available_embeddings(embedding_base_dir, filter_by="vae")

# Loop through each embedding and run the GAN pipeline
for embedding_relative_path in autoencoder_embeddings:
    print(f"\n🚀 Processing embedding: {embedding_relative_path}")

    # Full path to the embedding file
    embedding_file = os.path.join(embedding_base_dir, embedding_relative_path)

    # Extract identifier from the path (remove directory and `_embeddings.pt`)
    embedding_identifier = embedding_relative_path.split("/")[-1].replace("_embeddings.pt", "")

    # Update config with the current embedding
    config.update({
        "embedding_identifier": embedding_identifier,
        "embedding_file": embedding_file
    })

    # Load embeddings and split
    embeddings, labels, full_data_loader = load_embeddings(embedding_file, device)
    config["embedding_dim"] = embeddings.size(1)
    train_loader, eval_loader, train_embeddings, eval_embeddings = split_embeddings(embeddings, labels, config["eval_fraction"], config["batch_size"])

    # Update config with data loaders
    config.update({
        "data_loader": train_loader,
        "data_loader_a": train_loader,
        "data_loader_b": train_loader,
        "eval_loader": eval_loader,
        "original_embeddings": embeddings  # Store original embeddings for evaluation
    })

    # Loop through each GAN type
    for gan_type in config["gan_types"]:
        print(f"\n🚀 Running pipeline for GAN type: {gan_type}")

        # Update the GAN type in the config
        config["gan_type"] = gan_type

        # Initialize GAN components
        gan_components = initialize_gan_components(config, gan_configurations[config["gan_type"]])

        # Run GAN training
        print("🔄 Running GAN training...")
        run_gan_training(config)

        # Evaluate GAN
        print("📊 Evaluating GAN...")
        evaluate_gan(gan_components, config)

        print(f"✅ Completed pipeline for GAN type: {gan_type}\n")



🚀 Processing embedding: vae_BasicVAE_mse/BasicVAE_mse_embeddings.pt


  data = torch.load(embedding_file)



🚀 Running pipeline for GAN type: WGAN-GP
🔄 Running GAN training...
🚀 Training WGAN-GP...
Epoch [1/100], Loss Critic: 6.0706, Loss Generator: -0.0817
Epoch [2/100], Loss Critic: 4.3931, Loss Generator: -0.2675
Epoch [3/100], Loss Critic: 3.1304, Loss Generator: -0.5523
Epoch [4/100], Loss Critic: 2.2276, Loss Generator: -0.9789
Epoch [5/100], Loss Critic: 1.8222, Loss Generator: -1.4033
Epoch [6/100], Loss Critic: 1.8292, Loss Generator: -1.6371
Epoch [7/100], Loss Critic: 2.1240, Loss Generator: -2.0117
Epoch [8/100], Loss Critic: 2.1363, Loss Generator: -2.0090
Epoch [9/100], Loss Critic: 2.1470, Loss Generator: -2.0071
Epoch [10/100], Loss Critic: 1.9713, Loss Generator: -1.8577
Epoch [11/100], Loss Critic: 1.9361, Loss Generator: -1.7511
Epoch [12/100], Loss Critic: 1.6118, Loss Generator: -1.4822
Epoch [13/100], Loss Critic: 1.3775, Loss Generator: -1.2311
Epoch [14/100], Loss Critic: 0.9190, Loss Generator: -0.7885
Epoch [15/100], Loss Critic: 0.4088, Loss Generator: -0.2661
Epoc

  data = torch.load(embedding_file)



🚀 Running pipeline for GAN type: WGAN-GP
🔄 Running GAN training...
🚀 Training WGAN-GP...
Epoch [1/100], Loss Critic: 6.1957, Loss Generator: -0.0264
Epoch [2/100], Loss Critic: 4.4998, Loss Generator: -0.1535
Epoch [3/100], Loss Critic: 3.2681, Loss Generator: -0.4068
Epoch [4/100], Loss Critic: 2.3963, Loss Generator: -0.8001
Epoch [5/100], Loss Critic: 1.9303, Loss Generator: -1.2335
Epoch [6/100], Loss Critic: 1.9383, Loss Generator: -1.5336
Epoch [7/100], Loss Critic: 2.0846, Loss Generator: -1.7862
Epoch [8/100], Loss Critic: 2.1941, Loss Generator: -1.9137
Epoch [9/100], Loss Critic: 2.1647, Loss Generator: -1.8367
Epoch [10/100], Loss Critic: 2.1126, Loss Generator: -1.7505
Epoch [11/100], Loss Critic: 2.0124, Loss Generator: -1.6491
Epoch [12/100], Loss Critic: 1.6782, Loss Generator: -1.3703
Epoch [13/100], Loss Critic: 1.5018, Loss Generator: -1.1898
Epoch [14/100], Loss Critic: 1.1588, Loss Generator: -0.8881
Epoch [15/100], Loss Critic: 0.7634, Loss Generator: -0.4966
Epoc

  data = torch.load(embedding_file)



🚀 Running pipeline for GAN type: WGAN-GP
🔄 Running GAN training...
🚀 Training WGAN-GP...
Epoch [1/100], Loss Critic: 6.5923, Loss Generator: -0.0701
Epoch [2/100], Loss Critic: 4.5649, Loss Generator: -0.2277
Epoch [3/100], Loss Critic: 2.9949, Loss Generator: -0.5184
Epoch [4/100], Loss Critic: 1.9001, Loss Generator: -0.8854
Epoch [5/100], Loss Critic: 1.6276, Loss Generator: -1.2805
Epoch [6/100], Loss Critic: 1.7391, Loss Generator: -1.5802
Epoch [7/100], Loss Critic: 1.9134, Loss Generator: -1.7463
Epoch [8/100], Loss Critic: 1.8855, Loss Generator: -1.7236
Epoch [9/100], Loss Critic: 1.9209, Loss Generator: -1.7360
Epoch [10/100], Loss Critic: 1.7077, Loss Generator: -1.5300
Epoch [11/100], Loss Critic: 1.4827, Loss Generator: -1.3124
Epoch [12/100], Loss Critic: 1.1819, Loss Generator: -1.0283
Epoch [13/100], Loss Critic: 0.7023, Loss Generator: -0.5763
Epoch [14/100], Loss Critic: 0.1831, Loss Generator: -0.0842
Epoch [15/100], Loss Critic: -0.2985, Loss Generator: 0.3934
Epoc

  data = torch.load(embedding_file)



🚀 Running pipeline for GAN type: WGAN-GP
🔄 Running GAN training...
🚀 Training WGAN-GP...
Epoch [1/100], Loss Critic: 6.5949, Loss Generator: -0.0593
Epoch [2/100], Loss Critic: 4.7480, Loss Generator: -0.2180
Epoch [3/100], Loss Critic: 3.2602, Loss Generator: -0.5442
Epoch [4/100], Loss Critic: 2.0770, Loss Generator: -0.9102
Epoch [5/100], Loss Critic: 1.7844, Loss Generator: -1.3975
Epoch [6/100], Loss Critic: 1.9360, Loss Generator: -1.6760
Epoch [7/100], Loss Critic: 2.0279, Loss Generator: -1.7469
Epoch [8/100], Loss Critic: 1.9167, Loss Generator: -1.6801
Epoch [9/100], Loss Critic: 1.9567, Loss Generator: -1.6018
Epoch [10/100], Loss Critic: 1.7125, Loss Generator: -1.4487
Epoch [11/100], Loss Critic: 1.4856, Loss Generator: -1.2002
Epoch [12/100], Loss Critic: 1.2574, Loss Generator: -0.9973
Epoch [13/100], Loss Critic: 0.8711, Loss Generator: -0.6681
Epoch [14/100], Loss Critic: 0.6037, Loss Generator: -0.4044
Epoch [15/100], Loss Critic: 0.2901, Loss Generator: -0.0910
Epoc

In [None]:
import time
from rich.progress import track
from rich.table import Table
from rich.console import Console
import pandas as pd

# Initialize Rich console for pretty printing
console = Console()

def print_summary_table(results):
    """Print a summary table of all results."""
    table = Table(title="GAN Evaluation Summary", show_header=True, header_style="bold magenta")
    table.add_column("Embedding", style="cyan")
    table.add_column("GAN Type", style="green")
    table.add_column("FID (Original vs Generated)", justify="right")
    table.add_column("Aggregate Quality Score", justify="right")
    table.add_column("Unique Embedding Ratio", justify="right")
    table.add_column("Time Taken (s)", justify="right")

    for result in results:
        table.add_row(
            result["embedding"],
            result["gan_type"],
            f"{result['metrics']['FID (Original vs Generated)']:.4f}",
            f"{result['metrics']['Aggregate Quality Score']:.4f}",
            f"{result['metrics']['Unique Embedding Ratio']:.4f}",
            f"{result['time_taken']:.2f}"
        )

    console.print(table)

def save_results_to_csv(results, filename="gan_evaluation_results.csv"):
    """Save all results to a CSV file."""
    df = pd.DataFrame(results)
    df.to_csv(filename, index=False)
    console.print(f"✅ All results saved to [bold green]{filename}[/bold green]")

def run_pipeline(config, embedding_relative_path):
    """Run the GAN pipeline for a single embedding."""
    results = []

    # Full path to the embedding file
    embedding_file = os.path.join(embedding_base_dir, embedding_relative_path)

    # Extract identifier from the path (remove directory and `_embeddings.pt`)
    embedding_identifier = embedding_relative_path.split("/")[-1].replace("_embeddings.pt", "")

    # Update config with the current embedding
    config.update({
        "embedding_identifier": embedding_identifier,
        "embedding_file": embedding_file
    })

    # Load embeddings and split
    embeddings, labels, full_data_loader = load_embeddings(embedding_file, device)
    config["embedding_dim"] = embeddings.size(1)
    train_loader, eval_loader, train_embeddings, eval_embeddings = split_embeddings(embeddings, labels, config["eval_fraction"], config["batch_size"])

    # Update config with data loaders
    config.update({
        "data_loader": train_loader,
        "data_loader_a": train_loader,
        "data_loader_b": train_loader,
        "eval_loader": eval_loader,
        "original_embeddings": embeddings  # Store original embeddings for evaluation
    })

    # Loop through each GAN type
    for gan_type in config["gan_types"]:
        console.print(f"\n🚀 [bold cyan]Running pipeline for GAN type: {gan_type}[/bold cyan]")

        # Update the GAN type in the config
        config["gan_type"] = gan_type

        # Initialize GAN components
        gan_components = initialize_gan_components(config, gan_configurations[config["gan_type"]])

        # Run GAN training
        start_time = time.time()
        console.print("🔄 [bold yellow]Running GAN training...[/bold yellow]")
        run_gan_training(config)

        # Evaluate GAN
        console.print("📊 [bold yellow]Evaluating GAN...[/bold yellow]")
        metrics = evaluate_gan(gan_components, config)

        # Calculate time taken
        time_taken = time.time() - start_time

        # Store results
        results.append({
            "embedding": embedding_identifier,
            "gan_type": gan_type,
            "metrics": metrics,
            "time_taken": time_taken
        })

        console.print(f"✅ [bold green]Completed pipeline for GAN type: {gan_type}[/bold green]")

    return results

# Main execution
all_results = []
autoencoder_embeddings = list_available_embeddings(embedding_base_dir, filter_by="vae")

for embedding_relative_path in track(autoencoder_embeddings, description="Processing embeddings..."):
    console.print(f"\n🚀 [bold blue]Processing embedding: {embedding_relative_path}[/bold blue]")
    results = run_pipeline(config, embedding_relative_path)
    all_results.extend(results)

# Print summary table
print_summary_table(all_results)

# Save all results to a CSV file
save_results_to_csv(all_results)

Output()

RecursionError: maximum recursion depth exceeded in comparison

In [None]:
from torch.utils.data import TensorDataset, DataLoader

def load_embeddings_v2(embedding_file, device, batch_size=64):
    """
    Loads embeddings and their associated labels from a specified file and
    creates a DataLoader for batching the embeddings.

    Args:
        embedding_file (str): Path to the file containing embeddings and labels.
        device (torch.device): The device (CPU/GPU) to load the tensors onto.
        batch_size (int, optional): The batch size for DataLoader. Default is 64.

    Returns:
        tuple: A tuple containing:
            - embeddings (torch.Tensor): Loaded embeddings.
            - labels (torch.Tensor): Corresponding labels for the embeddings.
            - data_loader (DataLoader): DataLoader for batching embeddings.
    """
    logger.info(f"Loading embeddings from: {embedding_file}")
    data = torch.load(embedding_file)
    embeddings = data["embeddings"].to(device)
    labels = data["labels"].to(device)

    # Create a TensorDataset and DataLoader
    dataset = TensorDataset(embeddings, labels)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return embeddings, labels, data_loader

# Now it will work correctly!
embeddings, labels, data_loader = load_embeddings_v2(embedding_file, device)


# interesting way to load the embeddings...

from torch.utils.data import DataLoader, TensorDataset

def load_embeddings(embedding_file, device, batch_size=64, return_labels=True):
    """
    Loads embeddings and their associated labels from a specified file and
    creates a DataLoader for batching the embeddings (and optionally labels).

    Args:
        embedding_file (str): Path to the file containing embeddings and labels.
        device (torch.device): The device (CPU/GPU) to load the tensors onto.
        batch_size (int, optional): The batch size for DataLoader. Default is 64.
        return_labels (bool, optional): Whether to include labels in the DataLoader. Default is True.

    Returns:
        tuple: A tuple containing:
            - embeddings (torch.Tensor): Loaded embeddings.
            - labels (torch.Tensor, optional): Corresponding labels for the embeddings (if return_labels=True).
            - data_loader (DataLoader): DataLoader for batching embeddings (and labels if required).
    """
    print(f"Loading embeddings from: {embedding_file}")
    data = torch.load(embedding_file)
    embeddings = data["embeddings"].to(device)

    # Initialize the DataLoader only with embeddings if labels are not required
    if return_labels:
        labels = data["labels"].to(device)
        # Create a TensorDataset containing both embeddings and labels
        dataset = TensorDataset(embeddings, labels)
        return embeddings, labels, DataLoader(dataset, batch_size=batch_size, shuffle=True)
    else:
        # Create DataLoader for embeddings only
        dataset = TensorDataset(embeddings)  # Just embeddings
        return embeddings, DataLoader(dataset, batch_size=batch_size, shuffle=True)

embeddings, labels, data_loader_v2 = load_embeddings(embedding_file, device)


INFO - Loading embeddings from: ./saved_embeddings/embeddings/autoencoders_BasicAutoencoder/BasicAutoencoder_embeddings.pt
  data = torch.load(embedding_file)


In [None]:
import torch

# ========================
# MAIN CONFIGURATION
# ========================
config = {
    "gan_type": "Semi-Supervised-GAN",  # Change this to switch models
    "latent_dim": 100,  # Latent space dimension
    "embedding_dim": embeddings.size(1),  # Embedding dimension
    "num_classes": 10,  # For conditional models
    "categorical_dim": 10,  # For InfoGAN
    "epochs": 1,
    "batch_size": 64,
    "learning_rate": 1e-4,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "lambda_gp": 10,  # For WGAN-GP
    "beta1": 0.5,  # Adam optimizer beta1
    "beta2": 0.999,  # Adam optimizer beta2
    "save_path": "gan_model.pth",  # Path to save the model
    "data_loader": data_loader_v2,  # Main data loader
    "data_loader_a": data_loader,  # For cross-domain models
    "data_loader_b": data_loader,  # For cross-domain models
    "show_model_architecture": True  # Toggle to print model architectures
}

# ========================
# MODEL INITIALIZATION
# ========================

def initialize_gan_components(config, gan_config):
    """Initialize all components for the specified GAN type"""
    components = {}
    gan_type = config["gan_type"]
    multi_gan_types = ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]

    if gan_type in multi_gan_types:
        gen_args = {"embedding_dim": config["embedding_dim"]}
        if gan_type != "Cycle-GAN":
            gen_args["latent_dim"] = config["latent_dim"]

        components.update({
            "generator_a": gan_config["generator_a"](**gen_args).to(config["device"]),
            "generator_b": gan_config["generator_b"](**gen_args).to(config["device"]),
            "discriminator_a": gan_config["discriminator_a"](embedding_dim=config["embedding_dim"]).to(config["device"]),
            "discriminator_b": gan_config["discriminator_b"](embedding_dim=config["embedding_dim"]).to(config["device"])
        })
    else:
        gen_args = {
            "latent_dim": config["latent_dim"],
            "embedding_dim": config["embedding_dim"]
        }
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            gen_args["num_classes"] = config["num_classes"]
        if gan_type == "InfoGAN":
            gen_args["categorical_dim"] = config["categorical_dim"]

        components["generator"] = gan_config["generator"](**gen_args).to(config["device"])

    if gan_type == "WGAN-GP":
        components["critic"] = gan_config["critic"](embedding_dim=config["embedding_dim"]).to(config["device"])
    elif gan_type == "VAE-GAN":
        components["encoder"] = gan_config["encoder"](embedding_dim=config["embedding_dim"], latent_dim=config["latent_dim"]).to(config["device"])
        components["discriminator"] = gan_config["discriminator"](embedding_dim=config["embedding_dim"]).to(config["device"])
    elif gan_type not in multi_gan_types:
        disc_args = {"embedding_dim": config["embedding_dim"]}
        if gan_type in ["Conditional-GAN", "InfoGAN", "Semi-Supervised-GAN"]:
            disc_args["num_classes"] = config["num_classes"]
        if gan_type == "InfoGAN":
            disc_args["categorical_dim"] = config["categorical_dim"]

        components["discriminator"] = gan_config["discriminator"](**disc_args).to(config["device"])

    if config["show_model_architecture"]:
        print(f"Initialized components: {components}")

    return components

# ========================
# MAIN EXECUTION
# ========================

def run_gan_training(config):
    """Main function to initialize and train the GAN"""
    gan_type = config["gan_type"]
    if gan_type not in gan_configurations:
        raise ValueError(f"Unsupported GAN type: {gan_type}")

    gan_config = gan_configurations[gan_type]
    components = initialize_gan_components(config, gan_config)
    train_function = gan_config["train_function"]

    print(f"🚀 Training {gan_type}...")

    if gan_type == "VAE-GAN":
        train_function(components["encoder"], components["generator"], components["discriminator"], **config)
    elif gan_type in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]:
        train_function(components["generator_a"], components["generator_b"], components["discriminator_a"], components["discriminator_b"], **config)
    else:
        discriminator = components.get("discriminator", components.get("critic", None))
        train_function(components["generator"], discriminator, **config)

    print(f"✅ {gan_type} training completed!")

# ========================
# EXECUTION
# ========================
run_gan_training(config)


Initialized components: {'generator': SimpleGANGenerator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=100, out_features=1024, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=1024, out_features=512, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (2): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (3): Linear(in_features=256, out_features=50, bias=True)
    (4): Tanh()
  )
), 'discriminator': SemiSupervisedGANDiscriminator(
  (shared_model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=50, out_features=512, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): Li

In [None]:
# GAN Configuration
gan_type = "Semi-Supervised-GAN"  # Change to the desired GAN type
latent_dim = 100  # Latent dimension for the generator
embedding_dim = embeddings.size(1)  # Embedding dimension based on loaded embeddings
num_classes = 10  # For Conditional GAN and InfoGAN (e.g., 10 classes for MNIST)
categorical_dim = 10
epochs = 1
learning_rate = 0.0001
# Ensure device is properly set up
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the data loaders before the GAN models
embedding_loader_a = data_loader
embedding_loader_b = data_loader
embedding_loader = data_loader_v2

# Initialize the correct GAN models based on `gan_type`
if gan_type in gan_configurations:
    config = gan_configurations[gan_type]
    print(f"Initializing {gan_type} configuration: {config}")

    # Generator initialization for Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        if gan_type == "Cycle-GAN":
            # Cycle-GAN requires two generators and two discriminators
            generator_a = config["generator_a"](embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"CycleGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
        else:
            # Dual-GAN or Contrastive-Dual-GAN requires both latent_dim and embedding_dim for generators
            generator_a = config["generator_a"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"DualGAN/ContrastiveDualGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
    else:
        # For other GANs (WGAN-GP, VAE-GAN, etc.), initialize a single generator
        generator_args = {"latent_dim": latent_dim, "embedding_dim": embedding_dim}
        if gan_type == "InfoGAN":
            # Correctly initialize InfoGAN generator with latent_dim and categorical_dim
            generator_args["categorical_dim"] = categorical_dim
        if gan_type == "Conditional-GAN":
            # Conditional-GAN generator requires num_classes
            generator_args["num_classes"] = num_classes
        generator = config["generator"](**generator_args).to(device)
        print(f"Generator initialized on {generator}")

    # Handle critic/discriminator initialization conditionally
    if gan_type == "WGAN-GP":
        # WGAN-GP uses a critic instead of a discriminator
        critic = config["critic"](embedding_dim=embedding_dim).to(device)
        print(f"Critic initialized on {critic}")
    elif gan_type == "Semi-Supervised-GAN":
        # Semi-Supervised-GAN requires num_classes for the discriminator
        discriminator = config["discriminator"](embedding_dim=embedding_dim, num_classes=num_classes).to(device)
        print(f"Semi-Supervised-GAN discriminator initialized on {discriminator}")
    elif gan_type == "InfoGAN":
        # InfoGAN discriminator requires embedding_dim and categorical_dim
        discriminator = config["discriminator"](embedding_dim=embedding_dim, categorical_dim=categorical_dim).to(device)
        print(f"InfoGAN discriminator initialized on {discriminator}")
    elif gan_type == "Conditional-GAN":
        # Conditional-GAN discriminator requires embedding_dim and num_classes
        discriminator = config["discriminator"](embedding_dim=embedding_dim, num_classes=num_classes).to(device)
        print(f"Conditional-GAN discriminator initialized on {discriminator}")
    elif gan_type not in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]:
        # For other models (VAE-GAN, Conditional-GAN, etc.), initialize a single discriminator
        discriminator_args = {"embedding_dim": embedding_dim}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            discriminator_args["num_classes"] = num_classes
        if gan_type == "InfoGAN":
            discriminator_args["categorical_dim"] = categorical_dim
        discriminator = config["discriminator"](**discriminator_args).to(device)
        print(f"Discriminator initialized on {discriminator}")
    else:
        # If discriminator is part of the configuration (Cycle-GAN, Dual-GAN, etc.), use the ones in the config
        discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
        discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
        print(f"Discriminators initialized on {discriminator_a}, {discriminator_b}")

    # Initialize encoder for VAE-GAN
    if gan_type == "VAE-GAN":
        encoder = config["encoder"](embedding_dim=embedding_dim, latent_dim=latent_dim).to(device)
        print(f"Encoder initialized on {encoder}")

    # Select the appropriate training function
    train_function = config["train_function"]

    # Handle any additional configurations (like learning rate, lambda_gp)
    train_kwargs = config.get("train_kwargs", {})  # Default to empty dict if not specified
    train_kwargs.update({
        "latent_dim": latent_dim,
        "epochs": epochs,  # Pass epochs as part of train_kwargs
        "device": device,
        "learning_rate": learning_rate,
        "num_classes": num_classes if gan_type in ["Conditional-GAN", "InfoGAN", "Semi-Supervised-GAN"] else None,  # Add for Conditional, InfoGAN, and Semi-Supervised-GAN
        "categorical_dim": categorical_dim if gan_type == "InfoGAN" else None,  # Add only for InfoGAN
        "lambda_gp": 10 if gan_type == "WGAN-GP" else None,  # Add only for WGAN-GP
    })

    # Add data loaders conditionally
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        # Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN: pass two data loaders
        train_kwargs.update({
            "data_loader_a": embedding_loader_a,  # Two loaders for these models
            "data_loader_b": embedding_loader_b
        })
    else:
        # For all other models (WGAN-GP, VAE-GAN, etc.), use a single loader
        train_kwargs.update({
            "data_loader": embedding_loader  # Single data loader for other models
        })

    # Training loop (now handled inside the train_function)
    if gan_type == "VAE-GAN":
        # VAE-GAN requires encoder, generator, and discriminator
        train_function(encoder, generator, discriminator, **train_kwargs)
    else:
        # For other models, pass the appropriate generator/discriminator/critic
        train_function(generator_a, generator_b, discriminator_a, discriminator_b, **train_kwargs) if gan_type in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"] else train_function(generator, discriminator if "discriminator" in locals() else critic, **train_kwargs)

    print(f"{gan_type} training test passed!")

else:
    raise ValueError(f"Unsupported GAN type: {gan_type}")



Initializing Semi-Supervised-GAN configuration: {'generator': <class 'src.gan_workflows.plan2.plan2_gan_models.SimpleGANGenerator'>, 'discriminator': <class 'src.gan_workflows.plan2.plan2_gan_models.SemiSupervisedGANDiscriminator'>, 'train_function': <function train_semi_supervised_gan at 0x7d7ef691dda0>}
Generator initialized on SimpleGANGenerator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=100, out_features=1024, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=1024, out_features=512, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (2): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (3): Linear(in_features=256, out_features=50, bias=True)
    