In [32]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
import wandb
import h5py
import os
import time
import json
from datetime import datetime

In [None]:
class Generator(nn.Module):
    def __init__(self, x_dim, vocab_sizes, nb_numeric, z_dim, library_size=10000):
        """
        Generator network for conditional GAN with fixed architecture and library size normalization
        Args:
            x_dim: Dimension of output data
            vocab_sizes: List of vocabulary sizes for each categorical variable
            nb_numeric: Number of numeric covariates
            z_dim: Dimension of latent noise vector
            library_size: Target sum for the generated expression values (default: 10000)
        """
        super(Generator, self).__init__()
        
        # Store library size
        self.library_size = library_size
        
        # Embedding layers for categorical variables
        self.embeddings = nn.ModuleList([
            nn.Embedding(vocab_size, min(50, vocab_size)) 
            for vocab_size in vocab_sizes
        ])
        
        # Calculate total embedding dimension
        embedding_dim = sum(min(50, vocab_size) for vocab_size in vocab_sizes)
        
        # Input dimension is latent dim + embedding dim + numeric covariates
        input_dim = z_dim + embedding_dim + nb_numeric
        
        # Fixed architecture with 3 layers: 256 -> 512 -> 1024
        self.network = nn.Sequential(
            # First layer: input_dim -> 256
            nn.Linear(input_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            
            # Second layer: 256 -> 512
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            
            # Third layer: 512 -> 1024
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            
            # Output layer: 1024 -> x_dim
            nn.Linear(1024, x_dim),
            # No activation here as we'll normalize in forward pass
        )

    def normalize_to_library_size(self, x):
        """
        Normalize the output tensor so the sum equals the target library size
        while ensuring all values are non-negative
        """
        # Apply ReLU to ensure non-negative values
        x = torch.relu(x)
        
        # Add small epsilon to avoid division by zero
        epsilon = 1e-10
        
        # Calculate current sum for each sample
        current_sums = x.sum(dim=1, keepdim=True) + epsilon
        
        # Scale to target library size
        normalized = x * (self.library_size / current_sums)
        
        return normalized

    def forward(self, z, cat_covs, num_covs):
        # Process categorical covariates through embeddings
        embeddings = [emb(cat_covs[:, i]) for i, emb in enumerate(self.embeddings)]
        embedded = torch.cat(embeddings, dim=1)
        
        # Concatenate all inputs
        gen_input = torch.cat([z, embedded, num_covs], dim=1)
        
        # Generate output through network
        output = self.network(gen_input)
        
        # Normalize to library size
        normalized_output = self.normalize_to_library_size(output)
        
        return normalized_output

    def get_negative_penalty(self, generated_data):
        """Calculate penalty for negative values"""
        negative_mask = (generated_data < 0).float()
        negative_proportion = negative_mask.mean()
        negative_magnitude = (generated_data * negative_mask).abs().mean()
        return negative_magnitude, negative_proportion

In [None]:
class Discriminator(nn.Module):
    def __init__(self, x_dim, vocab_sizes, nb_numeric, use_neg_detector=False):
        """
        Discriminator network with fixed architecture: 1024 -> 512 -> 256 -> 1
        Args:
            x_dim: Dimension of input data
            vocab_sizes: List of vocabulary sizes for each categorical variable
            nb_numeric: Number of numeric covariates
            use_neg_detector: Whether to use negative value detection
        """
        super(Discriminator, self).__init__()
        
        # Store use_neg_detector flag
        self.use_neg_detector = use_neg_detector
        
        # Embedding layers for categorical variables
        self.embeddings = nn.ModuleList([
            nn.Embedding(vocab_size, min(50, vocab_size))
            for vocab_size in vocab_sizes
        ])
        
        # Calculate total embedding dimension
        embedding_dim = sum(min(50, vocab_size) for vocab_size in vocab_sizes)
        
        # Input dimension is data dim + embedding dim + numeric covariates
        input_dim = x_dim + embedding_dim + nb_numeric
        
        # Fixed discriminator architecture
        self.main_network = nn.Sequential(
            # First layer: input_dim -> 1024
            nn.Linear(input_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            # Second layer: 1024 -> 512
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            # Third layer: 512 -> 256
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            # Output layer: 256 -> 1
            nn.Linear(256, 1)
        )
        
        # Add negative value detection branch if enabled
        if use_neg_detector:
            self.negative_detector = nn.Sequential(
                nn.Linear(x_dim, 1024),
                nn.LeakyReLU(0.2),
                nn.Linear(1024, 1),
                nn.Sigmoid()
            )

    def forward(self, x, cat_covs, num_covs):
        # Process categorical covariates through embeddings
        embeddings = [emb(cat_covs[:, i]) for i, emb in enumerate(self.embeddings)]
        embedded = torch.cat(embeddings, dim=1)
        
        # Concatenate inputs for main discrimination
        disc_input = torch.cat([x, embedded, num_covs], dim=1)
        
        # Main discrimination score
        validity = self.main_network(disc_input)
        
        # Add negative value detection if enabled
        if self.use_neg_detector:
            neg_score = self.negative_detector(torch.relu(-x))  # Only pass negative values
            return validity - 0.1 * neg_score  # Penalize negative values
        
        return validity

In [4]:
def compute_gradient_penalty(discriminator, real_samples, fake_samples, cat_covs, num_covs, device):
    """
    Calculate gradient penalty for WGAN-GP
    """
    # Random weight term for interpolation between real and fake samples
    alpha = torch.rand((real_samples.size(0), 1), device=device)
    
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    
    # Calculate discriminator output for interpolated samples
    d_interpolates = discriminator(interpolates, cat_covs, num_covs)
    
    # Get gradients w.r.t. interpolates
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(d_interpolates),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    
    # Calculate gradient penalty
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    
    return gradient_penalty

def train_gan(generator, discriminator, dataloader, cat_covs, num_covs, 
              config, device, score_fn=None, save_fn=None):
    """
    Train the conditional GAN with progress tracking and proper device handling
    """
    # Optimizers
    '''
    g_optimizer = optim.RMSprop(generator.parameters(), lr=config['lr'])
    d_optimizer = optim.RMSprop(discriminator.parameters(), lr=config['lr'])
    '''
    # New optimizers with AMSGrad:
    g_optimizer = optim.Adam(
        generator.parameters(),
        lr=config['lr'],
        betas=(config.get('beta1', 0.5), config.get('beta2', 0.9)),
        eps=config.get('eps', 1e-8),
        amsgrad=True)

    d_optimizer = optim.Adam(
        discriminator.parameters(),
        lr=config['lr'],
        betas=(config.get('beta1', 0.5), config.get('beta2', 0.9)),
        eps=config.get('eps', 1e-8),
        amsgrad=True)
    
    # Training parameters
    lambda_gp = config.get('lambda_gp', 10)
    grad_clip_value = config.get('grad_clip_value', 1.0)
    neg_penalty_start = config.get('negative_penalty_start_epoch', 50)
    neg_penalty_ramp = config.get('negative_penalty_ramp_epochs', 50)
    max_neg_penalty = config.get('max_negative_penalty', 10.0)
    
    # Convert covariates to tensors and move to device
    cat_covs = torch.tensor(cat_covs, dtype=torch.long).to(device)
    num_covs = torch.tensor(num_covs, dtype=torch.float32).to(device)
    
    total_batches = len(dataloader)
    
    print(f"Starting training for {config['epochs']} epochs...")
    print(f"Total batches per epoch: {total_batches}")
    print(f"Using device: {device}")
    print(f"Using negative detector: {discriminator.use_neg_detector}")
    print(f"Negative penalty starts at epoch: {neg_penalty_start}")
    
    def get_negative_penalty_weight(epoch):
        """Calculate curriculum learning weight for negative penalty"""
        if epoch < neg_penalty_start:
            return 0.0
        
        ramp_progress = (epoch - neg_penalty_start) / neg_penalty_ramp
        ramp_progress = min(1.0, max(0.0, ramp_progress))
        return max_neg_penalty * ramp_progress
    
    for epoch in range(config['epochs']):
        d_losses = []
        g_losses = []
        g_losses_main = []  # Track main generator loss without penalty
        neg_metrics = []
        
        print(f"\nEpoch [{epoch+1}/{config['epochs']}]")
        curr_neg_weight = get_negative_penalty_weight(epoch)
        
        for batch_idx, (real_data,) in enumerate(dataloader):
            batch_size = real_data.size(0)
            
            # Move real data to device
            real_data = real_data.to(device)
            
            # Get random batch of categorical and numerical covariates
            batch_indices = torch.randint(0, cat_covs.size(0), (batch_size,))
            batch_cat_covs = cat_covs[batch_indices]
            batch_num_covs = num_covs[batch_indices]
            
            # Train Discriminator
            for _ in range(config['nb_critic']):
                d_optimizer.zero_grad()
                
                # Generate fake data
                z = torch.randn(batch_size, config['latent_dim']).to(device)
                fake_data = generator(z, batch_cat_covs, batch_num_covs)
                
                # Calculate discriminator outputs
                real_validity = discriminator(real_data, batch_cat_covs, batch_num_covs)
                fake_validity = discriminator(fake_data.detach(), batch_cat_covs, batch_num_covs)
                
                # Calculate gradient penalty
                gp = compute_gradient_penalty(
                    discriminator,
                    real_data,
                    fake_data.detach(),
                    batch_cat_covs,
                    batch_num_covs,
                    device)
                
                # Calculate discriminator loss with gradient penalty
                d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gp
                
                d_loss.backward()
                d_optimizer.step()
                
                d_losses.append(d_loss.item())
            
            # Train Generator
            g_optimizer.zero_grad()
            
            # Generate fake data
            z = torch.randn(batch_size, config['latent_dim']).to(device)
            fake_data = generator(z, batch_cat_covs, batch_num_covs)
            
            # Calculate standard generator loss
            fake_validity = discriminator(fake_data, batch_cat_covs, batch_num_covs)
            g_loss_main = -torch.mean(fake_validity)
            
            # Calculate negative penalty
            neg_magnitude, neg_proportion = generator.get_negative_penalty(fake_data)
            g_loss = g_loss_main + curr_neg_weight * neg_magnitude
            
            g_loss.backward()
            
            g_optimizer.step()
            
            # Track losses and metrics
            g_losses.append(g_loss.item())
            g_losses_main.append(g_loss_main.item())
            neg_metrics.append({
                'proportion': neg_proportion.item(),
                'magnitude': neg_magnitude.item()
            })
            
            # Print progress every 10 batches
            if batch_idx % 10 == 0:
                progress_msg = (
                    f"  Batch [{batch_idx}/{total_batches}] "
                    f"D_loss: {d_loss.item():.4f}, "
                    f"G_loss: {g_loss.item():.4f}, "
                    f"G_main: {g_loss_main.item():.4f}, "
                    f"Neg_prop: {neg_proportion.item():.3f}, "
                    f"Neg_mag: {neg_magnitude.item():.3f}, "
                    f"Neg_weight: {curr_neg_weight:.3f}"
                )
                print(progress_msg)
        
        # Print epoch summary
        avg_d_loss = np.mean(d_losses)
        avg_g_loss = np.mean(g_losses)
        avg_g_main = np.mean(g_losses_main)
        avg_neg_prop = np.mean([m['proportion'] for m in neg_metrics])
        avg_neg_mag = np.mean([m['magnitude'] for m in neg_metrics])
        
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  Average D_loss: {avg_d_loss:.4f}")
        print(f"  Average G_loss: {avg_g_loss:.4f}")
        print(f"  Average G_main: {avg_g_main:.4f}")
        print(f"  Average Neg_prop: {avg_neg_prop:.4f}")
        print(f"  Average Neg_mag: {avg_neg_mag:.4f}")
        print(f"  Negative Penalty Weight: {curr_neg_weight:.4f}")
        
        # Log metrics
        if wandb.run is not None:
            metrics = {
                'epoch': epoch,
                'd_loss': avg_d_loss,
                'g_loss': avg_g_loss,
                'g_loss_main': avg_g_main,
                'negative_proportion': avg_neg_prop,
                'negative_magnitude': avg_neg_mag,
                'negative_penalty_weight': curr_neg_weight
            }
            
            if discriminator.use_neg_detector:
                metrics.update({
                    'discriminator_neg_proportion': avg_neg_prop,
                    'discriminator_neg_magnitude': avg_neg_mag
                })
            
            wandb.log(metrics)
        
        # Evaluate and save model if needed
        if score_fn is not None and epoch % 10 == 0:
            score = score_fn(generator)
            print(f'Epoch {epoch}: Score = {score:.4f}')
        
        if save_fn is not None and epoch % 20 == 0:
            save_fn(generator, discriminator, epoch)
        
        # Save model from last epoch
        if epoch == config['epochs'] - 1 and save_fn is not None:
            save_fn(generator, discriminator, epoch)

In [None]:
def main(selected_categories=None):
    """
    Train the GAN with selected categorical variables
    Args:
        selected_categories: List of column names to use as categorical variables.
                           If None, uses all columns except 'cell_id'
    """
    # Configuration
    CONFIG = {
        'epochs': 300,
        'latent_dim': 64,
        'batch_size': 32,
        'lr': 1e-4,
        'beta1': 0.5,      # First moment coefficient
        'beta2': 0.9,    # Second moment coefficient
        'eps': 1e-8,       # Small constant for numerical stability
        'nb_critic': 5,
        'lambda_gp': 10,
        'negative_penalty_start_epoch': 50,
        'negative_penalty_ramp_epochs': 50,
        'max_negative_penalty': 5.0,
        'library_size': 10000 
    }
    neg_detector = False # True - use negative value detector in the discriminator
    
    # Device configuration
    if torch.cuda.is_available():
        device = torch.device('cuda')
    elif torch.backends.mps.is_available():
        device = torch.device('mps')
    else:
        device = torch.device('cpu')
    
    print(f"Using device: {device}")
    
    # Load data
    data_path = "/Users/guyshani/Documents/PHD/Aim_2/10x_data_mouse/20_1_2025__normalized/"
    # data_path = "/content/drive/MyDrive/Colab_Notebooks/data/5000_genes/"
    
    # Load expression matrix
    # matrix with cells as columns and genes as rows
    with h5py.File(data_path+'combined_normalized_data.h5', 'r') as f:
        x_train = f['matrix'][:]
    '''
    with h5py.File(data_path+'full_matrix_top5000.h5', 'r') as f:
        x_train = np.array(f['matrix/data'])
    '''

    # Load all categorical variables from single file
    cat_data = pd.read_csv(data_path+'combined_metadata.csv', sep=';')
    print("Categorical data shape:", cat_data.shape)
    print("Available categorical variables:", [col for col in cat_data.columns if col != 'cell_id'])
    
    # Determine which categories to use
    if selected_categories is None:
        # Use all columns except cell_id
        categories_to_use = [col for col in cat_data.columns if col != 'cell_id']
    else:
        # Validate selected categories
        invalid_categories = [cat for cat in selected_categories if cat not in cat_data.columns]
        if invalid_categories:
            raise ValueError(f"Invalid categories: {invalid_categories}")
        categories_to_use = selected_categories
    
    print(f"\nUsing categorical variables: {categories_to_use}")
    
    # Create dictionaries and inverse mappings for categorical variables
    cat_dicts = []
    encoded_covs = []
    
    # Process each selected column as a categorical variable
    for column in categories_to_use:
        # Get the column data
        cat_vec = cat_data[column]
        print(f"\nProcessing categorical variable: {column}")
        
        # Create list of unique category names, sorted
        dict_inv = np.array(list(sorted(set(cat_vec.values))))
        dict_map = {t: i for i, t in enumerate(dict_inv)}
        cat_dicts.append(dict_inv)
        
        # Convert categorical variables to integers
        encoded = np.vectorize(lambda t: dict_map[t])(cat_vec)
        encoded = encoded.reshape(-1, 1)  # Reshape to column vector
        encoded_covs.append(encoded)
        
        print(f"Categories in {column}:", dict_inv)
        print(f"Number of categories:", len(dict_inv))
    
    # Combine all categorical covariates
    cat_covs = np.hstack(encoded_covs)
    print("\nCombined categorical covariates shape:", cat_covs.shape)
    
    # Load numerical covariates (currently empty)
    num_covs = np.zeros((x_train.shape[0], 0))
    
    # Convert data to PyTorch tensors and move to device
    x_train = torch.tensor(x_train, dtype=torch.float32)  # Keep on CPU for DataLoader
    
    # Create data loader
    train_dataset = TensorDataset(x_train)
    train_loader = DataLoader(
        train_dataset, 
        batch_size=CONFIG['batch_size'],
        shuffle=True,
        drop_last=True
    )
    
    ############# Initialize models
    # Generator
    vocab_sizes = [len(c) for c in cat_dicts]
    print("\nVocabulary sizes for categorical variables:", vocab_sizes)
    nb_numeric = num_covs.shape[-1]
    x_dim = x_train.shape[-1]
    
    generator = Generator(
        x_dim=x_dim,
        vocab_sizes=vocab_sizes,
        nb_numeric=nb_numeric,
        z_dim=CONFIG['latent_dim'],
        library_size=CONFIG['library_size']).to(device)
    
    # Discriminator
    discriminator = Discriminator(
        x_dim=x_dim,
        vocab_sizes=vocab_sizes,
        nb_numeric=nb_numeric,
        use_neg_detector=False).to(device)
    
    # Define save function
    def save_models(generator, discriminator, epoch):
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        # create save directory
        categories_str = "+".join(categories_to_use)
        save_dir = os.path.join(data_path, "saved_models")
        os.makedirs(save_dir, exist_ok=True)
        # Create run folder
        run_dir = os.path.join(save_dir, f"run_{timestamp}_{categories_str}")
        os.makedirs(run_dir, exist_ok=True)

        # Save model initialization parameters
        model_config = {
            'x_dim': x_dim,
            'vocab_sizes': vocab_sizes,
            'nb_numeric': nb_numeric,
            'h_dims': [CONFIG['hdim']] * CONFIG['nb_layers'],
            'z_dim': CONFIG['latent_dim'],
            'categories': categories_to_use,
            'training_config': CONFIG}
        config_path = os.path.join(run_dir, 'model_config.json')
        with open(config_path, 'w') as f:
            json.dump(model_config, f, indent=4)
        
        # Save generator
        generator_path = os.path.join(run_dir, f"generator_{timestamp}_{categories_str}_epoch_{epoch+1}.pt")
        torch.save(generator.state_dict(), generator_path)
        
        
        # Save discriminator
        discriminator_path = os.path.join(run_dir, f"discriminator_{timestamp}_{categories_str}_epoch_{epoch+1}.pt")
        torch.save(discriminator.state_dict(), discriminator_path)
        
        print(f"\nModels saved at epoch {epoch + 1}:")
        print(f"Generator: {generator_path}")
        print(f"Discriminator: {discriminator_path}")
        
        # Log to wandb
        if wandb.run is not None:
            wandb.save(generator_path)
            wandb.save(discriminator_path)

    # Initialize wandb with unique run name
    run_name = f"run_{int(time.time())}"  # Uses timestamp for unique name
    wandb.init(
        project='adversarial_gene_expr',
        config=CONFIG,
        name=run_name,
        reinit=True  # Ensures new run each time
    )
    
    # Add selected categories to wandb config
    wandb.config.update({'selected_categories': categories_to_use})
    
    # Train model
    train_gan(
        generator=generator,
        discriminator=discriminator,
        dataloader=train_loader,
        cat_covs=cat_covs,
        num_covs=num_covs,
        config=CONFIG,
        device=device,
        save_fn=save_models
        #save_fn=None
    )

if __name__ == '__main__':
    # Example usage:
    # Use specific categories:
    main(selected_categories=['dataset','singler_label'])
    
    # Or use all available categories:
    # main()

In [33]:
### Functions for data gneration
def inspect_generator_dims(generator):
    """
    Inspect the generator's dimensions and architecture
    
    Parameters:
        generator: Generator model
    
    Returns:
        dict containing dimension information
    """
    # Get embedding dimensions
    embedding_dims = [emb.embedding_dim for emb in generator.embeddings]
    total_embedding_dim = sum(embedding_dims)
    
    # Get first layer dimension
    first_layer_in_dim = generator.network[0].in_features
    
    return {
        'embedding_dims': embedding_dims,
        'total_embedding_dim': total_embedding_dim,
        'first_layer_in_dim': first_layer_in_dim,
        'recommended_latent_dim': first_layer_in_dim - total_embedding_dim
    }

def generate_expression_profiles(generator, n_samples, dataset_category, cell_type_category, device='mps', debug=False):
    """
    Generate gene expression profiles using the trained cWGAN generator
    
    Parameters:
        generator: Trained Generator model
        n_samples: Number of profiles to generate
        dataset_category: Integer indicating which dataset category to generate
        cell_type_category: Integer indicating which cell type to generate
        device: Device to run generation on ('cuda', 'mps', or 'cpu')
        debug: If True, print debugging information
    
    Returns:
        numpy array of generated expression profiles with shape (n_samples, n_genes)
    """
    # Set generator to eval mode
    generator.eval()
    
    # Inspect dimensions
    dims = inspect_generator_dims(generator)
    
    if debug:
        print("Generator dimensions:")
        for k, v in dims.items():
            print(f"{k}: {v}")
    
    # Create latent vectors
    latent_dim = dims['recommended_latent_dim']
    z = torch.randn(n_samples, latent_dim, device=device)
    
    if debug:
        print(f"\nLatent vector shape: {z.shape}")
    
    # Create categorical condition tensor with dataset and cell type
    num_embeddings = len(generator.embeddings)
    cat_covs = torch.zeros((n_samples, num_embeddings), dtype=torch.long, device=device)
    cat_covs[:, 0] = dataset_category  # Set dataset category
    cat_covs[:, 1] = cell_type_category  # Set cell type category
    
    if debug:
        print(f"Categorical covariates shape: {cat_covs.shape}")
        print(f"Number of embedding layers: {num_embeddings}")
    
    # Create empty numeric covariates tensor
    num_covs = torch.zeros((n_samples, 0), device=device)
    
    # Generate samples
    try:
        with torch.no_grad():
            # Get embeddings
            embeddings = [emb(cat_covs[:, i]) for i, emb in enumerate(generator.embeddings)]
            embedded = torch.cat(embeddings, dim=1)
            
            if debug:
                print(f"Embedded shape: {embedded.shape}")
            
            # Concatenate inputs
            gen_input = torch.cat([z, embedded, num_covs], dim=1)
            
            if debug:
                print(f"Generator input shape: {gen_input.shape}")
                print(f"First layer input dim: {generator.network[0].in_features}")
                print(f"First layer weight shape: {generator.network[0].weight.shape}")
            
            # Generate samples
            fake_samples = generator.network(gen_input)
            
    except RuntimeError as e:
        print("\nError during generation:")
        print(e)
        print("\nGenerator architecture:")
        print(generator)
        raise
    
    # Convert to numpy array
    return fake_samples.cpu().numpy()

def generate_and_save_profiles(generator, samples_per_combination, save_path, cell_type_names, device='mps', debug=False):
    """
    Generate expression profiles using the trained generator.
    
    Args:
        generator: Trained generator model
        samples_per_combination: Dictionary with (dataset_num, cell_type_num) keys and number of samples as values
        save_path: Where to save the generated data
        cell_type_names: Dictionary mapping cell type indices to their names
        device: Device to use for generation
        debug: Whether to print debug information
    """
    all_samples = []
    all_categories = []
    
    # Generate samples for specified combinations
    for (dataset_num, cell_type_num), n_samples in samples_per_combination.items():
        dataset_category = dataset_num - 1  # Convert dataset number (1-7) to category index (0-6)
        cell_type_category = cell_type_num
        cell_type_name = cell_type_names[cell_type_num]
        
        if debug:
            print(f"\nGenerating {n_samples} samples for dataset{dataset_num}, {cell_type_name}")
        
        samples = generate_expression_profiles(
            generator, 
            n_samples, 
            dataset_category,
            cell_type_category,
            device,
            debug=debug
        )
        all_samples.append(samples)
        all_categories.extend([f'dataset{dataset_num}_{cell_type_name}'] * n_samples)

    print("Save location: "+str(save_path))

    # Combine all samples
    all_samples = np.vstack(all_samples)
    
    # Save generated profiles
    np.save(f'{save_path}_profiles.npy', all_samples)
    
    # Save category labels
    with open(f'{save_path}_categories.txt', 'w') as f:
        for category in all_categories:
            f.write(f'{category}\n')

    
            
    return all_samples, all_categories

In [None]:
# Generate data
# Set directories
data_path = "~/Documents/PHD/Aim_2/test_models/5000_genes/"
#run_dir = "/Users/guyshani/Documents/PHD/Aim_2/test_models/5000_genes/run_20250206_184657_dataset_singler_label/"
#generator_model = "generator_20250206_184657_dataset_singler_label_epoch_281.pt"
####
run_dir = "/Users/guyshani/Documents/PHD/Aim_2/test_models/5000_genes/run_20250206_160000_dataset_singler_label/"
generator_model = "generator_20250206_160000_dataset_singler_label_epoch_121.pt"

# Device configuration
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
print(f"Using device: {device}")

# Load configuration
config_path = os.path.join(run_dir, 'model_config.json')
with open(config_path, 'r') as f:
    model_config = json.load(f)

# Initialize models with saved configuration
generator = Generator(
    x_dim=model_config['x_dim'],
    vocab_sizes=model_config['vocab_sizes'],
    nb_numeric=model_config['nb_numeric'],
    h_dims=model_config['h_dims'],
    z_dim=model_config['z_dim']).to(device)

generator_path = os.path.join(run_dir, generator_model)
generator.load_state_dict(torch.load(generator_path, map_location=device, weights_only=True))

# Load metadata
meta = pd.read_csv(data_path+'metadata_top5000.csv')

# Get counts of each dataset-celltype combination
counts = meta.groupby(['dataset', 'singler_label']).size().to_dict()

# Get unique cell types from the data
unique_cell_types = sorted(meta['singler_label'].unique())

# Create cell type mapping automatically
cell_type_map = {cell_type: idx for idx, cell_type in enumerate(unique_cell_types)}

print("Detected cell types:")
for cell_type, idx in cell_type_map.items():
    print(f"{cell_type}: {idx}")

# Convert the dict keys from tuple of strings to tuple of numbers
samples_dict = {}
for (dataset, cell_type), count in counts.items():
    # Extract dataset number
    dataset_num = int(dataset.replace('dataset', ''))
    # Get cell type number from our automatic mapping
    cell_type_num = cell_type_map[cell_type]
    # Add to new dictionary with numerical tuple as key
    samples_dict[(dataset_num, cell_type_num)] = count

# Get the reverse mapping for cell type names
cell_type_names = {v: k for k, v in cell_type_map.items()}

all_samples, categories = generate_and_save_profiles(
    generator,
    samples_per_combination=samples_dict,
    save_path=run_dir+'_generated_data',
    cell_type_names=cell_type_names,  # Pass the mapping to the function
    debug=False
)

In [None]:
## Load and analyze generated data
# Load the generated profiles
profiles = np.load(run_dir + '_generated_data_profiles.npy')

# Load categories
with open(run_dir + '_generated_data_categories.txt', 'r') as f:
    categories = [line.strip() for line in f]

# Load gene names
gene_names = pd.read_csv(data_path+"full_matrix_top5000.csv")

# Convert to pandas DataFrame
df = pd.DataFrame(profiles)
print(df.shape)
# Set gene names as header
#df.columns = gene_names.iloc[0]
df.columns = gene_names['Unnamed: 0']


# Add categories as a column
df['labels'] = categories

## 
# Add separate dataset and cell_type columns
df['dataset'] = df['labels'].apply(lambda x: x.split('_')[0])
df['cell_type'] = df['labels'].apply(lambda x: x.split('_')[1])
# drop the combined labels colum
df=df.drop(['labels'], axis=1)

# Save a csv file
df.to_csv(f'{run_dir}_generated_data.csv', index=False)
#
with h5py.File(f'{run_dir}_generated_data.h5', 'w') as f:
    # Save as matrix - select only numerical values
    f.create_dataset('matrix', data=df.select_dtypes(include=[np.number]).values)  # or df.to_numpy()

# Save the categorical information separately:
df_labels = df.select_dtypes(exclude=[np.number])
df_labels.to_csv(f'{run_dir}_generated_labels.csv', index=True)

In [16]:
data_path = "/Users/guyshani/Documents/PHD/Aim_2/test_models/5000_genes/"
# data_path = "/content/drive/MyDrive/Colab_Notebooks/data/"

# Load expression matrix
# matrix with cells as columns and genes as rows
with h5py.File(data_path+'full_matrix_top5000.h5', 'r') as f:
    x_train = np.array(f['matrix/data'])