# Plan 2: GAN Training with Embeddings

## Objective
The goal of this plan is to train GANs using embeddings generated from Plan 1. This involves:
1. Setting up a GAN architecture that incorporates embeddings.
2. Training the GAN to generate high-quality synthetic embeddings or data.
3. Evaluating the GAN's performance using relevant metrics.

## Key Steps

### 1. Define GAN Architecture
- **Generator**: Accepts embeddings and noise as inputs, producing synthetic data.
- **Discriminator**: Evaluates the authenticity of generated data, optionally conditioned on embeddings.

### 2. Load Pre-Generated Embeddings
- Load embeddings created in Plan 1 from their respective directories.
- Normalize and preprocess embeddings for GAN training.

### 3. Train the GAN
- Set up training loops for the generator and discriminator.
- Use appropriate loss functions, such as adversarial loss (e.g., Wasserstein loss).
- Optionally, include auxiliary tasks (e.g., reconstruction loss) for better embedding alignment.

### 4. Save Outputs
- Save trained GAN models.
- Save generated embeddings or data for downstream evaluation.

### 5. Evaluate GAN Performance
- Use metrics like FID, IS, and qualitative visualization.
- Compare performance with baseline models or methods.

---

## Starting Point

**Files Available**:
- `plan2_gan_models.py`: Contains the GAN architecture.
- `plan2_gan_training.py`: Implements the training pipeline.
- `main_plan2_gan_training.ipynb`: High-level control notebook.
- `plan2_experiments.ipynb`: For experimentation and evaluation.

## Next Steps
1. **Review Architecture**:
   - Inspect `plan2_gan_models.py` to understand the generator and discriminator setup.
2. **Pipeline Setup**:
   - Review and prepare the training pipeline in `plan2_gan_training.py`.
3. **Experimentation**:
   - Use `plan2_experiments.ipynb` for controlled experiments.


In [1]:
# Plan 2 GAN Training Notebook

import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import logging
from datetime import datetime
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset

from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Repository path (adjust if needed)
repo_path = "/content/drive/MyDrive/GAN-thesis-project"

# Add repository path to sys.path for module imports
if repo_path not in sys.path:
    sys.path.append(repo_path)

# Change working directory to the repository
os.chdir(repo_path)

# Verify the working directory
print(f"Current working directory: {os.getcwd()}")


# Set random seed and device
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Mounted at /content/drive
Current working directory: /content/drive/MyDrive/GAN-thesis-project
Using device: cpu


In [2]:
# just for you to see what are the implemented function and classes that have been coded up...

import inspect

# Import the entire modules
import src.data_utils as data_utils
import src.cl_loss_function as cl_loss
import src.losses as losses
import src.gan_workflows.plan2.plan2_gan_models as gan_models
import src.gan_workflows.plan2.plan2_gan_training as gan_training

# Function to list functions and classes in a module
def list_functions_and_classes(module):
    members = inspect.getmembers(module)
    functions = [name for name, obj in members if inspect.isfunction(obj)]
    classes = [name for name, obj in members if inspect.isclass(obj)]
    return functions, classes

# Function to print functions and classes in a readable format
def print_functions_and_classes(module_name, module):
    functions, classes = list_functions_and_classes(module)
    print(f"Module: {module_name}")
    print("  Functions:")
    for func in functions:
        print(f"    - {func}")
    print("  Classes:")
    for cls in classes:
        print(f"    - {cls}")
    print()  # Add a blank line for separation

# Print functions and classes for each module
print_functions_and_classes("src.data_utils", data_utils)
print_functions_and_classes("src.cl_loss_function", cl_loss)
print_functions_and_classes("src.losses", losses)
print_functions_and_classes("src.embeddings.encoder_models", gan_models)
print_functions_and_classes("src.embeddings.encoder_training", gan_training)

Module: src.data_utils
  Functions:
    - analyze_embeddings
    - analyze_embeddings_v2
    - create_dataloader
    - create_embedding_loaders
    - generate_embeddings
    - kurtosis
    - load_data
    - load_embeddings
    - load_mnist_data
    - pdist
    - preprocess_images
    - save_embeddings
    - skew
    - split_dataset
    - train_test_split
    - visualize_embeddings
  Classes:
    - DataLoader
    - LocalOutlierFactor
    - TensorDataset

Module: src.cl_loss_function
  Functions:
    - augment
    - compute_nt_xent_loss_with_augmentation
    - compute_triplet_loss_with_augmentation
    - contrastive_loss
    - hflip
    - info_nce_loss
    - resize
  Classes:
    - BYOLLoss
    - BarlowTwinsLoss
    - ContrastiveHead
    - DataLoader
    - NTXentLoss
    - PCA
    - Predictor
    - TensorDataset
    - TripletLoss
    - VicRegLoss

Module: src.losses
  Functions:
    - add_noise
    - cyclical_beta_schedule
    - linear_beta_schedule
    - loss_function_dae_ssim
    - vae

In [3]:
# GAN Models and Training Functions
from src.gan_workflows.plan2.plan2_gan_models import (
    SimpleGANGenerator, SimpleGANDiscriminator,
    ContrastiveGANGenerator, ContrastiveGANDiscriminator,
    VAEGANEncoder, VAEGANGenerator, VAEGANDiscriminator,
    WGANGenerator, WGANCritic,
    CrossDomainGenerator, CrossDomainDiscriminator,
    CycleGenerator, CycleDiscriminator,
    DualGANGenerator, DualGANDiscriminator,
    ContrastiveDualGANGenerator, ContrastiveDualGANDiscriminator,
    SemiSupervisedGANDiscriminator,
    ConditionalGANGenerator, ConditionalGANDiscriminator,
    InfoGANGenerator, InfoGANDiscriminator,
    compute_gradient_penalty
)

from src.gan_workflows.plan2.plan2_gan_training import (
    train_wgan_gp, train_vae_gan, train_contrastive_gan,
    train_cross_domain_gan, train_cycle_gan,
    train_dual_gan, train_contrastive_dual_gan,
    train_semi_supervised_gan,
    train_conditional_gan,
    train_infogan
)

from src.data_utils import (
    load_embeddings, analyze_embeddings
)

In [35]:
embedding_dir = "./saved_embeddings/embeddings/autoencoders_BasicAutoencoder"  # Example embedding path
embedding_file = os.path.join(embedding_dir, "BasicAutoencoder_embeddings.pt")

# Load embeddings and labels using the new function
embeddings, labels, data_loader = load_embeddings(embedding_file, device)

INFO - Loading embeddings from: ./saved_embeddings/embeddings/autoencoders_BasicAutoencoder/BasicAutoencoder_embeddings.pt
  data = torch.load(embedding_file)


In [None]:
# ## OLD VERSION

# def validate_embeddings(embeddings):
#     """
#     Validate and provide information about the shape and data type of embeddings.
#     """
#     if embeddings is None or len(embeddings) == 0:
#         raise ValueError("Embeddings are empty or not properly generated.")
#     print(f"Embeddings are of shape: {embeddings.shape}")
#     print(f"Data type: {embeddings.dtype}")
#     print(f"Device: {embeddings.device}")
#     if torch.isnan(embeddings).any():
#         raise ValueError("Embeddings contain NaN values.")
#     if torch.isinf(embeddings).any():
#         raise ValueError("Embeddings contain infinite values.")
#     if embeddings.ndim != 2:
#         raise ValueError("Embeddings should be a 2D tensor.")
#     print("Embeddings validation passed.")

# validate_embeddings(embeddings)

In [18]:
# Get a batch of training data
embedding_batch = next(iter(data_loader))
print('embedding batches', embedding_batch.shape)

embedding batches torch.Size([64, 50])


In [83]:
# GAN Configuration Dictionary
gan_configurations = {
    "WGAN-GP": {
        "generator": WGANGenerator,
        "critic": WGANCritic,
        "train_function": train_wgan_gp,
        "train_kwargs": {"lambda_gp": 10}
    },
    "VAE-GAN": {
        "encoder": VAEGANEncoder,
        "generator": VAEGANGenerator,
        "discriminator": VAEGANDiscriminator,
        "train_function": train_vae_gan
    },
    "Contrastive-GAN": {
        "generator": ContrastiveGANGenerator,
        "discriminator": ContrastiveGANDiscriminator,
        "train_function": train_contrastive_gan
    },
    "Cross-Domain-GAN": {
        "generator": CrossDomainGenerator,
        "discriminator": CrossDomainDiscriminator,
        "train_function": train_cross_domain_gan
    },
    "Cycle-GAN": {
        "generator_a": CycleGenerator,
        "generator_b": CycleGenerator,
        "discriminator_a": CycleDiscriminator,
        "discriminator_b": CycleDiscriminator,
        "train_function": train_cycle_gan
    },
    "Dual-GAN": {
        "generator_a": DualGANGenerator,
        "generator_b": DualGANGenerator,
        "discriminator_a": DualGANDiscriminator,
        "discriminator_b": DualGANDiscriminator,
        "train_function": train_dual_gan
    },
    "Contrastive-Dual-GAN": {
        "generator_a": ContrastiveDualGANGenerator,
        "generator_b": ContrastiveDualGANGenerator,
        "discriminator_a": ContrastiveDualGANDiscriminator,
        "discriminator_b": ContrastiveDualGANDiscriminator,
        "train_function": train_contrastive_dual_gan
    },
    "Semi-Supervised-GAN": {
        "generator": SimpleGANGenerator,
        "discriminator": SemiSupervisedGANDiscriminator,
        "train_function": train_semi_supervised_gan
    },
    "Conditional-GAN": {
        "generator": ConditionalGANGenerator,
        "discriminator": ConditionalGANDiscriminator,
        "train_function": train_conditional_gan
    },
    "InfoGAN": {
        "generator": InfoGANGenerator,
        "discriminator": InfoGANDiscriminator,
        "train_function": train_infogan
    }
}

In [47]:
gan_configurations.keys()

dict_keys(['WGAN-GP', 'VAE-GAN', 'Contrastive-GAN', 'Cross-Domain-GAN', 'Cycle-GAN', 'Dual-GAN', 'Contrastive-Dual-GAN', 'Semi-Supervised-GAN', 'Conditional-GAN', 'InfoGAN'])

In [110]:
import os

# Base directory where embeddings are stored
embedding_base_dir = "./saved_embeddings/embeddings/"

def list_available_embeddings(base_dir, filter_by=None):
    """
    List available embedding directories and files, optionally filtered by method.

    Args:
        base_dir (str): The base directory containing embeddings.
        filter_by (str or list, optional): Method(s) to filter by (e.g., "autoencoder", "vae").
                                           If None, all embeddings are displayed.
    """
    print("\n📂 Available Embeddings:\n")

    if isinstance(filter_by, str):
        filter_by = [filter_by]  # Convert single filter to list

    for method in sorted(os.listdir(base_dir)):
        method_path = os.path.join(base_dir, method)

        # Check if it's a directory
        if os.path.isdir(method_path):
            if filter_by is None or any(f.lower() in method.lower() for f in filter_by):
                pt_files = [f for f in sorted(os.listdir(method_path)) if f.endswith(".pt")]
                if pt_files:
                    print(f"\n🔹 {method}")  # Show only the category
                    for file in pt_files:
                        print(f"   📄 {method}/{file}")  # Show category + filename

# Default: Show everything
list_available_embeddings(embedding_base_dir)

# Example: Show only autoencoder-related embeddings
# list_available_embeddings(embedding_base_dir, filter_by="vae")

# Example: Show both autoencoder and VAE embeddings
# list_available_embeddings(embedding_base_dir, filter_by=["autoencoder", "vae"])


📂 Available Embeddings:


🔹 autoencoder_AdvancedAutoencoder_barlow_twins
   📄 autoencoder_AdvancedAutoencoder_barlow_twins/AdvancedAutoencoder_barlow_twins_embeddings.pt

🔹 autoencoder_AdvancedAutoencoder_contrastive
   📄 autoencoder_AdvancedAutoencoder_contrastive/AdvancedAutoencoder_contrastive_embeddings.pt

🔹 autoencoder_AdvancedAutoencoder_info_nce
   📄 autoencoder_AdvancedAutoencoder_info_nce/AdvancedAutoencoder_info_nce_embeddings.pt

🔹 autoencoder_AdvancedAutoencoder_mse
   📄 autoencoder_AdvancedAutoencoder_mse/AdvancedAutoencoder_mse_embeddings.pt

🔹 autoencoder_AdvancedAutoencoder_ntxent
   📄 autoencoder_AdvancedAutoencoder_ntxent/AdvancedAutoencoder_ntxent_embeddings.pt

🔹 autoencoder_AdvancedAutoencoder_vicreg
   📄 autoencoder_AdvancedAutoencoder_vicreg/AdvancedAutoencoder_vicreg_embeddings.pt

🔹 autoencoder_EnhancedAutoencoder_barlow_twins
   📄 autoencoder_EnhancedAutoencoder_barlow_twins/EnhancedAutoencoder_barlow_twins_embeddings.pt

🔹 autoencoder_EnhancedAutoencoder_co

In [152]:
import os
import torch
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.models import inception_v3
from scipy.linalg import sqrtm
import matplotlib.pyplot as plt
from scipy.stats import entropy, spearmanr
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.neighbors import NearestNeighbors
from scipy.stats import wasserstein_distance
import json

# ==========================
# CONFIGURATION & EMBEDDING LOADING
# ==========================
embedding_dir = "./saved_embeddings/embeddings/"
embedding_file = os.path.join(embedding_dir, "autoencoder_EnhancedAutoencoder_barlow_twins/EnhancedAutoencoder_barlow_twins_embeddings.pt")

# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = {
    "gan_type": "WGAN-GP",
    "latent_dim": 100,
    "embedding_dim": None,  # Will be set after loading embeddings
    "num_classes": 10,
    "categorical_dim": 10,
    "epochs": 1,
    "batch_size": 64,
    "learning_rate": 1e-4,
    "device": device,
    "lambda_gp": 10,
    "beta1": 0.5,
    "beta2": 0.999,
    "save_path": "gan_model.pth",
    "eval_fraction": 0.1,  # Fraction of embeddings used for evaluation
    "show_model_architecture": True
}

def load_embeddings(embedding_file, device, batch_size=64):
    """Loads embeddings and labels from a specified file."""
    data = torch.load(embedding_file)
    embeddings = data["embeddings"].to(device)
    labels = data["labels"].to(device)
    data_loader = DataLoader(embeddings, batch_size=batch_size, shuffle=True)
    return embeddings, labels, data_loader

def split_embeddings(embeddings, labels, eval_fraction=0.1, batch_size=64):
    """Splits embeddings into training and evaluation sets."""
    num_samples = embeddings.size(0)
    num_eval = int(num_samples * eval_fraction)
    num_train = num_samples - num_eval

    train_embeddings, eval_embeddings = random_split(embeddings, [num_train, num_eval])
    train_labels, eval_labels = random_split(labels, [num_train, num_eval])

    train_loader = DataLoader(train_embeddings, batch_size=batch_size, shuffle=True)
    eval_loader = DataLoader(eval_embeddings, batch_size=batch_size, shuffle=False)
    return train_loader, eval_loader, train_embeddings, eval_embeddings

# Load embeddings and split
embeddings, labels, full_data_loader = load_embeddings(embedding_file, device)
config["embedding_dim"] = embeddings.size(1)
train_loader, eval_loader, train_embeddings, eval_embeddings = split_embeddings(embeddings, labels, config["eval_fraction"], config["batch_size"])

# Update config with data loaders
config.update({
    "data_loader": train_loader,
    "data_loader_a": train_loader,
    "data_loader_b": train_loader,
    "eval_loader": eval_loader,
    "original_embeddings": embeddings  # Store original embeddings for evaluation
})

# ==========================
# MODEL INITIALIZATION
# ==========================

def initialize_gan_components(config, gan_configurations):
    """Initialize GAN components based on type."""
    components = {}
    gan_type = config["gan_type"]
    multi_gan_types = ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]

    if gan_type in multi_gan_types:
        gen_args = {"embedding_dim": config["embedding_dim"]}
        if gan_type != "Cycle-GAN":
            gen_args["latent_dim"] = config["latent_dim"]

        components.update({
            "generator_a": gan_configurations["generator_a"](**gen_args).to(config["device"]),
            "generator_b": gan_configurations["generator_b"](**gen_args).to(config["device"]),
            "discriminator_a": gan_configurations["discriminator_a"](embedding_dim=config["embedding_dim"]).to(config["device"]),
            "discriminator_b": gan_configurations["discriminator_b"](embedding_dim=config["embedding_dim"]).to(config["device"])
        })
    else:
        gen_args = {"latent_dim": config["latent_dim"], "embedding_dim": config["embedding_dim"]}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            gen_args["num_classes"] = config["num_classes"]
        if gan_type == "InfoGAN":
            gen_args["categorical_dim"] = config["categorical_dim"]

        components["generator"] = gan_configurations["generator"](**gen_args).to(config["device"])

    if gan_type == "WGAN-GP":
        components["critic"] = gan_configurations["critic"](embedding_dim=config["embedding_dim"]).to(config["device"])
    elif gan_type == "VAE-GAN":
        components["encoder"] = gan_configurations["encoder"](embedding_dim=config["embedding_dim"], latent_dim=config["latent_dim"]).to(config["device"])
        components["discriminator"] = gan_configurations["discriminator"](embedding_dim=config["embedding_dim"]).to(config["device"])
    elif gan_type not in multi_gan_types:
        disc_args = {"embedding_dim": config["embedding_dim"]}
        if gan_type in ["Conditional-GAN", "InfoGAN", "Semi-Supervised-GAN"]:
            disc_args["num_classes"] = config["num_classes"]
        if gan_type == "InfoGAN":
            disc_args["categorical_dim"] = config["categorical_dim"]

        components["discriminator"] = gan_configurations["discriminator"](**disc_args).to(config["device"])

    if config["show_model_architecture"]:
        print(f"Initialized components: {components}")

    return components

# ==========================
# EVALUATION METRICS
# ==========================

def calculate_fid(real_embeddings, generated_embeddings):
    """Compute Fréchet Inception Distance (FID)."""
    mu1, sigma1 = np.mean(real_embeddings, axis=0), np.cov(real_embeddings, rowvar=False)
    mu2, sigma2 = np.mean(generated_embeddings, axis=0), np.cov(generated_embeddings, rowvar=False)
    diff = mu1 - mu2
    covmean = sqrtm(sigma1 @ sigma2)
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    return np.sum(diff**2) + np.trace(sigma1 + sigma2 - 2 * covmean)

def calculate_kl_divergence(real_embeddings, generated_embeddings):
    """Compute KL divergence between real and generated embeddings."""
    real_prob = np.histogram(real_embeddings, bins=50, density=True)[0]
    gen_prob = np.histogram(generated_embeddings, bins=50, density=True)[0]
    real_prob += 1e-10  # Prevent division by zero
    gen_prob += 1e-10
    return entropy(real_prob, gen_prob)

def calculate_cosine_similarity(real_embeddings, generated_embeddings):
    """Compute cosine similarity between real and generated embeddings."""
    return np.mean(cosine_similarity(real_embeddings, generated_embeddings))

def rank_similarity(real_embeddings, generated_embeddings):
    """Compute Spearman Rank Correlation between real and generated embeddings, ensuring matched dimensions."""
    min_size = min(real_embeddings.shape[0], generated_embeddings.shape[0])

    # Randomly sample real embeddings if larger
    if real_embeddings.shape[0] > min_size:
        real_embeddings = real_embeddings[np.random.choice(real_embeddings.shape[0], min_size, replace=False)]

    # Randomly sample generated embeddings if larger
    if generated_embeddings.shape[0] > min_size:
        generated_embeddings = generated_embeddings[np.random.choice(generated_embeddings.shape[0], min_size, replace=False)]

    # Ensure the shapes are aligned for ranking
    real_rank = np.argsort(real_embeddings, axis=0)
    gen_rank = np.argsort(generated_embeddings, axis=0)

    return spearmanr(real_rank.flatten(), gen_rank.flatten()).correlation

def unique_embedding_ratio(generated_embeddings):
    """Compute the ratio of unique embeddings in the generated set."""
    unique_embeddings = np.unique(generated_embeddings, axis=0)
    return len(unique_embeddings) / len(generated_embeddings)

def aggregate_quality_score(fid, kl, cosine, rank_corr):
    """Compute a weighted quality score based on multiple metrics."""
    # Normalize metrics (assuming lower is better for FID & KL)
    fid_norm = 1 / (1 + fid)
    kl_norm = 1 / (1 + kl)
    cosine_norm = cosine  # Higher is better, no need to invert
    rank_norm = (rank_corr + 1) / 2  # Convert [-1,1] to [0,1]

    # Compute final weighted score
    return (0.4 * fid_norm) + (0.2 * kl_norm) + (0.2 * cosine_norm) + (0.2 * rank_norm)

def calculate_wasserstein_distance(real_embeddings, generated_embeddings):
    """Compute Wasserstein Distance between real and generated embeddings."""
    return wasserstein_distance(real_embeddings.flatten(), generated_embeddings.flatten())

def calculate_coverage_score(real_embeddings, generated_embeddings, n_neighbors=5):
    """Compute Coverage Score: Percentage of real embeddings with at least one close match in generated embeddings."""
    neigh = NearestNeighbors(n_neighbors=n_neighbors)
    neigh.fit(generated_embeddings)
    distances, _ = neigh.kneighbors(real_embeddings)
    return np.mean(distances[:, 0] < 0.1)  # Adjust threshold as needed

def calculate_memorization_score(real_embeddings, generated_embeddings, n_neighbors=1, tolerance=1e-3):
    """Compute Memorization Score: Percentage of generated embeddings that are close to real embeddings within a small tolerance."""
    neigh = NearestNeighbors(n_neighbors=n_neighbors)
    neigh.fit(real_embeddings)
    distances, _ = neigh.kneighbors(generated_embeddings)

    return np.mean(distances[:, 0] < tolerance)  # Allow small tolerance for near-exact matches

def generate_synthetic_samples(generator, num_samples=1000, latent_dim=100, device="cuda"):
    """Generate synthetic samples using a trained GAN generator."""
    generator.eval()
    with torch.no_grad():
        latent_vectors = torch.randn(num_samples, latent_dim).to(device)
        generated_samples = generator(latent_vectors)
    return generated_samples.cpu()

def create_dataloader_from_samples(samples, batch_size=64):
    return DataLoader(samples, batch_size=batch_size, shuffle=False)

def convert_to_float(metrics):
    """Ensure all values in the dictionary are JSON serializable."""
    return {k: float(v) for k, v in metrics.items()}

def evaluate_gan(gan_components, config):
    """Evaluate GAN with multiple metrics, ensuring JSON serialization."""
    device = config["device"]
    generator = gan_components["generator"]

    generated_samples = generate_synthetic_samples(generator, num_samples=1000, latent_dim=config["latent_dim"], device=device)
    generated_dataloader = create_dataloader_from_samples(generated_samples, batch_size=config["batch_size"])

    real_embeddings = config["original_embeddings"].cpu().numpy()
    eval_embeddings = torch.cat([batch for batch in config["eval_loader"]], dim=0).cpu().numpy()
    gen_embeddings = torch.cat([batch for batch in generated_dataloader], dim=0).cpu().numpy()

    metrics = {
        "FID (Original vs Generated)": calculate_fid(real_embeddings, gen_embeddings),
        "FID (Original vs Eval)": calculate_fid(real_embeddings, eval_embeddings),
        "KL Divergence": calculate_kl_divergence(real_embeddings, gen_embeddings),
        "Cosine Similarity": calculate_cosine_similarity(real_embeddings, gen_embeddings),
        "Spearman Rank Correlation": rank_similarity(real_embeddings, gen_embeddings),
        "Wasserstein Distance": calculate_wasserstein_distance(real_embeddings, gen_embeddings),
        "Coverage Score": calculate_coverage_score(real_embeddings, gen_embeddings),
        "Memorization Score": calculate_memorization_score(real_embeddings, gen_embeddings),
        "Unique Embedding Ratio": unique_embedding_ratio(gen_embeddings)
    }

    # Convert to native Python floats for JSON compatibility
    metrics = convert_to_float(metrics)

    # Compute aggregate quality score
    metrics["Aggregate Quality Score"] = aggregate_quality_score(
        metrics["FID (Original vs Generated)"],
        metrics["KL Divergence"],
        metrics["Cosine Similarity"],
        metrics["Spearman Rank Correlation"]
    )

    # Print results
    print(json.dumps(metrics, indent=4))

    # Save results
    with open("evaluation_results.json", "w") as f:
        json.dump(metrics, f, indent=4)

# Initialize GAN components
gan_components = initialize_gan_components(config, gan_configurations[config["gan_type"]])

# Run training
run_gan_training(config)

# Evaluate using the original embeddings and evaluation set
evaluate_gan(gan_components, config)



  data = torch.load(embedding_file)


Initialized components: {'generator': WGANGenerator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=100, out_features=256, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): Linear(in_features=256, out_features=50, bias=True)
  )
), 'critic': WGANCritic(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=50, out_features=512, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (2): Linear(in_features=256, out_features=1, bias=True)
  )
)}
Initialized components: {'generator': WGANGenerator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=100, out_features=256, bias=Tr

In [157]:
import os
import torch
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, TensorDataset
from torchvision.models import inception_v3
from scipy.linalg import sqrtm
import matplotlib.pyplot as plt
from scipy.stats import entropy, spearmanr, wasserstein_distance
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.neighbors import NearestNeighbors
import json
import datetime

# ==========================
# CONFIGURATION & EMBEDDING LOADING
# ==========================
embedding_dir = "./saved_embeddings/embeddings/"
embedding_file = os.path.join(embedding_dir, "autoencoder_EnhancedAutoencoder_barlow_twins/EnhancedAutoencoder_barlow_twins_embeddings.pt")

# Report directory
report_dir = "./reports/"
os.makedirs(report_dir, exist_ok=True)

# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = {
    "gan_type": "Semi-Supervised-GAN",
    "latent_dim": 100,
    "embedding_dim": None,  # Will be set after loading embeddings
    "num_classes": 10,
    "categorical_dim": 10,
    "epochs": 1,
    "batch_size": 64,
    "learning_rate": 1e-4,
    "device": device,
    "lambda_gp": 10,
    "beta1": 0.5,
    "beta2": 0.999,
    "save_path": "gan_model.pth",
    "eval_fraction": 0.1,  # Fraction of embeddings used for evaluation
    "show_model_architecture": True
}

# ==========================
# EMBEDDING LOADING
# ==========================
def load_embeddings(embedding_file, device, batch_size=64, return_labels=True):
    """Loads embeddings and their associated labels from a file and creates a DataLoader."""
    print(f"Loading embeddings from: {embedding_file}")
    data = torch.load(embedding_file)
    embeddings = data["embeddings"].to(device)

    if return_labels:
        labels = data["labels"].to(device)
        dataset = TensorDataset(embeddings, labels)
        return embeddings, labels, DataLoader(dataset, batch_size=batch_size, shuffle=True)
    else:
        dataset = TensorDataset(embeddings)
        return embeddings, DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Load embeddings and select appropriate DataLoader
embeddings, labels, data_loader = load_embeddings(embedding_file, device)
embeddings, labels, data_loader_v2 = load_embeddings(embedding_file, device)  # Alternative loader for Semi-Supervised-GAN

config["embedding_dim"] = embeddings.size(1)

# Use data_loader_v2 for Semi-Supervised-GAN
if config["gan_type"] == "Semi-Supervised-GAN":
    config["data_loader"] = data_loader_v2
else:
    config["data_loader"] = data_loader

# ==========================
# REPORT SAVING
# ==========================
def save_report(metrics, gan_type):
    """Saves evaluation metrics as a JSON file with a unique name."""
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    report_filename = f"evaluation_{gan_type}_{timestamp}.json"
    report_path = os.path.join(report_dir, report_filename)

    with open(report_path, "w") as f:
        json.dump(metrics, f, indent=4)

    print(f"✅ Report saved: {report_path}")

# ==========================
# GAN INITIALIZATION & TRAINING
# ==========================
def initialize_gan_components(config, gan_configurations):
    """Initialize GAN components based on type."""
    components = {}
    gan_type = config["gan_type"]
    multi_gan_types = ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]

    if gan_type in multi_gan_types:
        gen_args = {"embedding_dim": config["embedding_dim"]}
        if gan_type != "Cycle-GAN":
            gen_args["latent_dim"] = config["latent_dim"]

        components.update({
            "generator_a": gan_configurations["generator_a"](**gen_args).to(config["device"]),
            "generator_b": gan_configurations["generator_b"](**gen_args).to(config["device"]),
            "discriminator_a": gan_configurations["discriminator_a"](embedding_dim=config["embedding_dim"]).to(config["device"]),
            "discriminator_b": gan_configurations["discriminator_b"](embedding_dim=config["embedding_dim"]).to(config["device"])
        })
    else:
        gen_args = {"latent_dim": config["latent_dim"], "embedding_dim": config["embedding_dim"]}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            gen_args["num_classes"] = config["num_classes"]
        if gan_type == "InfoGAN":
            gen_args["categorical_dim"] = config["categorical_dim"]

        components["generator"] = gan_configurations["generator"](**gen_args).to(config["device"])

    if gan_type == "WGAN-GP":
        components["critic"] = gan_configurations["critic"](embedding_dim=config["embedding_dim"]).to(config["device"])
    elif gan_type == "VAE-GAN":
        components["encoder"] = gan_configurations["encoder"](embedding_dim=config["embedding_dim"], latent_dim=config["latent_dim"]).to(config["device"])
        components["discriminator"] = gan_configurations["discriminator"](embedding_dim=config["embedding_dim"]).to(config["device"])
    elif gan_type not in multi_gan_types:
        disc_args = {"embedding_dim": config["embedding_dim"]}
        if gan_type in ["Conditional-GAN", "InfoGAN", "Semi-Supervised-GAN"]:
            disc_args["num_classes"] = config["num_classes"]
        if gan_type == "InfoGAN":
            disc_args["categorical_dim"] = config["categorical_dim"]

        components["discriminator"] = gan_configurations["discriminator"](**disc_args).to(config["device"])

    if config["show_model_architecture"]:
        print(f"Initialized components: {components}")

    return components

def run_gan_training(config):
    """Runs GAN training based on selected model."""
    gan_type = config["gan_type"]
    if gan_type not in gan_configurations:
        raise ValueError(f"Unsupported GAN type: {gan_type}")

    gan_config = gan_configurations[gan_type]
    components = initialize_gan_components(config, gan_config)
    train_function = gan_config["train_function"]

    print(f"🚀 Training {gan_type}...")

    if gan_type == "VAE-GAN":
        train_function(components["encoder"], components["generator"], components["discriminator"], **config)
    elif gan_type in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]:
        train_function(components["generator_a"], components["generator_b"], components["discriminator_a"], components["discriminator_b"], **config)
    else:
        discriminator = components.get("discriminator", components.get("critic", None))
        train_function(components["generator"], discriminator, **config)

    print(f"✅ {gan_type} training completed!")

# Run Training
run_gan_training(config)

evaluate_gan(gan_components, config)


Loading embeddings from: ./saved_embeddings/embeddings/autoencoder_EnhancedAutoencoder_barlow_twins/EnhancedAutoencoder_barlow_twins_embeddings.pt
Loading embeddings from: ./saved_embeddings/embeddings/autoencoder_EnhancedAutoencoder_barlow_twins/EnhancedAutoencoder_barlow_twins_embeddings.pt
Initialized components: {'generator': SimpleGANGenerator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=100, out_features=1024, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=1024, out_features=512, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (2): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (3): Linear(in_features=256, out_features=50, bias=True)
    

  data = torch.load(embedding_file)


Epoch [1/1], D Loss: 0.0937, G Loss: 8.5292
✅ Semi-Supervised-GAN training completed!


KeyError: 'original_embeddings'

In [154]:
import torch

# ========================
# MAIN CONFIGURATION
# ========================
config = {
    "gan_type": "Semi-Supervised-GAN",  # Change this to switch models
    "latent_dim": 100,  # Latent space dimension
    "embedding_dim": embeddings.size(1),  # Embedding dimension
    "num_classes": 10,  # For conditional models
    "categorical_dim": 10,  # For InfoGAN
    "epochs": 1,
    "batch_size": 64,
    "learning_rate": 1e-4,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "lambda_gp": 10,  # For WGAN-GP
    "beta1": 0.5,  # Adam optimizer beta1
    "beta2": 0.999,  # Adam optimizer beta2
    "save_path": "gan_model.pth",  # Path to save the model
    "data_loader": data_loader_v2,  # Main data loader
    "data_loader_a": data_loader,  # For cross-domain models
    "data_loader_b": data_loader,  # For cross-domain models
    "show_model_architecture": True  # Toggle to print model architectures
}

# ========================
# MODEL INITIALIZATION
# ========================

def initialize_gan_components(config, gan_config):
    """Initialize all components for the specified GAN type"""
    components = {}
    gan_type = config["gan_type"]
    multi_gan_types = ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]

    if gan_type in multi_gan_types:
        gen_args = {"embedding_dim": config["embedding_dim"]}
        if gan_type != "Cycle-GAN":
            gen_args["latent_dim"] = config["latent_dim"]

        components.update({
            "generator_a": gan_config["generator_a"](**gen_args).to(config["device"]),
            "generator_b": gan_config["generator_b"](**gen_args).to(config["device"]),
            "discriminator_a": gan_config["discriminator_a"](embedding_dim=config["embedding_dim"]).to(config["device"]),
            "discriminator_b": gan_config["discriminator_b"](embedding_dim=config["embedding_dim"]).to(config["device"])
        })
    else:
        gen_args = {
            "latent_dim": config["latent_dim"],
            "embedding_dim": config["embedding_dim"]
        }
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            gen_args["num_classes"] = config["num_classes"]
        if gan_type == "InfoGAN":
            gen_args["categorical_dim"] = config["categorical_dim"]

        components["generator"] = gan_config["generator"](**gen_args).to(config["device"])

    if gan_type == "WGAN-GP":
        components["critic"] = gan_config["critic"](embedding_dim=config["embedding_dim"]).to(config["device"])
    elif gan_type == "VAE-GAN":
        components["encoder"] = gan_config["encoder"](embedding_dim=config["embedding_dim"], latent_dim=config["latent_dim"]).to(config["device"])
        components["discriminator"] = gan_config["discriminator"](embedding_dim=config["embedding_dim"]).to(config["device"])
    elif gan_type not in multi_gan_types:
        disc_args = {"embedding_dim": config["embedding_dim"]}
        if gan_type in ["Conditional-GAN", "InfoGAN", "Semi-Supervised-GAN"]:
            disc_args["num_classes"] = config["num_classes"]
        if gan_type == "InfoGAN":
            disc_args["categorical_dim"] = config["categorical_dim"]

        components["discriminator"] = gan_config["discriminator"](**disc_args).to(config["device"])

    if config["show_model_architecture"]:
        print(f"Initialized components: {components}")

    return components

# ========================
# MAIN EXECUTION
# ========================

def run_gan_training(config):
    """Main function to initialize and train the GAN"""
    gan_type = config["gan_type"]
    if gan_type not in gan_configurations:
        raise ValueError(f"Unsupported GAN type: {gan_type}")

    gan_config = gan_configurations[gan_type]
    components = initialize_gan_components(config, gan_config)
    train_function = gan_config["train_function"]

    print(f"🚀 Training {gan_type}...")

    if gan_type == "VAE-GAN":
        train_function(components["encoder"], components["generator"], components["discriminator"], **config)
    elif gan_type in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]:
        train_function(components["generator_a"], components["generator_b"], components["discriminator_a"], components["discriminator_b"], **config)
    else:
        discriminator = components.get("discriminator", components.get("critic", None))
        train_function(components["generator"], discriminator, **config)

    print(f"✅ {gan_type} training completed!")

# ========================
# EXECUTION
# ========================
run_gan_training(config)


Initialized components: {'generator': SimpleGANGenerator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=100, out_features=1024, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=1024, out_features=512, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (2): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (3): Linear(in_features=256, out_features=50, bias=True)
    (4): Tanh()
  )
), 'discriminator': SemiSupervisedGANDiscriminator(
  (shared_model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=50, out_features=512, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): Li

In [156]:
import torch

# ========================
# MAIN CONFIGURATION
# ========================
config = {
    "gan_types": ['WGAN-GP', 'VAE-GAN', 'Contrastive-GAN', 'Cross-Domain-GAN', 'Cycle-GAN', 'Dual-GAN', 'Contrastive-Dual-GAN', 'Semi-Supervised-GAN', 'Conditional-GAN', 'InfoGAN'],  # List of GAN types to test
    "latent_dim": 100,  # Latent space dimension
    "embedding_dim": embeddings.size(1),  # Embedding dimension
    "num_classes": 10,  # For conditional models
    "categorical_dim": 10,  # For InfoGAN
    "epochs": 2,  # Short training for quick verification
    "batch_size": 64,
    "learning_rate": 1e-4,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "lambda_gp": 10,  # For WGAN-GP
    "beta1": 0.5,  # Adam optimizer beta1
    "beta2": 0.999,  # Adam optimizer beta2
    "save_path": "gan_model.pth",  # Path to save the model
    "data_loader": data_loader,  # Main data loader
    "data_loader_a": data_loader,  # For cross-domain models
    "data_loader_b": data_loader,  # For cross-domain models
    "show_model_architecture": False  # Toggle to print model architectures
}

# ========================
# MODEL INITIALIZATION
# ========================

def initialize_gan_components(config, gan_config):
    """Initialize all components for the specified GAN type"""
    components = {}
    gan_type = config["gan_type"]
    multi_gan_types = ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]

    if gan_type in multi_gan_types:
        gen_args = {"embedding_dim": config["embedding_dim"]}
        if gan_type != "Cycle-GAN":
            gen_args["latent_dim"] = config["latent_dim"]

        components.update({
            "generator_a": gan_config["generator_a"](**gen_args).to(config["device"]),
            "generator_b": gan_config["generator_b"](**gen_args).to(config["device"]),
            "discriminator_a": gan_config["discriminator_a"](embedding_dim=config["embedding_dim"]).to(config["device"]),
            "discriminator_b": gan_config["discriminator_b"](embedding_dim=config["embedding_dim"]).to(config["device"])
        })
    else:
        gen_args = {
            "latent_dim": config["latent_dim"],
            "embedding_dim": config["embedding_dim"]
        }
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            gen_args["num_classes"] = config["num_classes"]
        if gan_type == "InfoGAN":
            gen_args["categorical_dim"] = config["categorical_dim"]

        components["generator"] = gan_config["generator"](**gen_args).to(config["device"])

    if gan_type == "WGAN-GP":
        components["critic"] = gan_config["critic"](embedding_dim=config["embedding_dim"]).to(config["device"])
    elif gan_type == "VAE-GAN":
        components["encoder"] = gan_config["encoder"](embedding_dim=config["embedding_dim"], latent_dim=config["latent_dim"]).to(config["device"])
        components["discriminator"] = gan_config["discriminator"](embedding_dim=config["embedding_dim"]).to(config["device"])
    elif gan_type not in multi_gan_types:
        disc_args = {"embedding_dim": config["embedding_dim"]}
        if gan_type in ["Conditional-GAN", "InfoGAN", "Semi-Supervised-GAN"]:
            disc_args["num_classes"] = config["num_classes"]
        if gan_type == "InfoGAN":
            disc_args["categorical_dim"] = config["categorical_dim"]

        components["discriminator"] = gan_config["discriminator"](**disc_args).to(config["device"])

    if config["show_model_architecture"]:
        print(f"Initialized components for {gan_type}: {components}")

    return components

# ========================
# MAIN EXECUTION
# ========================

def test_all_gans(config):
    """Function to test all GAN models for a short number of epochs with improved error handling"""
    for gan_type in config["gan_types"]:
        print(f"\n🚀 Testing {gan_type}...")
        config["gan_type"] = gan_type

        try:
            if gan_type not in gan_configurations:
                raise ValueError(f"Unsupported GAN type: {gan_type}")

            gan_config = gan_configurations[gan_type]
            components = initialize_gan_components(config, gan_config)
            train_function = gan_config["train_function"]

            if gan_type == "VAE-GAN":
                train_function(components["encoder"], components["generator"], components["discriminator"], **config)
            elif gan_type in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]:
                train_function(components["generator_a"], components["generator_b"], components["discriminator_a"], components["discriminator_b"], **config)
            else:
                discriminator = components.get("discriminator", components.get("critic", None))
                train_function(components["generator"], discriminator, **config)

            print(f"✅ {gan_type} training successful!")
        except Exception as e:
            print(f"❌ Error testing {gan_type}: {str(e)}")

# ========================
# EXECUTION
# ========================
test_all_gans(config)



🚀 Testing WGAN-GP...
Epoch [1/2], Loss Critic: -88.1208, Loss Generator: -24.3678
Epoch [2/2], Loss Critic: -55.1855, Loss Generator: -37.9971
✅ WGAN-GP training successful!

🚀 Testing VAE-GAN...
Epoch [1/2], D Loss: 0.0235, G Loss: 2166.8838
Epoch [2/2], D Loss: 0.0059, G Loss: 1393.7501
✅ VAE-GAN training successful!

🚀 Testing Contrastive-GAN...
Epoch [1/2], D Loss: 0.0000, G Loss: -216.4480
Epoch [2/2], D Loss: 0.0001, G Loss: -1216.6046
✅ Contrastive-GAN training successful!

🚀 Testing Cross-Domain-GAN...
Epoch [1/2], D Loss: 0.0126, G Loss: 5.7212
Epoch [2/2], D Loss: 0.0806, G Loss: 4.9677
✅ Cross-Domain-GAN training successful!

🚀 Testing Cycle-GAN...
Epoch [1/2], D Loss A: 1.2441, D Loss B: 1.0757, G Loss: 9.2320
Epoch [2/2], D Loss A: 0.4399, D Loss B: 0.4349, G Loss: 7.3315
✅ Cycle-GAN training successful!

🚀 Testing Dual-GAN...
❌ Error testing Dual-GAN: mat1 and mat2 shapes cannot be multiplied (64x50 and 100x512)

🚀 Testing Contrastive-Dual-GAN...
❌ Error testing Contrast

In [24]:
from torch.utils.data import TensorDataset, DataLoader

def load_embeddings_v2(embedding_file, device, batch_size=64):
    """
    Loads embeddings and their associated labels from a specified file and
    creates a DataLoader for batching the embeddings.

    Args:
        embedding_file (str): Path to the file containing embeddings and labels.
        device (torch.device): The device (CPU/GPU) to load the tensors onto.
        batch_size (int, optional): The batch size for DataLoader. Default is 64.

    Returns:
        tuple: A tuple containing:
            - embeddings (torch.Tensor): Loaded embeddings.
            - labels (torch.Tensor): Corresponding labels for the embeddings.
            - data_loader (DataLoader): DataLoader for batching embeddings.
    """
    logger.info(f"Loading embeddings from: {embedding_file}")
    data = torch.load(embedding_file)
    embeddings = data["embeddings"].to(device)
    labels = data["labels"].to(device)

    # Create a TensorDataset and DataLoader
    dataset = TensorDataset(embeddings, labels)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return embeddings, labels, data_loader

# Now it will work correctly!
embeddings, labels, data_loader = load_embeddings_v2(embedding_file, device)


# interesting way to load the embeddings...

from torch.utils.data import DataLoader, TensorDataset

def load_embeddings(embedding_file, device, batch_size=64, return_labels=True):
    """
    Loads embeddings and their associated labels from a specified file and
    creates a DataLoader for batching the embeddings (and optionally labels).

    Args:
        embedding_file (str): Path to the file containing embeddings and labels.
        device (torch.device): The device (CPU/GPU) to load the tensors onto.
        batch_size (int, optional): The batch size for DataLoader. Default is 64.
        return_labels (bool, optional): Whether to include labels in the DataLoader. Default is True.

    Returns:
        tuple: A tuple containing:
            - embeddings (torch.Tensor): Loaded embeddings.
            - labels (torch.Tensor, optional): Corresponding labels for the embeddings (if return_labels=True).
            - data_loader (DataLoader): DataLoader for batching embeddings (and labels if required).
    """
    print(f"Loading embeddings from: {embedding_file}")
    data = torch.load(embedding_file)
    embeddings = data["embeddings"].to(device)

    # Initialize the DataLoader only with embeddings if labels are not required
    if return_labels:
        labels = data["labels"].to(device)
        # Create a TensorDataset containing both embeddings and labels
        dataset = TensorDataset(embeddings, labels)
        return embeddings, labels, DataLoader(dataset, batch_size=batch_size, shuffle=True)
    else:
        # Create DataLoader for embeddings only
        dataset = TensorDataset(embeddings)  # Just embeddings
        return embeddings, DataLoader(dataset, batch_size=batch_size, shuffle=True)

embeddings, labels, data_loader_v2 = load_embeddings(embedding_file, device)


INFO - Loading embeddings from: ./saved_embeddings/embeddings/autoencoders_BasicAutoencoder/BasicAutoencoder_embeddings.pt
  data = torch.load(embedding_file)


In [149]:
# GAN Configuration
gan_type = "Cross-Domain-GAN"  # Change to the desired GAN type
latent_dim = 100  # Latent dimension for the generator
embedding_dim = embeddings.size(1)  # Embedding dimension based on loaded embeddings
num_classes = 10  # For Conditional GAN and InfoGAN (e.g., 10 classes for MNIST)
categorical_dim = 10
epochs = 2
learning_rate = 0.0001
# Ensure device is properly set up
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the data loaders before the GAN models
embedding_loader_a = data_loader
embedding_loader_b = data_loader
embedding_loader = data_loader

# Initialize the correct GAN models based on `gan_type`
if gan_type in gan_configurations:
    config = gan_configurations[gan_type]
    print(f"Initializing {gan_type} configuration: {config}")

    # Generator initialization for Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        if gan_type == "Cycle-GAN":
            # Cycle-GAN requires two generators and two discriminators
            generator_a = config["generator_a"](embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"CycleGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
        else:
            # Dual-GAN or Contrastive-Dual-GAN requires both latent_dim and embedding_dim for generators
            generator_a = config["generator_a"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"DualGAN/ContrastiveDualGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
    else:
        # For other GANs (WGAN-GP, VAE-GAN, etc.), initialize a single generator
        generator_args = {"latent_dim": latent_dim, "embedding_dim": embedding_dim}
        if gan_type == "InfoGAN":
            # Correctly initialize InfoGAN generator with latent_dim and categorical_dim
            generator_args["categorical_dim"] = categorical_dim
        generator = config["generator"](**generator_args).to(device)
        print(f"Generator initialized on {generator}")

    # Handle critic/discriminator initialization conditionally
    if gan_type == "WGAN-GP":
        # WGAN-GP uses a critic instead of a discriminator
        critic = config["critic"](embedding_dim=embedding_dim).to(device)
        print(f"Critic initialized on {critic}")
    elif gan_type == "Semi-Supervised-GAN":
        # Semi-Supervised-GAN requires num_classes for the discriminator
        discriminator = config["discriminator"](embedding_dim=embedding_dim, num_classes=num_classes).to(device)
        print(f"Semi-Supervised-GAN discriminator initialized on {discriminator}")
    elif gan_type == "InfoGAN":
        # InfoGAN discriminator requires embedding_dim and categorical_dim
        discriminator = config["discriminator"](embedding_dim=embedding_dim, categorical_dim=categorical_dim).to(device)
        print(f"InfoGAN discriminator initialized on {discriminator}")
    elif gan_type not in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]:
        # For other models (VAE-GAN, Conditional-GAN, etc.), initialize a single discriminator
        discriminator_args = {"embedding_dim": embedding_dim}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            discriminator_args["num_classes"] = num_classes
        if gan_type == "InfoGAN":
            discriminator_args["categorical_dim"] = categorical_dim
        discriminator = config["discriminator"](**discriminator_args).to(device)
        print(f"Discriminator initialized on {discriminator}")
    else:
        # If discriminator is part of the configuration (Cycle-GAN, Dual-GAN, etc.), use the ones in the config
        discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
        discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
        print(f"Discriminators initialized on {discriminator_a}, {discriminator_b}")

    # Initialize encoder for VAE-GAN
    if gan_type == "VAE-GAN":
        encoder = config["encoder"](embedding_dim=embedding_dim, latent_dim=latent_dim).to(device)
        print(f"Encoder initialized on {encoder}")

    # Select the appropriate training function
    train_function = config["train_function"]

    # Handle any additional configurations (like learning rate, lambda_gp)
    train_kwargs = config.get("train_kwargs", {})  # Default to empty dict if not specified
    train_kwargs.update({
        "latent_dim": latent_dim,
        "epochs": epochs,  # Pass epochs as part of train_kwargs
        "device": device,
        "learning_rate": learning_rate,
        "num_classes": num_classes if gan_type in ["Conditional-GAN", "InfoGAN", "Semi-Supervised-GAN"] else None,  # Add for Conditional, InfoGAN, and Semi-Supervised-GAN
        "categorical_dim": categorical_dim if gan_type == "InfoGAN" else None,  # Add only for InfoGAN
        "lambda_gp": 10 if gan_type == "WGAN-GP" else None,  # Add only for WGAN-GP
    })

    # Add data loaders conditionally
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        # Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN: pass two data loaders
        train_kwargs.update({
            "data_loader_a": embedding_loader_a,  # Two loaders for these models
            "data_loader_b": embedding_loader_b
        })
    else:
        # For all other models (WGAN-GP, VAE-GAN, etc.), use a single loader
        train_kwargs.update({
            "data_loader": embedding_loader  # Single data loader for other models
        })

    # Training loop (now handled inside the train_function)
    if gan_type == "VAE-GAN":
        # VAE-GAN requires encoder, generator, and discriminator
        train_function(encoder, generator, discriminator, **train_kwargs)
    else:
        # For other models, pass the appropriate generator/discriminator/critic
        train_function(generator_a, generator_b, discriminator_a, discriminator_b, **train_kwargs) if gan_type in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"] else train_function(generator, discriminator if "discriminator" in locals() else critic, **train_kwargs)

    print(f"{gan_type} training test passed!")

else:
    raise ValueError(f"Unsupported GAN type: {gan_type}")


Initializing Cross-Domain-GAN configuration: {'generator': <class 'src.gan_workflows.plan2.plan2_gan_models.CrossDomainGenerator'>, 'discriminator': <class 'src.gan_workflows.plan2.plan2_gan_models.CrossDomainDiscriminator'>, 'train_function': <function train_cross_domain_gan at 0x7d7ef691db20>}
Generator initialized on CrossDomainGenerator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=100, out_features=512, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (2): Linear(in_features=256, out_features=50, bias=True)
  )
)
Discriminator initialized on CrossDomainDiscriminator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=50, out_features=512, bias=True)
  

In [147]:
# GAN Configuration
gan_type = "Semi-Supervised-GAN"  # Change to the desired GAN type
latent_dim = 100  # Latent dimension for the generator
embedding_dim = embeddings.size(1)  # Embedding dimension based on loaded embeddings
num_classes = 10  # For Conditional GAN and InfoGAN (e.g., 10 classes for MNIST)
categorical_dim = 10
epochs = 1
learning_rate = 0.0001
# Ensure device is properly set up
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the data loaders before the GAN models
embedding_loader_a = data_loader
embedding_loader_b = data_loader
embedding_loader = data_loader_v2

# Initialize the correct GAN models based on `gan_type`
if gan_type in gan_configurations:
    config = gan_configurations[gan_type]
    print(f"Initializing {gan_type} configuration: {config}")

    # Generator initialization for Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        if gan_type == "Cycle-GAN":
            # Cycle-GAN requires two generators and two discriminators
            generator_a = config["generator_a"](embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"CycleGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
        else:
            # Dual-GAN or Contrastive-Dual-GAN requires both latent_dim and embedding_dim for generators
            generator_a = config["generator_a"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"DualGAN/ContrastiveDualGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
    else:
        # For other GANs (WGAN-GP, VAE-GAN, etc.), initialize a single generator
        generator_args = {"latent_dim": latent_dim, "embedding_dim": embedding_dim}
        if gan_type == "InfoGAN":
            # Correctly initialize InfoGAN generator with latent_dim and categorical_dim
            generator_args["categorical_dim"] = categorical_dim
        generator = config["generator"](**generator_args).to(device)
        print(f"Generator initialized on {generator}")

    # Handle critic/discriminator initialization conditionally
    if gan_type == "WGAN-GP":
        # WGAN-GP uses a critic instead of a discriminator
        critic = config["critic"](embedding_dim=embedding_dim).to(device)
        print(f"Critic initialized on {critic}")
    elif gan_type == "Semi-Supervised-GAN":
        # Semi-Supervised-GAN requires num_classes for the discriminator
        discriminator = config["discriminator"](embedding_dim=embedding_dim, num_classes=num_classes).to(device)
        print(f"Semi-Supervised-GAN discriminator initialized on {discriminator}")
    elif gan_type == "InfoGAN":
        # InfoGAN discriminator requires embedding_dim and categorical_dim
        discriminator = config["discriminator"](embedding_dim=embedding_dim, categorical_dim=categorical_dim).to(device)
        print(f"InfoGAN discriminator initialized on {discriminator}")
    elif gan_type not in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]:
        # For other models (VAE-GAN, Conditional-GAN, etc.), initialize a single discriminator
        discriminator_args = {"embedding_dim": embedding_dim}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            discriminator_args["num_classes"] = num_classes
        if gan_type == "InfoGAN":
            discriminator_args["categorical_dim"] = categorical_dim
        discriminator = config["discriminator"](**discriminator_args).to(device)
        print(f"Discriminator initialized on {discriminator}")
    else:
        # If discriminator is part of the configuration (Cycle-GAN, Dual-GAN, etc.), use the ones in the config
        discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
        discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
        print(f"Discriminators initialized on {discriminator_a}, {discriminator_b}")

    # Initialize encoder for VAE-GAN
    if gan_type == "VAE-GAN":
        encoder = config["encoder"](embedding_dim=embedding_dim, latent_dim=latent_dim).to(device)
        print(f"Encoder initialized on {encoder}")

    # Select the appropriate training function
    train_function = config["train_function"]

    # Handle any additional configurations (like learning rate, lambda_gp)
    train_kwargs = config.get("train_kwargs", {})  # Default to empty dict if not specified
    train_kwargs.update({
        "latent_dim": latent_dim,
        "epochs": epochs,  # Pass epochs as part of train_kwargs
        "device": device,
        "learning_rate": learning_rate,
        "num_classes": num_classes if gan_type in ["Conditional-GAN", "InfoGAN", "Semi-Supervised-GAN"] else None,  # Add for Conditional, InfoGAN, and Semi-Supervised-GAN
        "categorical_dim": categorical_dim if gan_type == "InfoGAN" else None,  # Add only for InfoGAN
        "lambda_gp": 10 if gan_type == "WGAN-GP" else None,  # Add only for WGAN-GP
    })

    # Add data loaders conditionally
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        # Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN: pass two data loaders
        train_kwargs.update({
            "data_loader_a": embedding_loader_a,  # Two loaders for these models
            "data_loader_b": embedding_loader_b
        })
    else:
        # For all other models (WGAN-GP, VAE-GAN, etc.), use a single loader
        train_kwargs.update({
            "data_loader": embedding_loader  # Single data loader for other models
        })

    # Training loop (now handled inside the train_function)
    if gan_type == "VAE-GAN":
        # VAE-GAN requires encoder, generator, and discriminator
        train_function(encoder, generator, discriminator, **train_kwargs)
    else:
        # For other models, pass the appropriate generator/discriminator/critic
        train_function(generator_a, generator_b, discriminator_a, discriminator_b, **train_kwargs) if gan_type in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"] else train_function(generator, discriminator if "discriminator" in locals() else critic, **train_kwargs)

    print(f"{gan_type} training test passed!")

else:
    raise ValueError(f"Unsupported GAN type: {gan_type}")


Initializing Semi-Supervised-GAN configuration: {'generator': <class 'src.gan_workflows.plan2.plan2_gan_models.SimpleGANGenerator'>, 'discriminator': <class 'src.gan_workflows.plan2.plan2_gan_models.SemiSupervisedGANDiscriminator'>, 'train_function': <function train_semi_supervised_gan at 0x7d7ef691dda0>}
Generator initialized on SimpleGANGenerator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=100, out_features=1024, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=1024, out_features=512, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (2): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (3): Linear(in_features=256, out_features=50, bias=True)
    

In [144]:
# GAN Configuration
gan_type = "Semi-Supervised-GAN"  # Change to the desired GAN type
latent_dim = 100  # Latent dimension for the generator
embedding_dim = embeddings.size(1)  # Embedding dimension based on loaded embeddings
num_classes = 10  # For Conditional GAN and InfoGAN (e.g., 10 classes for MNIST)
categorical_dim = 10
epochs = 1
learning_rate = 0.0001
# Ensure device is properly set up
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the data loaders before the GAN models
embedding_loader_a = data_loader
embedding_loader_b = data_loader
embedding_loader = data_loader_v2

# Initialize the correct GAN models based on `gan_type`
if gan_type in gan_configurations:
    config = gan_configurations[gan_type]
    print(f"Initializing {gan_type} configuration: {config}")

    # Generator initialization for Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        if gan_type == "Cycle-GAN":
            # Cycle-GAN requires two generators and two discriminators
            generator_a = config["generator_a"](embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"CycleGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
        else:
            # Dual-GAN or Contrastive-Dual-GAN requires both latent_dim and embedding_dim for generators
            generator_a = config["generator_a"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"DualGAN/ContrastiveDualGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
    else:
        # For other GANs (WGAN-GP, VAE-GAN, etc.), initialize a single generator
        generator_args = {"latent_dim": latent_dim, "embedding_dim": embedding_dim}
        if gan_type == "InfoGAN":
            # Correctly initialize InfoGAN generator with latent_dim and categorical_dim
            generator_args["categorical_dim"] = categorical_dim
        if gan_type == "Conditional-GAN":
            # Conditional-GAN generator requires num_classes
            generator_args["num_classes"] = num_classes
        generator = config["generator"](**generator_args).to(device)
        print(f"Generator initialized on {generator}")

    # Handle critic/discriminator initialization conditionally
    if gan_type == "WGAN-GP":
        # WGAN-GP uses a critic instead of a discriminator
        critic = config["critic"](embedding_dim=embedding_dim).to(device)
        print(f"Critic initialized on {critic}")
    elif gan_type == "Semi-Supervised-GAN":
        # Semi-Supervised-GAN requires num_classes for the discriminator
        discriminator = config["discriminator"](embedding_dim=embedding_dim, num_classes=num_classes).to(device)
        print(f"Semi-Supervised-GAN discriminator initialized on {discriminator}")
    elif gan_type == "InfoGAN":
        # InfoGAN discriminator requires embedding_dim and categorical_dim
        discriminator = config["discriminator"](embedding_dim=embedding_dim, categorical_dim=categorical_dim).to(device)
        print(f"InfoGAN discriminator initialized on {discriminator}")
    elif gan_type == "Conditional-GAN":
        # Conditional-GAN discriminator requires embedding_dim and num_classes
        discriminator = config["discriminator"](embedding_dim=embedding_dim, num_classes=num_classes).to(device)
        print(f"Conditional-GAN discriminator initialized on {discriminator}")
    elif gan_type not in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]:
        # For other models (VAE-GAN, Conditional-GAN, etc.), initialize a single discriminator
        discriminator_args = {"embedding_dim": embedding_dim}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            discriminator_args["num_classes"] = num_classes
        if gan_type == "InfoGAN":
            discriminator_args["categorical_dim"] = categorical_dim
        discriminator = config["discriminator"](**discriminator_args).to(device)
        print(f"Discriminator initialized on {discriminator}")
    else:
        # If discriminator is part of the configuration (Cycle-GAN, Dual-GAN, etc.), use the ones in the config
        discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
        discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
        print(f"Discriminators initialized on {discriminator_a}, {discriminator_b}")

    # Initialize encoder for VAE-GAN
    if gan_type == "VAE-GAN":
        encoder = config["encoder"](embedding_dim=embedding_dim, latent_dim=latent_dim).to(device)
        print(f"Encoder initialized on {encoder}")

    # Select the appropriate training function
    train_function = config["train_function"]

    # Handle any additional configurations (like learning rate, lambda_gp)
    train_kwargs = config.get("train_kwargs", {})  # Default to empty dict if not specified
    train_kwargs.update({
        "latent_dim": latent_dim,
        "epochs": epochs,  # Pass epochs as part of train_kwargs
        "device": device,
        "learning_rate": learning_rate,
        "num_classes": num_classes if gan_type in ["Conditional-GAN", "InfoGAN", "Semi-Supervised-GAN"] else None,  # Add for Conditional, InfoGAN, and Semi-Supervised-GAN
        "categorical_dim": categorical_dim if gan_type == "InfoGAN" else None,  # Add only for InfoGAN
        "lambda_gp": 10 if gan_type == "WGAN-GP" else None,  # Add only for WGAN-GP
    })

    # Add data loaders conditionally
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        # Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN: pass two data loaders
        train_kwargs.update({
            "data_loader_a": embedding_loader_a,  # Two loaders for these models
            "data_loader_b": embedding_loader_b
        })
    else:
        # For all other models (WGAN-GP, VAE-GAN, etc.), use a single loader
        train_kwargs.update({
            "data_loader": embedding_loader  # Single data loader for other models
        })

    # Training loop (now handled inside the train_function)
    if gan_type == "VAE-GAN":
        # VAE-GAN requires encoder, generator, and discriminator
        train_function(encoder, generator, discriminator, **train_kwargs)
    else:
        # For other models, pass the appropriate generator/discriminator/critic
        train_function(generator_a, generator_b, discriminator_a, discriminator_b, **train_kwargs) if gan_type in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"] else train_function(generator, discriminator if "discriminator" in locals() else critic, **train_kwargs)

    print(f"{gan_type} training test passed!")

else:
    raise ValueError(f"Unsupported GAN type: {gan_type}")



Initializing Semi-Supervised-GAN configuration: {'generator': <class 'src.gan_workflows.plan2.plan2_gan_models.SimpleGANGenerator'>, 'discriminator': <class 'src.gan_workflows.plan2.plan2_gan_models.SemiSupervisedGANDiscriminator'>, 'train_function': <function train_semi_supervised_gan at 0x7d7ef691dda0>}
Generator initialized on SimpleGANGenerator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=100, out_features=1024, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=1024, out_features=512, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (2): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (3): Linear(in_features=256, out_features=50, bias=True)
    

# Evaluating the GANs

In [100]:
import os
import torch
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.models import inception_v3
from scipy.linalg import sqrtm
import matplotlib.pyplot as plt
from scipy.stats import entropy, spearmanr
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.neighbors import NearestNeighbors
from scipy.stats import wasserstein_distance
import json

# ==========================
# CONFIGURATION & EMBEDDING LOADING
# ==========================
embedding_dir = "./saved_embeddings/embeddings/"
embedding_file = os.path.join(embedding_dir, "autoencoder_EnhancedAutoencoder_barlow_twins/EnhancedAutoencoder_barlow_twins_embeddings.pt")

# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = {
    "gan_type": "Contrastive-GAN",
    "latent_dim": 100,
    "embedding_dim": None,  # Will be set after loading embeddings
    "num_classes": 10,
    "categorical_dim": 10,
    "epochs": 1,
    "batch_size": 64,
    "learning_rate": 1e-4,
    "device": device,
    "lambda_gp": 10,
    "beta1": 0.5,
    "beta2": 0.999,
    "save_path": "gan_model.pth",
    "eval_fraction": 0.1,  # Fraction of embeddings used for evaluation
    "show_model_architecture": True
}

def load_embeddings(embedding_file, device, batch_size=64):
    """Loads embeddings and labels from a specified file."""
    data = torch.load(embedding_file)
    embeddings = data["embeddings"].to(device)
    labels = data["labels"].to(device)
    data_loader = DataLoader(embeddings, batch_size=batch_size, shuffle=True)
    return embeddings, labels, data_loader

def split_embeddings(embeddings, labels, eval_fraction=0.1, batch_size=64):
    """Splits embeddings into training and evaluation sets."""
    num_samples = embeddings.size(0)
    num_eval = int(num_samples * eval_fraction)
    num_train = num_samples - num_eval

    train_embeddings, eval_embeddings = random_split(embeddings, [num_train, num_eval])
    train_labels, eval_labels = random_split(labels, [num_train, num_eval])

    train_loader = DataLoader(train_embeddings, batch_size=batch_size, shuffle=True)
    eval_loader = DataLoader(eval_embeddings, batch_size=batch_size, shuffle=False)
    return train_loader, eval_loader, train_embeddings, eval_embeddings

# Load embeddings and split
embeddings, labels, full_data_loader = load_embeddings(embedding_file, device)
config["embedding_dim"] = embeddings.size(1)
train_loader, eval_loader, train_embeddings, eval_embeddings = split_embeddings(embeddings, labels, config["eval_fraction"], config["batch_size"])

# Update config with data loaders
config.update({
    "data_loader": train_loader,
    "data_loader_a": train_loader,
    "data_loader_b": train_loader,
    "eval_loader": eval_loader,
    "original_embeddings": embeddings  # Store original embeddings for evaluation
})

# ==========================
# MODEL INITIALIZATION
# ==========================

def initialize_gan_components(config, gan_configurations):
    """Initialize GAN components based on type."""
    components = {}
    gan_type = config["gan_type"]
    multi_gan_types = ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]

    if gan_type in multi_gan_types:
        gen_args = {"embedding_dim": config["embedding_dim"]}
        if gan_type != "Cycle-GAN":
            gen_args["latent_dim"] = config["latent_dim"]

        components.update({
            "generator_a": gan_configurations["generator_a"](**gen_args).to(config["device"]),
            "generator_b": gan_configurations["generator_b"](**gen_args).to(config["device"]),
            "discriminator_a": gan_configurations["discriminator_a"](embedding_dim=config["embedding_dim"]).to(config["device"]),
            "discriminator_b": gan_configurations["discriminator_b"](embedding_dim=config["embedding_dim"]).to(config["device"])
        })
    else:
        gen_args = {"latent_dim": config["latent_dim"], "embedding_dim": config["embedding_dim"]}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            gen_args["num_classes"] = config["num_classes"]
        if gan_type == "InfoGAN":
            gen_args["categorical_dim"] = config["categorical_dim"]

        components["generator"] = gan_configurations["generator"](**gen_args).to(config["device"])

    if gan_type == "WGAN-GP":
        components["critic"] = gan_configurations["critic"](embedding_dim=config["embedding_dim"]).to(config["device"])
    elif gan_type == "VAE-GAN":
        components["encoder"] = gan_configurations["encoder"](embedding_dim=config["embedding_dim"], latent_dim=config["latent_dim"]).to(config["device"])
        components["discriminator"] = gan_configurations["discriminator"](embedding_dim=config["embedding_dim"]).to(config["device"])
    elif gan_type not in multi_gan_types:
        disc_args = {"embedding_dim": config["embedding_dim"]}
        if gan_type in ["Conditional-GAN", "InfoGAN", "Semi-Supervised-GAN"]:
            disc_args["num_classes"] = config["num_classes"]
        if gan_type == "InfoGAN":
            disc_args["categorical_dim"] = config["categorical_dim"]

        components["discriminator"] = gan_configurations["discriminator"](**disc_args).to(config["device"])

    if config["show_model_architecture"]:
        print(f"Initialized components: {components}")

    return components

# ==========================
# EVALUATION METRICS
# ==========================

def calculate_fid(real_embeddings, generated_embeddings):
    """Compute Fréchet Inception Distance (FID)."""
    mu1, sigma1 = np.mean(real_embeddings, axis=0), np.cov(real_embeddings, rowvar=False)
    mu2, sigma2 = np.mean(generated_embeddings, axis=0), np.cov(generated_embeddings, rowvar=False)
    diff = mu1 - mu2
    covmean = sqrtm(sigma1 @ sigma2)
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    return np.sum(diff**2) + np.trace(sigma1 + sigma2 - 2 * covmean)

def calculate_kl_divergence(real_embeddings, generated_embeddings):
    """Compute KL divergence between real and generated embeddings."""
    real_prob = np.histogram(real_embeddings, bins=50, density=True)[0]
    gen_prob = np.histogram(generated_embeddings, bins=50, density=True)[0]
    real_prob += 1e-10  # Prevent division by zero
    gen_prob += 1e-10
    return entropy(real_prob, gen_prob)

def calculate_cosine_similarity(real_embeddings, generated_embeddings):
    """Compute cosine similarity between real and generated embeddings."""
    return np.mean(cosine_similarity(real_embeddings, generated_embeddings))

def rank_similarity(real_embeddings, generated_embeddings):
    """Compute Spearman Rank Correlation between real and generated embeddings, ensuring matched dimensions."""
    min_size = min(real_embeddings.shape[0], generated_embeddings.shape[0])

    # Randomly sample real embeddings if larger
    if real_embeddings.shape[0] > min_size:
        real_embeddings = real_embeddings[np.random.choice(real_embeddings.shape[0], min_size, replace=False)]

    # Randomly sample generated embeddings if larger
    if generated_embeddings.shape[0] > min_size:
        generated_embeddings = generated_embeddings[np.random.choice(generated_embeddings.shape[0], min_size, replace=False)]

    # Ensure the shapes are aligned for ranking
    real_rank = np.argsort(real_embeddings, axis=0)
    gen_rank = np.argsort(generated_embeddings, axis=0)

    return spearmanr(real_rank.flatten(), gen_rank.flatten()).correlation

def unique_embedding_ratio(generated_embeddings):
    """Compute the ratio of unique embeddings in the generated set."""
    unique_embeddings = np.unique(generated_embeddings, axis=0)
    return len(unique_embeddings) / len(generated_embeddings)

def aggregate_quality_score(fid, kl, cosine, rank_corr):
    """Compute a weighted quality score based on multiple metrics."""
    # Normalize metrics (assuming lower is better for FID & KL)
    fid_norm = 1 / (1 + fid)
    kl_norm = 1 / (1 + kl)
    cosine_norm = cosine  # Higher is better, no need to invert
    rank_norm = (rank_corr + 1) / 2  # Convert [-1,1] to [0,1]

    # Compute final weighted score
    return (0.4 * fid_norm) + (0.2 * kl_norm) + (0.2 * cosine_norm) + (0.2 * rank_norm)

def calculate_wasserstein_distance(real_embeddings, generated_embeddings):
    """Compute Wasserstein Distance between real and generated embeddings."""
    return wasserstein_distance(real_embeddings.flatten(), generated_embeddings.flatten())

def calculate_coverage_score(real_embeddings, generated_embeddings, n_neighbors=5):
    """Compute Coverage Score: Percentage of real embeddings with at least one close match in generated embeddings."""
    neigh = NearestNeighbors(n_neighbors=n_neighbors)
    neigh.fit(generated_embeddings)
    distances, _ = neigh.kneighbors(real_embeddings)
    return np.mean(distances[:, 0] < 0.1)  # Adjust threshold as needed

def calculate_memorization_score(real_embeddings, generated_embeddings, n_neighbors=1, tolerance=1e-3):
    """Compute Memorization Score: Percentage of generated embeddings that are close to real embeddings within a small tolerance."""
    neigh = NearestNeighbors(n_neighbors=n_neighbors)
    neigh.fit(real_embeddings)
    distances, _ = neigh.kneighbors(generated_embeddings)

    return np.mean(distances[:, 0] < tolerance)  # Allow small tolerance for near-exact matches

def generate_synthetic_samples(generator, num_samples=1000, latent_dim=100, device="cuda"):
    """Generate synthetic samples using a trained GAN generator."""
    generator.eval()
    with torch.no_grad():
        latent_vectors = torch.randn(num_samples, latent_dim).to(device)
        generated_samples = generator(latent_vectors)
    return generated_samples.cpu()

def create_dataloader_from_samples(samples, batch_size=64):
    return DataLoader(samples, batch_size=batch_size, shuffle=False)

def convert_to_float(metrics):
    """Ensure all values in the dictionary are JSON serializable."""
    return {k: float(v) for k, v in metrics.items()}

def evaluate_gan(gan_components, config):
    """Evaluate GAN with multiple metrics, ensuring JSON serialization."""
    device = config["device"]
    generator = gan_components["generator"]

    generated_samples = generate_synthetic_samples(generator, num_samples=1000, latent_dim=config["latent_dim"], device=device)
    generated_dataloader = create_dataloader_from_samples(generated_samples, batch_size=config["batch_size"])

    real_embeddings = config["original_embeddings"].cpu().numpy()
    eval_embeddings = torch.cat([batch for batch in config["eval_loader"]], dim=0).cpu().numpy()
    gen_embeddings = torch.cat([batch for batch in generated_dataloader], dim=0).cpu().numpy()

    metrics = {
        "FID (Original vs Generated)": calculate_fid(real_embeddings, gen_embeddings),
        "FID (Original vs Eval)": calculate_fid(real_embeddings, eval_embeddings),
        "KL Divergence": calculate_kl_divergence(real_embeddings, gen_embeddings),
        "Cosine Similarity": calculate_cosine_similarity(real_embeddings, gen_embeddings),
        "Spearman Rank Correlation": rank_similarity(real_embeddings, gen_embeddings),
        "Wasserstein Distance": calculate_wasserstein_distance(real_embeddings, gen_embeddings),
        "Coverage Score": calculate_coverage_score(real_embeddings, gen_embeddings),
        "Memorization Score": calculate_memorization_score(real_embeddings, gen_embeddings),
        "Unique Embedding Ratio": unique_embedding_ratio(gen_embeddings)
    }

    # Convert to native Python floats for JSON compatibility
    metrics = convert_to_float(metrics)

    # Compute aggregate quality score
    metrics["Aggregate Quality Score"] = aggregate_quality_score(
        metrics["FID (Original vs Generated)"],
        metrics["KL Divergence"],
        metrics["Cosine Similarity"],
        metrics["Spearman Rank Correlation"]
    )

    # Print results
    print(json.dumps(metrics, indent=4))

    # Save results
    with open("evaluation_results.json", "w") as f:
        json.dump(metrics, f, indent=4)

# Initialize GAN components
gan_components = initialize_gan_components(config, gan_configurations[config["gan_type"]])

# Run training
run_gan_training(config)

# Evaluate using the original embeddings and evaluation set
evaluate_gan(gan_components, config)



  data = torch.load(embedding_file)


Initialized components: {'generator': ContrastiveGANGenerator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=100, out_features=512, bias=True)
        (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=True)
        (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (2): Linear(in_features=256, out_features=50, bias=True)
  )
), 'discriminator': ContrastiveGANDiscriminator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=50, out_features=512, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): LinearBlock(
      (blo