# Plan 2: GAN Training with Embeddings

## Objective
The goal of this plan is to train GANs using embeddings generated from Plan 1. This involves:
1. Setting up a GAN architecture that incorporates embeddings.
2. Training the GAN to generate high-quality synthetic embeddings or data.
3. Evaluating the GAN's performance using relevant metrics.

## Key Steps

### 1. Define GAN Architecture
- **Generator**: Accepts embeddings and noise as inputs, producing synthetic data.
- **Discriminator**: Evaluates the authenticity of generated data, optionally conditioned on embeddings.

### 2. Load Pre-Generated Embeddings
- Load embeddings created in Plan 1 from their respective directories.
- Normalize and preprocess embeddings for GAN training.

### 3. Train the GAN
- Set up training loops for the generator and discriminator.
- Use appropriate loss functions, such as adversarial loss (e.g., Wasserstein loss).
- Optionally, include auxiliary tasks (e.g., reconstruction loss) for better embedding alignment.

### 4. Save Outputs
- Save trained GAN models.
- Save generated embeddings or data for downstream evaluation.

### 5. Evaluate GAN Performance
- Use metrics like FID, IS, and qualitative visualization.
- Compare performance with baseline models or methods.

---

## Starting Point

**Files Available**:
- `plan2_gan_models.py`: Contains the GAN architecture.
- `plan2_gan_training.py`: Implements the training pipeline.
- `main_plan2_gan_training.ipynb`: High-level control notebook.
- `plan2_experiments.ipynb`: For experimentation and evaluation.

## Next Steps
1. **Review Architecture**:
   - Inspect `plan2_gan_models.py` to understand the generator and discriminator setup.
2. **Pipeline Setup**:
   - Review and prepare the training pipeline in `plan2_gan_training.py`.
3. **Experimentation**:
   - Use `plan2_experiments.ipynb` for controlled experiments.


In [1]:
# Plan 2 GAN Training Notebook

import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import logging
from datetime import datetime
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset


from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Repository path (adjust if needed)
repo_path = "/content/drive/MyDrive/GAN-thesis-project"

# Add repository path to sys.path for module imports
if repo_path not in sys.path:
    sys.path.append(repo_path)

# Change working directory to the repository
os.chdir(repo_path)

# Verify the working directory
print(f"Current working directory: {os.getcwd()}")

# GAN Models and Training Functions
from src.gan_workflows.plan2.plan2_gan_models import (
    SimpleGANGenerator, SimpleGANDiscriminator,
    ContrastiveGANGenerator, ContrastiveGANDiscriminator,
    VAEGANEncoder, VAEGANGenerator, VAEGANDiscriminator,
    WGANGenerator, WGANCritic,
    CrossDomainGenerator, CrossDomainDiscriminator,
    CycleGenerator, CycleDiscriminator,
    DualGANGenerator, DualGANDiscriminator,
    ContrastiveDualGANGenerator, ContrastiveDualGANDiscriminator,
    SemiSupervisedGANDiscriminator,
    ConditionalGANGenerator, ConditionalGANDiscriminator,
    InfoGANGenerator, InfoGANDiscriminator,
    compute_gradient_penalty
)
from src.gan_workflows.plan2.plan2_gan_training import (
    train_wgan_gp, train_vae_gan, train_contrastive_gan,
    train_cross_domain_gan, train_cycle_gan,
    train_dual_gan, train_contrastive_dual_gan,
    train_semi_supervised_gan,
    train_conditional_gan,
    train_infogan
)

from src.data_utils import (
    load_embeddings, analyze_embeddings
)


# Set random seed and device
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Mounted at /content/drive
Current working directory: /content/drive/MyDrive/GAN-thesis-project
Using device: cpu


In [2]:
embedding_dir = "./saved_embeddings/embeddings/autoencoders_BasicAutoencoder"  # Example embedding path
embedding_file = os.path.join(embedding_dir, "BasicAutoencoder_embeddings.pt")

# Load embeddings and labels using the new function
embeddings, labels, data_loader = load_embeddings(embedding_file, device)

# Analyze the embeddings (optional: set expected_dim if necessary)
analyze_embeddings(embeddings, expected_dim=50, labels=labels)

Loading embeddings from: ./saved_embeddings/embeddings/autoencoders_BasicAutoencoder/BasicAutoencoder_embeddings.pt


  data = torch.load(embedding_file)
INFO - Embeddings are of shape: torch.Size([7000, 50])
INFO - Data type: torch.float32
INFO - Device: cpu
INFO - Mean: -0.10591454058885574
INFO - Standard Deviation: 4.491614818572998
INFO - Min: -22.35531997680664
INFO - Max: 15.808435440063477
INFO - Median: 0.15539658069610596
INFO - Mean L2 Norm: 31.488544464111328
INFO - Standard Deviation of L2 Norms: 4.21454381942749
INFO - Each embedding has 50 dimensions.
INFO - Sparsity (proportion of non-zero elements): 1.0
INFO - Skewness of embeddings: -0.6672645237194799
INFO - Kurtosis of embeddings: 1.5908177243367119
INFO - Pairwise distance (mean): 25.41992176637177
INFO - Average Cosine similarity with true labels: -0.00153539318125695
INFO - Number of outliers detected in embeddings: 30
INFO - Embeddings analysis completed.


In [3]:
# ## OLD VERSION

# def validate_embeddings(embeddings):
#     """
#     Validate and provide information about the shape and data type of embeddings.
#     """
#     if embeddings is None or len(embeddings) == 0:
#         raise ValueError("Embeddings are empty or not properly generated.")
#     print(f"Embeddings are of shape: {embeddings.shape}")
#     print(f"Data type: {embeddings.dtype}")
#     print(f"Device: {embeddings.device}")
#     if torch.isnan(embeddings).any():
#         raise ValueError("Embeddings contain NaN values.")
#     if torch.isinf(embeddings).any():
#         raise ValueError("Embeddings contain infinite values.")
#     if embeddings.ndim != 2:
#         raise ValueError("Embeddings should be a 2D tensor.")
#     print("Embeddings validation passed.")

# validate_embeddings(embeddings)

In [4]:
# Get a batch of training data
embedding_batch = next(iter(data_loader))
print('embedding batches', embedding_batch.shape)

embedding batches torch.Size([64, 50])


In [5]:
# GAN Configuration Dictionary
gan_configurations = {
    "WGAN-GP": {
        "generator": WGANGenerator,
        "critic": WGANCritic,
        "train_function": train_wgan_gp,
        "train_kwargs": {"lambda_gp": 10}
    },
    "VAE-GAN": {
        "encoder": VAEGANEncoder,
        "generator": VAEGANGenerator,
        "discriminator": VAEGANDiscriminator,
        "train_function": train_vae_gan
    },
    "Contrastive-GAN": {
        "generator": ContrastiveGANGenerator,
        "discriminator": ContrastiveGANDiscriminator,
        "train_function": train_contrastive_gan
    },
    "Cross-Domain-GAN": {
        "generator": CrossDomainGenerator,
        "discriminator": CrossDomainDiscriminator,
        "train_function": train_cross_domain_gan
    },
    "Cycle-GAN": {
        "generator_a": CycleGenerator,
        "generator_b": CycleGenerator,
        "discriminator_a": CycleDiscriminator,
        "discriminator_b": CycleDiscriminator,
        "train_function": train_cycle_gan
    },
    "Dual-GAN": {
        "generator_a": DualGANGenerator,
        "generator_b": DualGANGenerator,
        "discriminator_a": DualGANDiscriminator,
        "discriminator_b": DualGANDiscriminator,
        "train_function": train_dual_gan
    },
    "Contrastive-Dual-GAN": {
        "generator_a": ContrastiveDualGANGenerator,
        "generator_b": ContrastiveDualGANGenerator,
        "discriminator_a": ContrastiveDualGANDiscriminator,
        "discriminator_b": ContrastiveDualGANDiscriminator,
        "train_function": train_contrastive_dual_gan
    },
    "Semi-Supervised-GAN": {
        "generator": SimpleGANGenerator,
        "discriminator": SemiSupervisedGANDiscriminator,
        "train_function": train_semi_supervised_gan
    },
    "Conditional-GAN": {
        "generator": ConditionalGANGenerator,
        "discriminator": ConditionalGANDiscriminator,
        "train_function": train_conditional_gan
    },
    "InfoGAN": {
        "generator": InfoGANGenerator,
        "discriminator": InfoGANDiscriminator,
        "train_function": train_infogan
    }
}

In [6]:
# GAN Configuration
gan_type = "Cycle-GAN"  # Change to the desired GAN type
latent_dim = 100  # Latent dimension for the generator
embedding_dim = embeddings.size(1)  # Embedding dimension based on loaded embeddings
num_classes = 10  # For Conditional GAN and InfoGAN (e.g., 10 classes for MNIST)
categorical_dim = 10
epochs = 20
learning_rate = 0.0001
# Ensure device is properly set up
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the correct GAN models based on `gan_type`
if gan_type in gan_configurations:
    config = gan_configurations[gan_type]
    print(f"Initializing {gan_type} configuration: {config}")

    # Generator initialization for Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        if gan_type == "Cycle-GAN":
            # Cycle-GAN requires two generators and two discriminators
            generator_a = config["generator_a"](embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)  # Use CycleDiscriminator for CycleGAN
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)  # Same for second domain
            print(f"CycleGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
        else:
            # Dual-GAN or Contrastive-Dual-GAN requires both latent_dim and embedding_dim for generators
            generator_a = config["generator_a"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"DualGAN/ContrastiveDualGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
    else:
        # For other GANs (WGAN-GP, VAE-GAN, etc.), initialize a single generator
        generator_args = {"latent_dim": latent_dim, "embedding_dim": embedding_dim}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            generator_args["num_classes"] = num_classes
        if gan_type == "InfoGAN":
            generator_args["categorical_dim"] = categorical_dim
        generator = config["generator"](**generator_args).to(device)
        print(f"Generator initialized on {generator}")

    # Discriminator initialization for other GAN types
    if gan_type not in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]:
        # For other models (WGAN-GP, VAE-GAN, etc.), initialize a single discriminator
        discriminator_args = {"embedding_dim": embedding_dim}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            discriminator_args["num_classes"] = num_classes
        if gan_type == "InfoGAN":
            discriminator_args["categorical_dim"] = categorical_dim
        discriminator = config["discriminator"](**discriminator_args).to(device)
        print(f"Discriminator initialized on {discriminator}")

    # Select the appropriate training function
    train_function = config["train_function"]

    # Handle any additional configurations (like learning rate, lambda_gp)
    train_kwargs = config.get("train_kwargs", {})  # Default to empty dict if not specified
    train_kwargs.update({
        "latent_dim": latent_dim,
        "epochs": epochs,  # Pass epochs as part of train_kwargs
        "device": device,
        "learning_rate": learning_rate,
        "num_classes": num_classes if gan_type in ["Conditional-GAN", "InfoGAN"] else None,  # Add only for Conditional and InfoGAN
        "categorical_dim": categorical_dim if gan_type == "InfoGAN" else None,  # Add only for InfoGAN
        "lambda_gp": 10 if gan_type == "WGAN-GP" else None,  # Add only for WGAN-GP
    })

    # Add data loaders conditionally
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        # Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN: pass two data loaders
        train_kwargs.update({
            "data_loader_a": data_loader,  # Two loaders for these models
            "data_loader_b": data_loader
        })
    else:
        # For all other models (WGAN-GP, VAE-GAN, etc.), use a single loader
        train_kwargs.update({
            "data_loader": data_loader  # Single data loader for other models
        })

    # Training loop (now handled inside the train_function)
    train_function(generator_a, generator_b, discriminator_a, discriminator_b, **train_kwargs) if gan_type in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"] else train_function(generator, discriminator if "discriminator" in locals() else critic, **train_kwargs)

    print(f"{gan_type} training test passed!")

else:
    raise ValueError(f"Unsupported GAN type: {gan_type}")


Initializing Cycle-GAN configuration: {'generator_a': <class 'src.gan_workflows.plan2.plan2_gan_models.CycleGenerator'>, 'generator_b': <class 'src.gan_workflows.plan2.plan2_gan_models.CycleGenerator'>, 'discriminator_a': <class 'src.gan_workflows.plan2.plan2_gan_models.CycleDiscriminator'>, 'discriminator_b': <class 'src.gan_workflows.plan2.plan2_gan_models.CycleDiscriminator'>, 'train_function': <function train_cycle_gan at 0x7c0ebbc672e0>}
CycleGAN generators and discriminators initialized on CycleGenerator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=50, out_features=512, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (2): Linear(in_features=256, out_features=50, bias=True)
  )
), CycleGenerator(
  (mo

In [7]:
# GAN Configuration
gan_type = "Cycle-GAN"  # Change to the desired GAN type
latent_dim = 100  # Latent dimension for the generator
embedding_dim = embeddings.size(1)  # Embedding dimension based on loaded embeddings
num_classes = 10  # For Conditional GAN and InfoGAN (e.g., 10 classes for MNIST)
categorical_dim = 10
epochs = 20
learning_rate = 0.0001
# Ensure device is properly set up
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the data loaders before the GAN models
embedding_loader_a = data_loader
embedding_loader_b = data_loader
embedding_loader = data_loader

# Initialize the correct GAN models based on `gan_type`
if gan_type in gan_configurations:
    config = gan_configurations[gan_type]
    print(f"Initializing {gan_type} configuration: {config}")

    # Generator initialization for Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        if gan_type == "Cycle-GAN":
            # Cycle-GAN requires two generators and two discriminators
            generator_a = config["generator_a"](embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"CycleGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
        else:
            # Dual-GAN or Contrastive-Dual-GAN requires both latent_dim and embedding_dim for generators
            generator_a = config["generator_a"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"DualGAN/ContrastiveDualGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
    else:
        # For other GANs (WGAN-GP, VAE-GAN, etc.), initialize a single generator
        generator_args = {"latent_dim": latent_dim, "embedding_dim": embedding_dim}
        if gan_type == "InfoGAN":
            # Correctly initialize InfoGAN generator with latent_dim and categorical_dim
            generator_args["categorical_dim"] = categorical_dim
        generator = config["generator"](**generator_args).to(device)
        print(f"Generator initialized on {generator}")

    # Handle critic/discriminator initialization conditionally
    if gan_type == "WGAN-GP":
        # WGAN-GP uses a critic instead of a discriminator
        critic = config["critic"](embedding_dim=embedding_dim).to(device)
        print(f"Critic initialized on {critic}")
    elif gan_type == "Semi-Supervised-GAN":
        # Semi-Supervised-GAN requires num_classes for the discriminator
        discriminator = config["discriminator"](embedding_dim=embedding_dim, num_classes=num_classes).to(device)
        print(f"Semi-Supervised-GAN discriminator initialized on {discriminator}")
    elif gan_type == "InfoGAN":
        # InfoGAN discriminator requires embedding_dim and categorical_dim
        discriminator = config["discriminator"](embedding_dim=embedding_dim, categorical_dim=categorical_dim).to(device)
        print(f"InfoGAN discriminator initialized on {discriminator}")
    elif gan_type not in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]:
        # For other models (VAE-GAN, Conditional-GAN, etc.), initialize a single discriminator
        discriminator_args = {"embedding_dim": embedding_dim}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            discriminator_args["num_classes"] = num_classes
        if gan_type == "InfoGAN":
            discriminator_args["categorical_dim"] = categorical_dim
        discriminator = config["discriminator"](**discriminator_args).to(device)
        print(f"Discriminator initialized on {discriminator}")
    else:
        # If discriminator is part of the configuration (Cycle-GAN, Dual-GAN, etc.), use the ones in the config
        discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
        discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
        print(f"Discriminators initialized on {discriminator_a}, {discriminator_b}")

    # Initialize encoder for VAE-GAN
    if gan_type == "VAE-GAN":
        encoder = config["encoder"](embedding_dim=embedding_dim, latent_dim=latent_dim).to(device)
        print(f"Encoder initialized on {encoder}")

    # Select the appropriate training function
    train_function = config["train_function"]

    # Handle any additional configurations (like learning rate, lambda_gp)
    train_kwargs = config.get("train_kwargs", {})  # Default to empty dict if not specified
    train_kwargs.update({
        "latent_dim": latent_dim,
        "epochs": epochs,  # Pass epochs as part of train_kwargs
        "device": device,
        "learning_rate": learning_rate,
        "num_classes": num_classes if gan_type in ["Conditional-GAN", "InfoGAN", "Semi-Supervised-GAN"] else None,  # Add for Conditional, InfoGAN, and Semi-Supervised-GAN
        "categorical_dim": categorical_dim if gan_type == "InfoGAN" else None,  # Add only for InfoGAN
        "lambda_gp": 10 if gan_type == "WGAN-GP" else None,  # Add only for WGAN-GP
    })

    # Add data loaders conditionally
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        # Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN: pass two data loaders
        train_kwargs.update({
            "data_loader_a": embedding_loader_a,  # Two loaders for these models
            "data_loader_b": embedding_loader_b
        })
    else:
        # For all other models (WGAN-GP, VAE-GAN, etc.), use a single loader
        train_kwargs.update({
            "data_loader": embedding_loader  # Single data loader for other models
        })

    # Training loop (now handled inside the train_function)
    if gan_type == "VAE-GAN":
        # VAE-GAN requires encoder, generator, and discriminator
        train_function(encoder, generator, discriminator, **train_kwargs)
    else:
        # For other models, pass the appropriate generator/discriminator/critic
        train_function(generator_a, generator_b, discriminator_a, discriminator_b, **train_kwargs) if gan_type in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"] else train_function(generator, discriminator if "discriminator" in locals() else critic, **train_kwargs)

    print(f"{gan_type} training test passed!")

else:
    raise ValueError(f"Unsupported GAN type: {gan_type}")


Initializing Cycle-GAN configuration: {'generator_a': <class 'src.gan_workflows.plan2.plan2_gan_models.CycleGenerator'>, 'generator_b': <class 'src.gan_workflows.plan2.plan2_gan_models.CycleGenerator'>, 'discriminator_a': <class 'src.gan_workflows.plan2.plan2_gan_models.CycleDiscriminator'>, 'discriminator_b': <class 'src.gan_workflows.plan2.plan2_gan_models.CycleDiscriminator'>, 'train_function': <function train_cycle_gan at 0x7c0ebbc672e0>}
CycleGAN generators and discriminators initialized on CycleGenerator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=50, out_features=512, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (2): Linear(in_features=256, out_features=50, bias=True)
  )
), CycleGenerator(
  (mo

In [8]:
# GAN Configuration
gan_type = "WGAN-GP"  # Change to the desired GAN type
latent_dim = 100  # Latent dimension for the generator
embedding_dim = embeddings.size(1)  # Embedding dimension based on loaded embeddings
num_classes = 10  # For Conditional GAN and InfoGAN (e.g., 10 classes for MNIST)
categorical_dim = 10
epochs = 20
learning_rate = 0.0001
# Ensure device is properly set up
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the correct GAN models based on `gan_type`
if gan_type in gan_configurations:
    config = gan_configurations[gan_type]
    print(f"Initializing {gan_type} configuration: {config}")

    # Generator initialization for Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        if gan_type == "Cycle-GAN":
            # Cycle-GAN requires two generators and two discriminators
            generator_a = config["generator_a"](embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"CycleGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
        else:
            # Dual-GAN or Contrastive-Dual-GAN requires both latent_dim and embedding_dim for generators
            generator_a = config["generator_a"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"DualGAN/ContrastiveDualGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
    else:
        # For other GANs (WGAN-GP, VAE-GAN, etc.), initialize a single generator
        generator_args = {"latent_dim": latent_dim, "embedding_dim": embedding_dim}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            generator_args["num_classes"] = num_classes
        if gan_type == "InfoGAN":
            generator_args["categorical_dim"] = categorical_dim
        generator = config["generator"](**generator_args).to(device)
        print(f"Generator initialized on {generator}")

    # Handle critic/discriminator initialization conditionally
    if gan_type == "WGAN-GP":
        # WGAN-GP uses a critic instead of a discriminator
        critic = config["critic"](embedding_dim=embedding_dim).to(device)
        print(f"Critic initialized on {critic}")
    elif gan_type not in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]:
        # For other models (WGAN-GP, VAE-GAN, etc.), initialize a single discriminator
        discriminator_args = {"embedding_dim": embedding_dim}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            discriminator_args["num_classes"] = num_classes
        if gan_type == "InfoGAN":
            discriminator_args["categorical_dim"] = categorical_dim
        discriminator = config["discriminator"](**discriminator_args).to(device)
        print(f"Discriminator initialized on {discriminator}")
    else:
        # If discriminator is part of the configuration (Cycle-GAN, Dual-GAN, etc.), use the ones in the config
        discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
        discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
        print(f"Discriminators initialized on {discriminator_a}, {discriminator_b}")

    # Select the appropriate training function
    train_function = config["train_function"]

    # Handle any additional configurations (like learning rate, lambda_gp)
    train_kwargs = config.get("train_kwargs", {})  # Default to empty dict if not specified
    train_kwargs.update({
        "latent_dim": latent_dim,
        "epochs": epochs,  # Pass epochs as part of train_kwargs
        "device": device,
        "learning_rate": learning_rate,
        "num_classes": num_classes if gan_type in ["Conditional-GAN", "InfoGAN"] else None,  # Add only for Conditional and InfoGAN
        "categorical_dim": categorical_dim if gan_type == "InfoGAN" else None,  # Add only for InfoGAN
        "lambda_gp": 10 if gan_type == "WGAN-GP" else None,  # Add only for WGAN-GP
    })

    # Add data loaders conditionally
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        # Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN: pass two data loaders
        train_kwargs.update({
            "data_loader_a": data_loader,  # Two loaders for these models
            "data_loader_b": data_loader
        })
    else:
        # For all other models (WGAN-GP, VAE-GAN, etc.), use a single loader
        train_kwargs.update({
            "data_loader": data_loader  # Single data loader for other models
        })

    # Training loop (now handled inside the train_function)
    train_function(generator_a, generator_b, discriminator_a, discriminator_b, **train_kwargs) if gan_type in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"] else train_function(generator, critic if gan_type == "WGAN-GP" else discriminator, **train_kwargs)

    print(f"{gan_type} training test passed!")

else:
    raise ValueError(f"Unsupported GAN type: {gan_type}")


Initializing WGAN-GP configuration: {'generator': <class 'src.gan_workflows.plan2.plan2_gan_models.WGANGenerator'>, 'critic': <class 'src.gan_workflows.plan2.plan2_gan_models.WGANCritic'>, 'train_function': <function train_wgan_gp at 0x7c0ebbc66a20>, 'train_kwargs': {'lambda_gp': 10}}
Generator initialized on WGANGenerator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=100, out_features=256, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): Linear(in_features=256, out_features=50, bias=True)
  )
)
Critic initialized on WGANCritic(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=50, out_features=512, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=True)
        (1): LeakyReLU(negative_slope=0.2

In [9]:
# GAN Configuration
gan_type = "WGAN-GP"  # Change to the desired GAN type
latent_dim = 100  # Latent dimension for the generator
embedding_dim = embeddings.size(1)  # Embedding dimension based on loaded embeddings
num_classes = 10  # For Conditional GAN and InfoGAN (e.g., 10 classes for MNIST)
categorical_dim = 10
epochs = 20
learning_rate = 0.0001
# Ensure device is properly set up
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the data loaders before the GAN models
embedding_loader_a = data_loader
embedding_loader_b = data_loader
embedding_loader = data_loader

# Initialize the correct GAN models based on `gan_type`
if gan_type in gan_configurations:
    config = gan_configurations[gan_type]
    print(f"Initializing {gan_type} configuration: {config}")

    # Generator initialization for Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        if gan_type == "Cycle-GAN":
            # Cycle-GAN requires two generators and two discriminators
            generator_a = config["generator_a"](embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"CycleGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
        else:
            # Dual-GAN or Contrastive-Dual-GAN requires both latent_dim and embedding_dim for generators
            generator_a = config["generator_a"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"DualGAN/ContrastiveDualGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
    else:
        # For other GANs (WGAN-GP, VAE-GAN, etc.), initialize a single generator
        generator_args = {"latent_dim": latent_dim, "embedding_dim": embedding_dim}
        if gan_type == "InfoGAN":
            # Correctly initialize InfoGAN generator with latent_dim and categorical_dim
            generator_args["categorical_dim"] = categorical_dim
        generator = config["generator"](**generator_args).to(device)
        print(f"Generator initialized on {generator}")

    # Handle critic/discriminator initialization conditionally
    if gan_type == "WGAN-GP":
        # WGAN-GP uses a critic instead of a discriminator
        critic = config["critic"](embedding_dim=embedding_dim).to(device)
        print(f"Critic initialized on {critic}")
    elif gan_type == "Semi-Supervised-GAN":
        # Semi-Supervised-GAN requires num_classes for the discriminator
        discriminator = config["discriminator"](embedding_dim=embedding_dim, num_classes=num_classes).to(device)
        print(f"Semi-Supervised-GAN discriminator initialized on {discriminator}")
    elif gan_type == "InfoGAN":
        # InfoGAN discriminator requires embedding_dim and categorical_dim
        discriminator = config["discriminator"](embedding_dim=embedding_dim, categorical_dim=categorical_dim).to(device)
        print(f"InfoGAN discriminator initialized on {discriminator}")
    elif gan_type not in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]:
        # For other models (VAE-GAN, Conditional-GAN, etc.), initialize a single discriminator
        discriminator_args = {"embedding_dim": embedding_dim}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            discriminator_args["num_classes"] = num_classes
        if gan_type == "InfoGAN":
            discriminator_args["categorical_dim"] = categorical_dim
        discriminator = config["discriminator"](**discriminator_args).to(device)
        print(f"Discriminator initialized on {discriminator}")
    else:
        # If discriminator is part of the configuration (Cycle-GAN, Dual-GAN, etc.), use the ones in the config
        discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
        discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
        print(f"Discriminators initialized on {discriminator_a}, {discriminator_b}")

    # Initialize encoder for VAE-GAN
    if gan_type == "VAE-GAN":
        encoder = config["encoder"](embedding_dim=embedding_dim, latent_dim=latent_dim).to(device)
        print(f"Encoder initialized on {encoder}")

    # Select the appropriate training function
    train_function = config["train_function"]

    # Handle any additional configurations (like learning rate, lambda_gp)
    train_kwargs = config.get("train_kwargs", {})  # Default to empty dict if not specified
    train_kwargs.update({
        "latent_dim": latent_dim,
        "epochs": epochs,  # Pass epochs as part of train_kwargs
        "device": device,
        "learning_rate": learning_rate,
        "num_classes": num_classes if gan_type in ["Conditional-GAN", "InfoGAN", "Semi-Supervised-GAN"] else None,  # Add for Conditional, InfoGAN, and Semi-Supervised-GAN
        "categorical_dim": categorical_dim if gan_type == "InfoGAN" else None,  # Add only for InfoGAN
        "lambda_gp": 10 if gan_type == "WGAN-GP" else None,  # Add only for WGAN-GP
    })

    # Add data loaders conditionally
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        # Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN: pass two data loaders
        train_kwargs.update({
            "data_loader_a": embedding_loader_a,  # Two loaders for these models
            "data_loader_b": embedding_loader_b
        })
    else:
        # For all other models (WGAN-GP, VAE-GAN, etc.), use a single loader
        train_kwargs.update({
            "data_loader": embedding_loader  # Single data loader for other models
        })

    # Training loop (now handled inside the train_function)
    if gan_type == "VAE-GAN":
        # VAE-GAN requires encoder, generator, and discriminator
        train_function(encoder, generator, discriminator, **train_kwargs)
    else:
        # For other models, pass the appropriate generator/discriminator/critic
        train_function(generator_a, generator_b, discriminator_a, discriminator_b, **train_kwargs) if gan_type in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"] else train_function(generator, discriminator if "discriminator" in locals() else critic, **train_kwargs)

    print(f"{gan_type} training test passed!")

else:
    raise ValueError(f"Unsupported GAN type: {gan_type}")


Initializing WGAN-GP configuration: {'generator': <class 'src.gan_workflows.plan2.plan2_gan_models.WGANGenerator'>, 'critic': <class 'src.gan_workflows.plan2.plan2_gan_models.WGANCritic'>, 'train_function': <function train_wgan_gp at 0x7c0ebbc66a20>, 'train_kwargs': {'lambda_gp': 10, 'latent_dim': 100, 'epochs': 20, 'device': device(type='cpu'), 'learning_rate': 0.0001, 'num_classes': None, 'categorical_dim': None, 'data_loader': <torch.utils.data.dataloader.DataLoader object at 0x7c0eac70da50>}}
Generator initialized on WGANGenerator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=100, out_features=256, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): Linear(in_features=256, out_features=50, bias=True)
  )
)
Critic initialized on WGANCritic(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=50, out_features=512, bias=True)
        (1)

In [10]:
# GAN Configuration
gan_type = "VAE-GAN"  # Change to the desired GAN type
latent_dim = 100  # Latent dimension for the generator
embedding_dim = embeddings.size(1)  # Embedding dimension based on loaded embeddings
num_classes = 10  # For Conditional GAN and InfoGAN (e.g., 10 classes for MNIST)
categorical_dim = 10
epochs = 20
learning_rate = 0.0001
# Ensure device is properly set up
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the correct GAN models based on `gan_type`
if gan_type in gan_configurations:
    config = gan_configurations[gan_type]
    print(f"Initializing {gan_type} configuration: {config}")

    # Generator initialization for Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        if gan_type == "Cycle-GAN":
            # Cycle-GAN requires two generators and two discriminators
            generator_a = config["generator_a"](embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"CycleGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
        else:
            # Dual-GAN or Contrastive-Dual-GAN requires both latent_dim and embedding_dim for generators
            generator_a = config["generator_a"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"DualGAN/ContrastiveDualGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
    else:
        # For other GANs (WGAN-GP, VAE-GAN, etc.), initialize a single generator
        generator_args = {"latent_dim": latent_dim, "embedding_dim": embedding_dim}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            generator_args["num_classes"] = num_classes
        if gan_type == "InfoGAN":
            generator_args["categorical_dim"] = categorical_dim
        generator = config["generator"](**generator_args).to(device)
        print(f"Generator initialized on {generator}")

    # Discriminator initialization for other GAN types
    if gan_type == "WGAN-GP":
        # WGAN-GP uses a critic instead of a discriminator
        critic = config["critic"](embedding_dim=embedding_dim).to(device)
        print(f"Critic initialized on {critic}")
    elif gan_type not in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]:
        # For other models (VAE-GAN, Conditional-GAN, etc.), initialize a single discriminator
        discriminator_args = {"embedding_dim": embedding_dim}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            discriminator_args["num_classes"] = num_classes
        if gan_type == "InfoGAN":
            discriminator_args["categorical_dim"] = categorical_dim
        discriminator = config["discriminator"](**discriminator_args).to(device)
        print(f"Discriminator initialized on {discriminator}")
    else:
        # If discriminator is part of the configuration (Cycle-GAN, Dual-GAN, etc.), use the ones in the config
        discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
        discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
        print(f"Discriminators initialized on {discriminator_a}, {discriminator_b}")

    # Initialize encoder for VAE-GAN
    if gan_type == "VAE-GAN":
        encoder = config["encoder"](embedding_dim=embedding_dim, latent_dim=latent_dim).to(device)
        print(f"Encoder initialized on {encoder}")

    # Select the appropriate training function
    train_function = config["train_function"]

    # Handle any additional configurations (like learning rate, lambda_gp)
    train_kwargs = config.get("train_kwargs", {})  # Default to empty dict if not specified
    train_kwargs.update({
        "latent_dim": latent_dim,
        "epochs": epochs,  # Pass epochs as part of train_kwargs
        "device": device,
        "learning_rate": learning_rate,
        "num_classes": num_classes if gan_type in ["Conditional-GAN", "InfoGAN"] else None,  # Add only for Conditional and InfoGAN
        "categorical_dim": categorical_dim if gan_type == "InfoGAN" else None,  # Add only for InfoGAN
        "lambda_gp": 10 if gan_type == "WGAN-GP" else None,  # Add only for WGAN-GP
    })

    # Add data loaders conditionally
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        # Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN: pass two data loaders
        train_kwargs.update({
            "data_loader_a": data_loader,  # Two loaders for these models
            "data_loader_b": data_loader
        })
    else:
        # For all other models (WGAN-GP, VAE-GAN, etc.), use a single loader
        train_kwargs.update({
            "data_loader": data_loader  # Single data loader for other models
        })

    # Training loop (now handled inside the train_function)
    if gan_type == "VAE-GAN":
        # VAE-GAN requires encoder, generator, and discriminator
        train_function(encoder, generator, discriminator, **train_kwargs)
    else:
        # For other models, pass the appropriate generator/discriminator/critic
        train_function(generator_a, generator_b, discriminator_a, discriminator_b, **train_kwargs) if gan_type in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"] else train_function(generator, discriminator if "discriminator" in locals() else critic, **train_kwargs)

    print(f"{gan_type} training test passed!")

else:
    raise ValueError(f"Unsupported GAN type: {gan_type}")


Initializing VAE-GAN configuration: {'encoder': <class 'src.gan_workflows.plan2.plan2_gan_models.VAEGANEncoder'>, 'generator': <class 'src.gan_workflows.plan2.plan2_gan_models.VAEGANGenerator'>, 'discriminator': <class 'src.gan_workflows.plan2.plan2_gan_models.VAEGANDiscriminator'>, 'train_function': <function train_vae_gan at 0x7c0ebbc67100>}
Generator initialized on VAEGANGenerator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=100, out_features=256, bias=True)
        (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): Linear(in_features=256, out_features=50, bias=True)
  )
)
Discriminator initialized on VAEGANDiscriminator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=50, out_features=512, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True

In [11]:
# GAN Configuration
gan_type = "VAE-GAN"  # Change to the desired GAN type
latent_dim = 100  # Latent dimension for the generator
embedding_dim = embeddings.size(1)  # Embedding dimension based on loaded embeddings
num_classes = 10  # For Conditional GAN and InfoGAN (e.g., 10 classes for MNIST)
categorical_dim = 10
epochs = 20
learning_rate = 0.0001
# Ensure device is properly set up
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the data loaders before the GAN models
embedding_loader_a = data_loader
embedding_loader_b = data_loader
embedding_loader = data_loader

# Initialize the correct GAN models based on `gan_type`
if gan_type in gan_configurations:
    config = gan_configurations[gan_type]
    print(f"Initializing {gan_type} configuration: {config}")

    # Generator initialization for Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        if gan_type == "Cycle-GAN":
            # Cycle-GAN requires two generators and two discriminators
            generator_a = config["generator_a"](embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"CycleGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
        else:
            # Dual-GAN or Contrastive-Dual-GAN requires both latent_dim and embedding_dim for generators
            generator_a = config["generator_a"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"DualGAN/ContrastiveDualGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
    else:
        # For other GANs (WGAN-GP, VAE-GAN, etc.), initialize a single generator
        generator_args = {"latent_dim": latent_dim, "embedding_dim": embedding_dim}
        if gan_type == "InfoGAN":
            # Correctly initialize InfoGAN generator with latent_dim and categorical_dim
            generator_args["categorical_dim"] = categorical_dim
        generator = config["generator"](**generator_args).to(device)
        print(f"Generator initialized on {generator}")

    # Handle critic/discriminator initialization conditionally
    if gan_type == "WGAN-GP":
        # WGAN-GP uses a critic instead of a discriminator
        critic = config["critic"](embedding_dim=embedding_dim).to(device)
        print(f"Critic initialized on {critic}")
    elif gan_type == "Semi-Supervised-GAN":
        # Semi-Supervised-GAN requires num_classes for the discriminator
        discriminator = config["discriminator"](embedding_dim=embedding_dim, num_classes=num_classes).to(device)
        print(f"Semi-Supervised-GAN discriminator initialized on {discriminator}")
    elif gan_type == "InfoGAN":
        # InfoGAN discriminator requires embedding_dim and categorical_dim
        discriminator = config["discriminator"](embedding_dim=embedding_dim, categorical_dim=categorical_dim).to(device)
        print(f"InfoGAN discriminator initialized on {discriminator}")
    elif gan_type not in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]:
        # For other models (VAE-GAN, Conditional-GAN, etc.), initialize a single discriminator
        discriminator_args = {"embedding_dim": embedding_dim}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            discriminator_args["num_classes"] = num_classes
        if gan_type == "InfoGAN":
            discriminator_args["categorical_dim"] = categorical_dim
        discriminator = config["discriminator"](**discriminator_args).to(device)
        print(f"Discriminator initialized on {discriminator}")
    else:
        # If discriminator is part of the configuration (Cycle-GAN, Dual-GAN, etc.), use the ones in the config
        discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
        discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
        print(f"Discriminators initialized on {discriminator_a}, {discriminator_b}")

    # Initialize encoder for VAE-GAN
    if gan_type == "VAE-GAN":
        encoder = config["encoder"](embedding_dim=embedding_dim, latent_dim=latent_dim).to(device)
        print(f"Encoder initialized on {encoder}")

    # Select the appropriate training function
    train_function = config["train_function"]

    # Handle any additional configurations (like learning rate, lambda_gp)
    train_kwargs = config.get("train_kwargs", {})  # Default to empty dict if not specified
    train_kwargs.update({
        "latent_dim": latent_dim,
        "epochs": epochs,  # Pass epochs as part of train_kwargs
        "device": device,
        "learning_rate": learning_rate,
        "num_classes": num_classes if gan_type in ["Conditional-GAN", "InfoGAN", "Semi-Supervised-GAN"] else None,  # Add for Conditional, InfoGAN, and Semi-Supervised-GAN
        "categorical_dim": categorical_dim if gan_type == "InfoGAN" else None,  # Add only for InfoGAN
        "lambda_gp": 10 if gan_type == "WGAN-GP" else None,  # Add only for WGAN-GP
    })

    # Add data loaders conditionally
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        # Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN: pass two data loaders
        train_kwargs.update({
            "data_loader_a": embedding_loader_a,  # Two loaders for these models
            "data_loader_b": embedding_loader_b
        })
    else:
        # For all other models (WGAN-GP, VAE-GAN, etc.), use a single loader
        train_kwargs.update({
            "data_loader": embedding_loader  # Single data loader for other models
        })

    # Training loop (now handled inside the train_function)
    if gan_type == "VAE-GAN":
        # VAE-GAN requires encoder, generator, and discriminator
        train_function(encoder, generator, discriminator, **train_kwargs)
    else:
        # For other models, pass the appropriate generator/discriminator/critic
        train_function(generator_a, generator_b, discriminator_a, discriminator_b, **train_kwargs) if gan_type in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"] else train_function(generator, discriminator if "discriminator" in locals() else critic, **train_kwargs)

    print(f"{gan_type} training test passed!")

else:
    raise ValueError(f"Unsupported GAN type: {gan_type}")


Initializing VAE-GAN configuration: {'encoder': <class 'src.gan_workflows.plan2.plan2_gan_models.VAEGANEncoder'>, 'generator': <class 'src.gan_workflows.plan2.plan2_gan_models.VAEGANGenerator'>, 'discriminator': <class 'src.gan_workflows.plan2.plan2_gan_models.VAEGANDiscriminator'>, 'train_function': <function train_vae_gan at 0x7c0ebbc67100>}
Generator initialized on VAEGANGenerator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=100, out_features=256, bias=True)
        (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): Linear(in_features=256, out_features=50, bias=True)
  )
)
Discriminator initialized on VAEGANDiscriminator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=50, out_features=512, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True

In [12]:
# GAN Configuration
gan_type = "Contrastive-GAN"  # Change to the desired GAN type
latent_dim = 100  # Latent dimension for the generator
embedding_dim = embeddings.size(1)  # Embedding dimension based on loaded embeddings
num_classes = 10  # For Conditional GAN and InfoGAN (e.g., 10 classes for MNIST)
categorical_dim = 10
epochs = 20
learning_rate = 0.0001
# Ensure device is properly set up
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the correct GAN models based on `gan_type`
if gan_type in gan_configurations:
    config = gan_configurations[gan_type]
    print(f"Initializing {gan_type} configuration: {config}")

    # Generator initialization for Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        if gan_type == "Cycle-GAN":
            # Cycle-GAN requires two generators and two discriminators
            generator_a = config["generator_a"](embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"CycleGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
        else:
            # Dual-GAN or Contrastive-Dual-GAN requires both latent_dim and embedding_dim for generators
            generator_a = config["generator_a"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"DualGAN/ContrastiveDualGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
    else:
        # For other GANs (WGAN-GP, VAE-GAN, etc.), initialize a single generator
        generator_args = {"latent_dim": latent_dim, "embedding_dim": embedding_dim}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            generator_args["num_classes"] = num_classes
        if gan_type == "InfoGAN":
            generator_args["categorical_dim"] = categorical_dim
        generator = config["generator"](**generator_args).to(device)
        print(f"Generator initialized on {generator}")

    # Discriminator initialization for other GAN types
    if gan_type == "WGAN-GP":
        # WGAN-GP uses a critic instead of a discriminator
        critic = config["critic"](embedding_dim=embedding_dim).to(device)
        print(f"Critic initialized on {critic}")
    elif gan_type not in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]:
        # For other models (VAE-GAN, Conditional-GAN, etc.), initialize a single discriminator
        discriminator_args = {"embedding_dim": embedding_dim}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            discriminator_args["num_classes"] = num_classes
        if gan_type == "InfoGAN":
            discriminator_args["categorical_dim"] = categorical_dim
        discriminator = config["discriminator"](**discriminator_args).to(device)
        print(f"Discriminator initialized on {discriminator}")
    else:
        # If discriminator is part of the configuration (Cycle-GAN, Dual-GAN, etc.), use the ones in the config
        discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
        discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
        print(f"Discriminators initialized on {discriminator_a}, {discriminator_b}")

    # Initialize encoder for VAE-GAN
    if gan_type == "VAE-GAN":
        encoder = config["encoder"](embedding_dim=embedding_dim, latent_dim=latent_dim).to(device)
        print(f"Encoder initialized on {encoder}")

    # Select the appropriate training function
    train_function = config["train_function"]

    # Handle any additional configurations (like learning rate, lambda_gp)
    train_kwargs = config.get("train_kwargs", {})  # Default to empty dict if not specified
    train_kwargs.update({
        "latent_dim": latent_dim,
        "epochs": epochs,  # Pass epochs as part of train_kwargs
        "device": device,
        "learning_rate": learning_rate,
        "num_classes": num_classes if gan_type in ["Conditional-GAN", "InfoGAN"] else None,  # Add only for Conditional and InfoGAN
        "categorical_dim": categorical_dim if gan_type == "InfoGAN" else None,  # Add only for InfoGAN
        "lambda_gp": 10 if gan_type == "WGAN-GP" else None,  # Add only for WGAN-GP
    })

    # Add data loaders conditionally
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        # Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN: pass two data loaders
        train_kwargs.update({
            "data_loader_a": data_loader,  # Two loaders for these models
            "data_loader_b": data_loader
        })
    else:
        # For all other models (WGAN-GP, VAE-GAN, etc.), use a single loader
        train_kwargs.update({
            "data_loader": data_loader  # Single data loader for other models
        })

    # Training loop (now handled inside the train_function)
    if gan_type == "VAE-GAN":
        # VAE-GAN requires encoder, generator, and discriminator
        train_function(encoder, generator, discriminator, **train_kwargs)
    else:
        # For other models, pass the appropriate generator/discriminator/critic
        train_function(generator_a, generator_b, discriminator_a, discriminator_b, **train_kwargs) if gan_type in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"] else train_function(generator, discriminator if "discriminator" in locals() else critic, **train_kwargs)

    print(f"{gan_type} training test passed!")

else:
    raise ValueError(f"Unsupported GAN type: {gan_type}")


Initializing Contrastive-GAN configuration: {'generator': <class 'src.gan_workflows.plan2.plan2_gan_models.ContrastiveGANGenerator'>, 'discriminator': <class 'src.gan_workflows.plan2.plan2_gan_models.ContrastiveGANDiscriminator'>, 'train_function': <function train_contrastive_gan at 0x7c0ebbc671a0>}
Generator initialized on ContrastiveGANGenerator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=100, out_features=512, bias=True)
        (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=True)
        (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (2): Linear(in_features=256, out_features=50, bias=True)
  )
)

In [13]:
# GAN Configuration
gan_type = "Contrastive-GAN"  # Change to the desired GAN type
latent_dim = 100  # Latent dimension for the generator
embedding_dim = embeddings.size(1)  # Embedding dimension based on loaded embeddings
num_classes = 10  # For Conditional GAN and InfoGAN (e.g., 10 classes for MNIST)
categorical_dim = 10
epochs = 20
learning_rate = 0.0001
# Ensure device is properly set up
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the data loaders before the GAN models
embedding_loader_a = data_loader
embedding_loader_b = data_loader
embedding_loader = data_loader

# Initialize the correct GAN models based on `gan_type`
if gan_type in gan_configurations:
    config = gan_configurations[gan_type]
    print(f"Initializing {gan_type} configuration: {config}")

    # Generator initialization for Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        if gan_type == "Cycle-GAN":
            # Cycle-GAN requires two generators and two discriminators
            generator_a = config["generator_a"](embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"CycleGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
        else:
            # Dual-GAN or Contrastive-Dual-GAN requires both latent_dim and embedding_dim for generators
            generator_a = config["generator_a"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"DualGAN/ContrastiveDualGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
    else:
        # For other GANs (WGAN-GP, VAE-GAN, etc.), initialize a single generator
        generator_args = {"latent_dim": latent_dim, "embedding_dim": embedding_dim}
        if gan_type == "InfoGAN":
            # Correctly initialize InfoGAN generator with latent_dim and categorical_dim
            generator_args["categorical_dim"] = categorical_dim
        generator = config["generator"](**generator_args).to(device)
        print(f"Generator initialized on {generator}")

    # Handle critic/discriminator initialization conditionally
    if gan_type == "WGAN-GP":
        # WGAN-GP uses a critic instead of a discriminator
        critic = config["critic"](embedding_dim=embedding_dim).to(device)
        print(f"Critic initialized on {critic}")
    elif gan_type == "Semi-Supervised-GAN":
        # Semi-Supervised-GAN requires num_classes for the discriminator
        discriminator = config["discriminator"](embedding_dim=embedding_dim, num_classes=num_classes).to(device)
        print(f"Semi-Supervised-GAN discriminator initialized on {discriminator}")
    elif gan_type == "InfoGAN":
        # InfoGAN discriminator requires embedding_dim and categorical_dim
        discriminator = config["discriminator"](embedding_dim=embedding_dim, categorical_dim=categorical_dim).to(device)
        print(f"InfoGAN discriminator initialized on {discriminator}")
    elif gan_type not in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]:
        # For other models (VAE-GAN, Conditional-GAN, etc.), initialize a single discriminator
        discriminator_args = {"embedding_dim": embedding_dim}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            discriminator_args["num_classes"] = num_classes
        if gan_type == "InfoGAN":
            discriminator_args["categorical_dim"] = categorical_dim
        discriminator = config["discriminator"](**discriminator_args).to(device)
        print(f"Discriminator initialized on {discriminator}")
    else:
        # If discriminator is part of the configuration (Cycle-GAN, Dual-GAN, etc.), use the ones in the config
        discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
        discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
        print(f"Discriminators initialized on {discriminator_a}, {discriminator_b}")

    # Initialize encoder for VAE-GAN
    if gan_type == "VAE-GAN":
        encoder = config["encoder"](embedding_dim=embedding_dim, latent_dim=latent_dim).to(device)
        print(f"Encoder initialized on {encoder}")

    # Select the appropriate training function
    train_function = config["train_function"]

    # Handle any additional configurations (like learning rate, lambda_gp)
    train_kwargs = config.get("train_kwargs", {})  # Default to empty dict if not specified
    train_kwargs.update({
        "latent_dim": latent_dim,
        "epochs": epochs,  # Pass epochs as part of train_kwargs
        "device": device,
        "learning_rate": learning_rate,
        "num_classes": num_classes if gan_type in ["Conditional-GAN", "InfoGAN", "Semi-Supervised-GAN"] else None,  # Add for Conditional, InfoGAN, and Semi-Supervised-GAN
        "categorical_dim": categorical_dim if gan_type == "InfoGAN" else None,  # Add only for InfoGAN
        "lambda_gp": 10 if gan_type == "WGAN-GP" else None,  # Add only for WGAN-GP
    })

    # Add data loaders conditionally
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        # Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN: pass two data loaders
        train_kwargs.update({
            "data_loader_a": embedding_loader_a,  # Two loaders for these models
            "data_loader_b": embedding_loader_b
        })
    else:
        # For all other models (WGAN-GP, VAE-GAN, etc.), use a single loader
        train_kwargs.update({
            "data_loader": embedding_loader  # Single data loader for other models
        })

    # Training loop (now handled inside the train_function)
    if gan_type == "VAE-GAN":
        # VAE-GAN requires encoder, generator, and discriminator
        train_function(encoder, generator, discriminator, **train_kwargs)
    else:
        # For other models, pass the appropriate generator/discriminator/critic
        train_function(generator_a, generator_b, discriminator_a, discriminator_b, **train_kwargs) if gan_type in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"] else train_function(generator, discriminator if "discriminator" in locals() else critic, **train_kwargs)

    print(f"{gan_type} training test passed!")

else:
    raise ValueError(f"Unsupported GAN type: {gan_type}")


Initializing Contrastive-GAN configuration: {'generator': <class 'src.gan_workflows.plan2.plan2_gan_models.ContrastiveGANGenerator'>, 'discriminator': <class 'src.gan_workflows.plan2.plan2_gan_models.ContrastiveGANDiscriminator'>, 'train_function': <function train_contrastive_gan at 0x7c0ebbc671a0>}
Generator initialized on ContrastiveGANGenerator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=100, out_features=512, bias=True)
        (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=True)
        (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (2): Linear(in_features=256, out_features=50, bias=True)
  )
)

In [14]:
# GAN Configuration
gan_type = "Cross-Domain-GAN"  # Change to the desired GAN type
latent_dim = 100  # Latent dimension for the generator
embedding_dim = embeddings.size(1)  # Embedding dimension based on loaded embeddings
num_classes = 10  # For Conditional GAN and InfoGAN (e.g., 10 classes for MNIST)
categorical_dim = 10
epochs = 20
learning_rate = 0.0001
# Ensure device is properly set up
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the correct GAN models based on `gan_type`
if gan_type in gan_configurations:
    config = gan_configurations[gan_type]
    print(f"Initializing {gan_type} configuration: {config}")

    # Generator initialization for Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        if gan_type == "Cycle-GAN":
            # Cycle-GAN requires two generators and two discriminators
            generator_a = config["generator_a"](embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"CycleGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
        else:
            # Dual-GAN or Contrastive-Dual-GAN requires both latent_dim and embedding_dim for generators
            generator_a = config["generator_a"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"DualGAN/ContrastiveDualGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
    else:
        # For other GANs (WGAN-GP, VAE-GAN, etc.), initialize a single generator
        generator_args = {"latent_dim": latent_dim, "embedding_dim": embedding_dim}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            generator_args["num_classes"] = num_classes
        if gan_type == "InfoGAN":
            generator_args["categorical_dim"] = categorical_dim
        generator = config["generator"](**generator_args).to(device)
        print(f"Generator initialized on {generator}")

    # Discriminator initialization for other GAN types
    if gan_type == "WGAN-GP":
        # WGAN-GP uses a critic instead of a discriminator
        critic = config["critic"](embedding_dim=embedding_dim).to(device)
        print(f"Critic initialized on {critic}")
    elif gan_type not in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]:
        # For other models (VAE-GAN, Conditional-GAN, etc.), initialize a single discriminator
        discriminator_args = {"embedding_dim": embedding_dim}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            discriminator_args["num_classes"] = num_classes
        if gan_type == "InfoGAN":
            discriminator_args["categorical_dim"] = categorical_dim
        discriminator = config["discriminator"](**discriminator_args).to(device)
        print(f"Discriminator initialized on {discriminator}")
    else:
        # If discriminator is part of the configuration (Cycle-GAN, Dual-GAN, etc.), use the ones in the config
        discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
        discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
        print(f"Discriminators initialized on {discriminator_a}, {discriminator_b}")

    # Initialize encoder for VAE-GAN
    if gan_type == "VAE-GAN":
        encoder = config["encoder"](embedding_dim=embedding_dim, latent_dim=latent_dim).to(device)
        print(f"Encoder initialized on {encoder}")

    # Select the appropriate training function
    train_function = config["train_function"]

    # Handle any additional configurations (like learning rate, lambda_gp)
    train_kwargs = config.get("train_kwargs", {})  # Default to empty dict if not specified
    train_kwargs.update({
        "latent_dim": latent_dim,
        "epochs": epochs,  # Pass epochs as part of train_kwargs
        "device": device,
        "learning_rate": learning_rate,
        "num_classes": num_classes if gan_type in ["Conditional-GAN", "InfoGAN"] else None,  # Add only for Conditional and InfoGAN
        "categorical_dim": categorical_dim if gan_type == "InfoGAN" else None,  # Add only for InfoGAN
        "lambda_gp": 10 if gan_type == "WGAN-GP" else None,  # Add only for WGAN-GP
    })

    # Add data loaders conditionally
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        # Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN: pass two data loaders
        train_kwargs.update({
            "data_loader_a": data_loader,  # Two loaders for these models
            "data_loader_b": data_loader
        })
    else:
        # For all other models (WGAN-GP, VAE-GAN, etc.), use a single loader
        train_kwargs.update({
            "data_loader": data_loader  # Single data loader for other models
        })

    # Training loop (now handled inside the train_function)
    if gan_type == "VAE-GAN":
        # VAE-GAN requires encoder, generator, and discriminator
        train_function(encoder, generator, discriminator, **train_kwargs)
    else:
        # For other models, pass the appropriate generator/discriminator/critic
        train_function(generator_a, generator_b, discriminator_a, discriminator_b, **train_kwargs) if gan_type in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"] else train_function(generator, discriminator if "discriminator" in locals() else critic, **train_kwargs)

    print(f"{gan_type} training test passed!")

else:
    raise ValueError(f"Unsupported GAN type: {gan_type}")


Initializing Cross-Domain-GAN configuration: {'generator': <class 'src.gan_workflows.plan2.plan2_gan_models.CrossDomainGenerator'>, 'discriminator': <class 'src.gan_workflows.plan2.plan2_gan_models.CrossDomainDiscriminator'>, 'train_function': <function train_cross_domain_gan at 0x7c0ebbc67240>}
Generator initialized on CrossDomainGenerator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=100, out_features=512, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (2): Linear(in_features=256, out_features=50, bias=True)
  )
)
Discriminator initialized on CrossDomainDiscriminator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=50, out_features=512, bias=True)
  

In [15]:
# GAN Configuration
gan_type = "Cross-Domain-GAN"  # Change to the desired GAN type
latent_dim = 100  # Latent dimension for the generator
embedding_dim = embeddings.size(1)  # Embedding dimension based on loaded embeddings
num_classes = 10  # For Conditional GAN and InfoGAN (e.g., 10 classes for MNIST)
categorical_dim = 10
epochs = 20
learning_rate = 0.0001
# Ensure device is properly set up
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the data loaders before the GAN models
embedding_loader_a = data_loader
embedding_loader_b = data_loader
embedding_loader = data_loader

# Initialize the correct GAN models based on `gan_type`
if gan_type in gan_configurations:
    config = gan_configurations[gan_type]
    print(f"Initializing {gan_type} configuration: {config}")

    # Generator initialization for Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        if gan_type == "Cycle-GAN":
            # Cycle-GAN requires two generators and two discriminators
            generator_a = config["generator_a"](embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"CycleGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
        else:
            # Dual-GAN or Contrastive-Dual-GAN requires both latent_dim and embedding_dim for generators
            generator_a = config["generator_a"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"DualGAN/ContrastiveDualGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
    else:
        # For other GANs (WGAN-GP, VAE-GAN, etc.), initialize a single generator
        generator_args = {"latent_dim": latent_dim, "embedding_dim": embedding_dim}
        if gan_type == "InfoGAN":
            # Correctly initialize InfoGAN generator with latent_dim and categorical_dim
            generator_args["categorical_dim"] = categorical_dim
        generator = config["generator"](**generator_args).to(device)
        print(f"Generator initialized on {generator}")

    # Handle critic/discriminator initialization conditionally
    if gan_type == "WGAN-GP":
        # WGAN-GP uses a critic instead of a discriminator
        critic = config["critic"](embedding_dim=embedding_dim).to(device)
        print(f"Critic initialized on {critic}")
    elif gan_type == "Semi-Supervised-GAN":
        # Semi-Supervised-GAN requires num_classes for the discriminator
        discriminator = config["discriminator"](embedding_dim=embedding_dim, num_classes=num_classes).to(device)
        print(f"Semi-Supervised-GAN discriminator initialized on {discriminator}")
    elif gan_type == "InfoGAN":
        # InfoGAN discriminator requires embedding_dim and categorical_dim
        discriminator = config["discriminator"](embedding_dim=embedding_dim, categorical_dim=categorical_dim).to(device)
        print(f"InfoGAN discriminator initialized on {discriminator}")
    elif gan_type not in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]:
        # For other models (VAE-GAN, Conditional-GAN, etc.), initialize a single discriminator
        discriminator_args = {"embedding_dim": embedding_dim}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            discriminator_args["num_classes"] = num_classes
        if gan_type == "InfoGAN":
            discriminator_args["categorical_dim"] = categorical_dim
        discriminator = config["discriminator"](**discriminator_args).to(device)
        print(f"Discriminator initialized on {discriminator}")
    else:
        # If discriminator is part of the configuration (Cycle-GAN, Dual-GAN, etc.), use the ones in the config
        discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
        discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
        print(f"Discriminators initialized on {discriminator_a}, {discriminator_b}")

    # Initialize encoder for VAE-GAN
    if gan_type == "VAE-GAN":
        encoder = config["encoder"](embedding_dim=embedding_dim, latent_dim=latent_dim).to(device)
        print(f"Encoder initialized on {encoder}")

    # Select the appropriate training function
    train_function = config["train_function"]

    # Handle any additional configurations (like learning rate, lambda_gp)
    train_kwargs = config.get("train_kwargs", {})  # Default to empty dict if not specified
    train_kwargs.update({
        "latent_dim": latent_dim,
        "epochs": epochs,  # Pass epochs as part of train_kwargs
        "device": device,
        "learning_rate": learning_rate,
        "num_classes": num_classes if gan_type in ["Conditional-GAN", "InfoGAN", "Semi-Supervised-GAN"] else None,  # Add for Conditional, InfoGAN, and Semi-Supervised-GAN
        "categorical_dim": categorical_dim if gan_type == "InfoGAN" else None,  # Add only for InfoGAN
        "lambda_gp": 10 if gan_type == "WGAN-GP" else None,  # Add only for WGAN-GP
    })

    # Add data loaders conditionally
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        # Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN: pass two data loaders
        train_kwargs.update({
            "data_loader_a": embedding_loader_a,  # Two loaders for these models
            "data_loader_b": embedding_loader_b
        })
    else:
        # For all other models (WGAN-GP, VAE-GAN, etc.), use a single loader
        train_kwargs.update({
            "data_loader": embedding_loader  # Single data loader for other models
        })

    # Training loop (now handled inside the train_function)
    if gan_type == "VAE-GAN":
        # VAE-GAN requires encoder, generator, and discriminator
        train_function(encoder, generator, discriminator, **train_kwargs)
    else:
        # For other models, pass the appropriate generator/discriminator/critic
        train_function(generator_a, generator_b, discriminator_a, discriminator_b, **train_kwargs) if gan_type in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"] else train_function(generator, discriminator if "discriminator" in locals() else critic, **train_kwargs)

    print(f"{gan_type} training test passed!")

else:
    raise ValueError(f"Unsupported GAN type: {gan_type}")


Initializing Cross-Domain-GAN configuration: {'generator': <class 'src.gan_workflows.plan2.plan2_gan_models.CrossDomainGenerator'>, 'discriminator': <class 'src.gan_workflows.plan2.plan2_gan_models.CrossDomainDiscriminator'>, 'train_function': <function train_cross_domain_gan at 0x7c0ebbc67240>}
Generator initialized on CrossDomainGenerator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=100, out_features=512, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=True)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (2): Linear(in_features=256, out_features=50, bias=True)
  )
)
Discriminator initialized on CrossDomainDiscriminator(
  (model): Sequential(
    (0): LinearBlock(
      (block): Sequential(
        (0): Linear(in_features=50, out_features=512, bias=True)
  

In [16]:
# interesting way to load the embeddings...

from torch.utils.data import DataLoader, TensorDataset

def load_embeddings(embedding_file, device, batch_size=64, return_labels=True):
    """
    Loads embeddings and their associated labels from a specified file and
    creates a DataLoader for batching the embeddings (and optionally labels).

    Args:
        embedding_file (str): Path to the file containing embeddings and labels.
        device (torch.device): The device (CPU/GPU) to load the tensors onto.
        batch_size (int, optional): The batch size for DataLoader. Default is 64.
        return_labels (bool, optional): Whether to include labels in the DataLoader. Default is True.

    Returns:
        tuple: A tuple containing:
            - embeddings (torch.Tensor): Loaded embeddings.
            - labels (torch.Tensor, optional): Corresponding labels for the embeddings (if return_labels=True).
            - data_loader (DataLoader): DataLoader for batching embeddings (and labels if required).
    """
    print(f"Loading embeddings from: {embedding_file}")
    data = torch.load(embedding_file)
    embeddings = data["embeddings"].to(device)

    # Initialize the DataLoader only with embeddings if labels are not required
    if return_labels:
        labels = data["labels"].to(device)
        # Create a TensorDataset containing both embeddings and labels
        dataset = TensorDataset(embeddings, labels)
        return embeddings, labels, DataLoader(dataset, batch_size=batch_size, shuffle=True)
    else:
        # Create DataLoader for embeddings only
        dataset = TensorDataset(embeddings)  # Just embeddings
        return embeddings, DataLoader(dataset, batch_size=batch_size, shuffle=True)

embeddings, labels, data_loader_v2 = load_embeddings(embedding_file, device)


Loading embeddings from: ./saved_embeddings/embeddings/autoencoders_BasicAutoencoder/BasicAutoencoder_embeddings.pt


  data = torch.load(embedding_file)


In [17]:
# GAN Configuration
gan_type = "InfoGAN"  # Change to the desired GAN type
latent_dim = 100  # Latent dimension for the generator
embedding_dim = embeddings.size(1)  # Embedding dimension based on loaded embeddings
num_classes = 10  # For Conditional GAN and InfoGAN (e.g., 10 classes for MNIST)
categorical_dim = 10
epochs = 20
learning_rate = 0.0001
# Ensure device is properly set up
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the data loaders before the GAN models
embedding_loader_a = data_loader
embedding_loader_b = data_loader
embedding_loader = data_loader_v2

# Initialize the correct GAN models based on `gan_type`
if gan_type in gan_configurations:
    config = gan_configurations[gan_type]
    print(f"Initializing {gan_type} configuration: {config}")

    # Generator initialization for Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        if gan_type == "Cycle-GAN":
            # Cycle-GAN requires two generators and two discriminators
            generator_a = config["generator_a"](embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"CycleGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
        else:
            # Dual-GAN or Contrastive-Dual-GAN requires both latent_dim and embedding_dim for generators
            generator_a = config["generator_a"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"DualGAN/ContrastiveDualGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
    else:
        # For other GANs (WGAN-GP, VAE-GAN, etc.), initialize a single generator
        generator_args = {"latent_dim": latent_dim, "embedding_dim": embedding_dim}
        if gan_type == "InfoGAN":
            # Correctly initialize InfoGAN generator with latent_dim and categorical_dim
            generator_args["categorical_dim"] = categorical_dim
        generator = config["generator"](**generator_args).to(device)
        print(f"Generator initialized on {generator}")

    # Handle critic/discriminator initialization conditionally
    if gan_type == "WGAN-GP":
        # WGAN-GP uses a critic instead of a discriminator
        critic = config["critic"](embedding_dim=embedding_dim).to(device)
        print(f"Critic initialized on {critic}")
    elif gan_type == "Semi-Supervised-GAN":
        # Semi-Supervised-GAN requires num_classes for the discriminator
        discriminator = config["discriminator"](embedding_dim=embedding_dim, num_classes=num_classes).to(device)
        print(f"Semi-Supervised-GAN discriminator initialized on {discriminator}")
    elif gan_type == "InfoGAN":
        # InfoGAN discriminator requires embedding_dim and categorical_dim
        discriminator = config["discriminator"](embedding_dim=embedding_dim, categorical_dim=categorical_dim).to(device)
        print(f"InfoGAN discriminator initialized on {discriminator}")
    elif gan_type not in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]:
        # For other models (VAE-GAN, Conditional-GAN, etc.), initialize a single discriminator
        discriminator_args = {"embedding_dim": embedding_dim}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            discriminator_args["num_classes"] = num_classes
        if gan_type == "InfoGAN":
            discriminator_args["categorical_dim"] = categorical_dim
        discriminator = config["discriminator"](**discriminator_args).to(device)
        print(f"Discriminator initialized on {discriminator}")
    else:
        # If discriminator is part of the configuration (Cycle-GAN, Dual-GAN, etc.), use the ones in the config
        discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
        discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
        print(f"Discriminators initialized on {discriminator_a}, {discriminator_b}")

    # Initialize encoder for VAE-GAN
    if gan_type == "VAE-GAN":
        encoder = config["encoder"](embedding_dim=embedding_dim, latent_dim=latent_dim).to(device)
        print(f"Encoder initialized on {encoder}")

    # Select the appropriate training function
    train_function = config["train_function"]

    # Handle any additional configurations (like learning rate, lambda_gp)
    train_kwargs = config.get("train_kwargs", {})  # Default to empty dict if not specified
    train_kwargs.update({
        "latent_dim": latent_dim,
        "epochs": epochs,  # Pass epochs as part of train_kwargs
        "device": device,
        "learning_rate": learning_rate,
        "num_classes": num_classes if gan_type in ["Conditional-GAN", "InfoGAN", "Semi-Supervised-GAN"] else None,  # Add for Conditional, InfoGAN, and Semi-Supervised-GAN
        "categorical_dim": categorical_dim if gan_type == "InfoGAN" else None,  # Add only for InfoGAN
        "lambda_gp": 10 if gan_type == "WGAN-GP" else None,  # Add only for WGAN-GP
    })

    # Add data loaders conditionally
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        # Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN: pass two data loaders
        train_kwargs.update({
            "data_loader_a": embedding_loader_a,  # Two loaders for these models
            "data_loader_b": embedding_loader_b
        })
    else:
        # For all other models (WGAN-GP, VAE-GAN, etc.), use a single loader
        train_kwargs.update({
            "data_loader": embedding_loader  # Single data loader for other models
        })

    # Training loop (now handled inside the train_function)
    if gan_type == "VAE-GAN":
        # VAE-GAN requires encoder, generator, and discriminator
        train_function(encoder, generator, discriminator, **train_kwargs)
    else:
        # For other models, pass the appropriate generator/discriminator/critic
        train_function(generator_a, generator_b, discriminator_a, discriminator_b, **train_kwargs) if gan_type in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"] else train_function(generator, discriminator if "discriminator" in locals() else critic, **train_kwargs)

    print(f"{gan_type} training test passed!")

else:
    raise ValueError(f"Unsupported GAN type: {gan_type}")


Initializing InfoGAN configuration: {'generator': <class 'src.gan_workflows.plan2.plan2_gan_models.InfoGANGenerator'>, 'discriminator': <class 'src.gan_workflows.plan2.plan2_gan_models.InfoGANDiscriminator'>, 'train_function': <function train_infogan at 0x7c0ebbc67600>}
Generator initialized on InfoGANGenerator(
  (model): Sequential(
    (0): Linear(in_features=110, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): LeakyReLU(negative_slope=0.2, inplace=True)
    (4): Linear(in_features=256, out_features=50, bias=True)
  )
)
InfoGAN discriminator initialized on InfoGANDiscriminator(
  (model): Sequential(
    (0): Linear(in_features=50, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): LeakyReLU(negative_slope=0.2, inplace=True)
    (4): Linear(in_features=256, out_features=1, bias=T

RuntimeError: all elements of input should be between 0 and 1

In [24]:
# GAN Configuration
gan_type = "Conditional-GAN"  # Change to the desired GAN type
latent_dim = 100  # Latent dimension for the generator
embedding_dim = embeddings.size(1)  # Embedding dimension based on loaded embeddings
num_classes = 10  # For Conditional GAN and InfoGAN (e.g., 10 classes for MNIST)
categorical_dim = 10
epochs = 20
learning_rate = 0.0001
# Ensure device is properly set up
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the data loaders before the GAN models
embedding_loader_a = data_loader
embedding_loader_b = data_loader
embedding_loader = data_loader_v2

# Initialize the correct GAN models based on `gan_type`
if gan_type in gan_configurations:
    config = gan_configurations[gan_type]
    print(f"Initializing {gan_type} configuration: {config}")

    # Generator initialization for Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        if gan_type == "Cycle-GAN":
            # Cycle-GAN requires two generators and two discriminators
            generator_a = config["generator_a"](embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"CycleGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
        else:
            # Dual-GAN or Contrastive-Dual-GAN requires both latent_dim and embedding_dim for generators
            generator_a = config["generator_a"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            generator_b = config["generator_b"](latent_dim=latent_dim, embedding_dim=embedding_dim).to(device)
            discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
            discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
            print(f"DualGAN/ContrastiveDualGAN generators and discriminators initialized on {generator_a}, {generator_b}, {discriminator_a}, {discriminator_b}")
    else:
        # For other GANs (WGAN-GP, VAE-GAN, etc.), initialize a single generator
        generator_args = {"latent_dim": latent_dim, "embedding_dim": embedding_dim}
        if gan_type == "InfoGAN":
            # Correctly initialize InfoGAN generator with latent_dim and categorical_dim
            generator_args["categorical_dim"] = categorical_dim
        if gan_type == "Conditional-GAN":
            # Conditional-GAN generator requires num_classes
            generator_args["num_classes"] = num_classes
        generator = config["generator"](**generator_args).to(device)
        print(f"Generator initialized on {generator}")

    # Handle critic/discriminator initialization conditionally
    if gan_type == "WGAN-GP":
        # WGAN-GP uses a critic instead of a discriminator
        critic = config["critic"](embedding_dim=embedding_dim).to(device)
        print(f"Critic initialized on {critic}")
    elif gan_type == "Semi-Supervised-GAN":
        # Semi-Supervised-GAN requires num_classes for the discriminator
        discriminator = config["discriminator"](embedding_dim=embedding_dim, num_classes=num_classes).to(device)
        print(f"Semi-Supervised-GAN discriminator initialized on {discriminator}")
    elif gan_type == "InfoGAN":
        # InfoGAN discriminator requires embedding_dim and categorical_dim
        discriminator = config["discriminator"](embedding_dim=embedding_dim, categorical_dim=categorical_dim).to(device)
        print(f"InfoGAN discriminator initialized on {discriminator}")
    elif gan_type == "Conditional-GAN":
        # Conditional-GAN discriminator requires embedding_dim and num_classes
        discriminator = config["discriminator"](embedding_dim=embedding_dim, num_classes=num_classes).to(device)
        print(f"Conditional-GAN discriminator initialized on {discriminator}")
    elif gan_type not in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"]:
        # For other models (VAE-GAN, Conditional-GAN, etc.), initialize a single discriminator
        discriminator_args = {"embedding_dim": embedding_dim}
        if gan_type in ["Conditional-GAN", "InfoGAN"]:
            discriminator_args["num_classes"] = num_classes
        if gan_type == "InfoGAN":
            discriminator_args["categorical_dim"] = categorical_dim
        discriminator = config["discriminator"](**discriminator_args).to(device)
        print(f"Discriminator initialized on {discriminator}")
    else:
        # If discriminator is part of the configuration (Cycle-GAN, Dual-GAN, etc.), use the ones in the config
        discriminator_a = config["discriminator_a"](embedding_dim=embedding_dim).to(device)
        discriminator_b = config["discriminator_b"](embedding_dim=embedding_dim).to(device)
        print(f"Discriminators initialized on {discriminator_a}, {discriminator_b}")

    # Initialize encoder for VAE-GAN
    if gan_type == "VAE-GAN":
        encoder = config["encoder"](embedding_dim=embedding_dim, latent_dim=latent_dim).to(device)
        print(f"Encoder initialized on {encoder}")

    # Select the appropriate training function
    train_function = config["train_function"]

    # Handle any additional configurations (like learning rate, lambda_gp)
    train_kwargs = config.get("train_kwargs", {})  # Default to empty dict if not specified
    train_kwargs.update({
        "latent_dim": latent_dim,
        "epochs": epochs,  # Pass epochs as part of train_kwargs
        "device": device,
        "learning_rate": learning_rate,
        "num_classes": num_classes if gan_type in ["Conditional-GAN", "InfoGAN", "Semi-Supervised-GAN"] else None,  # Add for Conditional, InfoGAN, and Semi-Supervised-GAN
        "categorical_dim": categorical_dim if gan_type == "InfoGAN" else None,  # Add only for InfoGAN
        "lambda_gp": 10 if gan_type == "WGAN-GP" else None,  # Add only for WGAN-GP
    })

    # Add data loaders conditionally
    if gan_type in ["Dual-GAN", "Cycle-GAN", "Contrastive-Dual-GAN"]:
        # Dual-GAN, Cycle-GAN, and Contrastive-Dual-GAN: pass two data loaders
        train_kwargs.update({
            "data_loader_a": embedding_loader_a,  # Two loaders for these models
            "data_loader_b": embedding_loader_b
        })
    else:
        # For all other models (WGAN-GP, VAE-GAN, etc.), use a single loader
        train_kwargs.update({
            "data_loader": embedding_loader  # Single data loader for other models
        })

    # Training loop (now handled inside the train_function)
    if gan_type == "VAE-GAN":
        # VAE-GAN requires encoder, generator, and discriminator
        train_function(encoder, generator, discriminator, **train_kwargs)
    else:
        # For other models, pass the appropriate generator/discriminator/critic
        train_function(generator_a, generator_b, discriminator_a, discriminator_b, **train_kwargs) if gan_type in ["Cycle-GAN", "Dual-GAN", "Contrastive-Dual-GAN"] else train_function(generator, discriminator if "discriminator" in locals() else critic, **train_kwargs)

    print(f"{gan_type} training test passed!")

else:
    raise ValueError(f"Unsupported GAN type: {gan_type}")



Initializing Conditional-GAN configuration: {'generator': <class 'src.gan_workflows.plan2.plan2_gan_models.ConditionalGANGenerator'>, 'discriminator': <class 'src.gan_workflows.plan2.plan2_gan_models.ConditionalGANDiscriminator'>, 'train_function': <function train_conditional_gan at 0x7c0ebbc67560>}
Generator initialized on ConditionalGANGenerator(
  (model): Sequential(
    (0): Linear(in_features=110, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): LeakyReLU(negative_slope=0.2, inplace=True)
    (4): Linear(in_features=256, out_features=50, bias=True)
  )
)
Conditional-GAN discriminator initialized on ConditionalGANDiscriminator(
  (model): Sequential(
    (0): Linear(in_features=60, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): LeakyReLU(negative_slope=0.2, inplace=True)
   

RuntimeError: Tensors must have same number of dimensions: got 2 and 1