In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
import wandb
import h5py
import os
import time
import json
from datetime import datetime

In [9]:
class Generator(nn.Module):
    def __init__(self, x_dim, vocab_sizes, nb_numeric, h_dims, z_dim):
        """
        Generator network for conditional GAN
        Args:
            x_dim: Dimension of output data
            vocab_sizes: List of vocabulary sizes for each categorical variable
            nb_numeric: Number of numeric covariates
            h_dims: List of hidden dimensions
            z_dim: Dimension of latent noise vector
        """
        super(Generator, self).__init__()
        
        # Embedding layers for categorical variables
        self.embeddings = nn.ModuleList([
            nn.Embedding(vocab_size, min(50, vocab_size)) 
            for vocab_size in vocab_sizes
        ])
        
        # Calculate total embedding dimension
        embedding_dim = sum(min(50, vocab_size) for vocab_size in vocab_sizes)
        
        # Input dimension is latent dim + embedding dim + numeric covariates
        input_dim = z_dim + embedding_dim + nb_numeric
        
        # Build generator network
        layers = []
        current_dim = input_dim
        
        # Add hidden layers
        for h_dim in h_dims:
            layers.extend([
                nn.Linear(current_dim, h_dim),
                nn.BatchNorm1d(h_dim),
                nn.ReLU()
            ])
            current_dim = h_dim
        
        # Output layer
        layers.append(nn.Linear(current_dim, x_dim))
        
        self.network = nn.Sequential(*layers)

    def forward(self, z, cat_covs, num_covs):
        # Process categorical covariates through embeddings
        embeddings = [emb(cat_covs[:, i]) for i, emb in enumerate(self.embeddings)]
        embedded = torch.cat(embeddings, dim=1)
        
        # Concatenate all inputs
        gen_input = torch.cat([z, embedded, num_covs], dim=1)
        
        # Generate output
        return self.network(gen_input)

In [10]:
class Discriminator(nn.Module):
    def __init__(self, x_dim, vocab_sizes, nb_numeric, h_dims):
        """
        Discriminator network for conditional GAN
        Args:
            x_dim: Dimension of input data
            vocab_sizes: List of vocabulary sizes for each categorical variable
            nb_numeric: Number of numeric covariates
            h_dims: List of hidden dimensions
        """
        super(Discriminator, self).__init__()
        
        # Embedding layers for categorical variables
        self.embeddings = nn.ModuleList([
            nn.Embedding(vocab_size, min(50, vocab_size))
            for vocab_size in vocab_sizes
        ])
        
        # Calculate total embedding dimension
        embedding_dim = sum(min(50, vocab_size) for vocab_size in vocab_sizes)
        
        # Input dimension is data dim + embedding dim + numeric covariates
        input_dim = x_dim + embedding_dim + nb_numeric
        
        # Build discriminator network
        layers = []
        current_dim = input_dim
        
        # Add hidden layers
        for h_dim in h_dims:
            layers.extend([
                nn.Linear(current_dim, h_dim),
                nn.LeakyReLU(0.2),
                nn.Dropout(0.3)
            ])
            current_dim = h_dim
        
        # Output layer
        layers.append(nn.Linear(current_dim, 1))
        
        self.network = nn.Sequential(*layers)

    def forward(self, x, cat_covs, num_covs):
        # Process categorical covariates through embeddings
        embeddings = [emb(cat_covs[:, i]) for i, emb in enumerate(self.embeddings)]
        embedded = torch.cat(embeddings, dim=1)
        
        # Concatenate all inputs
        disc_input = torch.cat([x, embedded, num_covs], dim=1)
        
        # Generate output
        return self.network(disc_input)


In [7]:
def compute_gradient_penalty(discriminator, real_samples, fake_samples, cat_covs, num_covs, device):
    """
    Calculate gradient penalty for WGAN-GP
    """
    # Random weight term for interpolation between real and fake samples
    alpha = torch.rand((real_samples.size(0), 1), device=device)
    
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    
    # Calculate discriminator output for interpolated samples
    d_interpolates = discriminator(interpolates, cat_covs, num_covs)
    
    # Get gradients w.r.t. interpolates
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(d_interpolates),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    
    # Calculate gradient penalty
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    
    return gradient_penalty

def train_gan(generator, discriminator, dataloader, cat_covs, num_covs, 
              config, device, score_fn=None, save_fn=None):
    """
    Train the conditional GAN with progress tracking and proper device handling
    """
    
    # Optimizers
    g_optimizer = optim.RMSprop(generator.parameters(), lr=config['lr'])
    d_optimizer = optim.RMSprop(discriminator.parameters(), lr=config['lr'])
    
    # Lambda for gradient penalty
    lambda_gp = 10
    
    # Convert covariates to tensors and move to device
    cat_covs = torch.tensor(cat_covs, dtype=torch.long).to(device)
    num_covs = torch.tensor(num_covs, dtype=torch.float32).to(device)
    
    total_batches = len(dataloader)
    
    print(f"Starting training for {config['epochs']} epochs...")
    print(f"Total batches per epoch: {total_batches}")
    print(f"Using device: {device}")
    
    for epoch in range(config['epochs']):
        d_losses = []
        g_losses = []
        print(f"\nEpoch [{epoch+1}/{config['epochs']}]")
        
        for batch_idx, (real_data,) in enumerate(dataloader):
            batch_size = real_data.size(0)
            
            # Move real data to device
            real_data = real_data.to(device)
            
            # Get random batch of categorical and numerical covariates
            batch_indices = torch.randint(0, cat_covs.size(0), (batch_size,))
            batch_cat_covs = cat_covs[batch_indices]
            batch_num_covs = num_covs[batch_indices]
            
            # Train Discriminator
            for _ in range(config['nb_critic']):
                d_optimizer.zero_grad()
                
                # Generate fake data
                z = torch.randn(batch_size, config['latent_dim']).to(device)
                fake_data = generator(z, batch_cat_covs, batch_num_covs)
                
                # Calculate discriminator output for real and fake data
                real_validity = discriminator(real_data, batch_cat_covs, batch_num_covs)
                fake_validity = discriminator(fake_data.detach(), batch_cat_covs, batch_num_covs)
                
                # Calculate gradient penalty
                gp = compute_gradient_penalty(
                    discriminator,
                    real_data,
                    fake_data.detach(),
                    batch_cat_covs,
                    batch_num_covs,
                    device)
                
                # Calculate discriminator loss with gradient penalty
                d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gp
                
                d_loss.backward()
                d_optimizer.step()
                
                d_losses.append(d_loss.item())
            
            # Train Generator
            g_optimizer.zero_grad()
            
            # Generate fake data
            z = torch.randn(batch_size, config['latent_dim']).to(device)
            fake_data = generator(z, batch_cat_covs, batch_num_covs)
            
            # Calculate generator loss
            fake_validity = discriminator(fake_data, batch_cat_covs, batch_num_covs)
            g_loss = -torch.mean(fake_validity)
            
            g_loss.backward()
            g_optimizer.step()
            
            g_losses.append(g_loss.item())
            
            # Print progress every 10 batches
            if batch_idx % 10 == 0:
                print(f"  Batch [{batch_idx}/{total_batches}] " \
                      f"D_loss: {d_loss.item():.4f}, " \
                      f"G_loss: {g_loss.item():.4f}")
        
        # Print epoch summary
        avg_d_loss = np.mean(d_losses)
        avg_g_loss = np.mean(g_losses)
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  Average D_loss: {avg_d_loss:.4f}")
        print(f"  Average G_loss: {avg_g_loss:.4f}")
        
        # Log metrics
        if wandb.run is not None:
            wandb.log({
                'epoch': epoch,
                'd_loss': np.mean(d_losses),
                'g_loss': np.mean(g_losses)
            })
        
        # Evaluate and save model if needed
        if score_fn is not None and epoch % 10 == 0:
            score = score_fn(generator)
            print(f'Epoch {epoch}: Score = {score:.4f}')
        
        if save_fn is not None and epoch % 20 == 0:
            save_fn(generator, discriminator, epoch)

In [8]:
def main(selected_categories=None):
    """
    Train the GAN with selected categorical variables
    Args:
        selected_categories: List of column names to use as categorical variables.
                           If None, uses all columns except 'cell_id'
    """
    # Configuration
    CONFIG = {
        'epochs': 100,
        'latent_dim': 64,
        'batch_size': 32,
        'nb_layers': 3,
        'hdim': 256,
        'lr': 1e-4,
        'nb_critic': 5,
        'lambda_gp': 10  # Gradient penalty coefficient
    }
    
    # Device configuration
    if torch.cuda.is_available():
        device = torch.device('cuda')
    elif torch.backends.mps.is_available():
        device = torch.device('mps')
    else:
        device = torch.device('cpu')
    
    print(f"Using device: {device}")
    
    # Load data
    data_path = "/Users/guyshani/Documents/PHD/Aim_2/10x_data_mouse/20_1_2025__normalized/"
    
    # Load expression matrix
    # matrix with cells as columns and genes as rows
    with h5py.File(data_path+'combined_normalized_data.h5', 'r') as f:
        x_train = f['matrix'][:]
    
    # Load all categorical variables from single file
    cat_data = pd.read_csv(data_path+'combined_metadata.csv', sep=';')
    print("Categorical data shape:", cat_data.shape)
    print("Available categorical variables:", [col for col in cat_data.columns if col != 'cell_id'])
    
    # Determine which categories to use
    if selected_categories is None:
        # Use all columns except cell_id
        categories_to_use = [col for col in cat_data.columns if col != 'cell_id']
    else:
        # Validate selected categories
        invalid_categories = [cat for cat in selected_categories if cat not in cat_data.columns]
        if invalid_categories:
            raise ValueError(f"Invalid categories: {invalid_categories}")
        categories_to_use = selected_categories
    
    print(f"\nUsing categorical variables: {categories_to_use}")
    
    # Create dictionaries and inverse mappings for categorical variables
    cat_dicts = []
    encoded_covs = []
    
    # Process each selected column as a categorical variable
    for column in categories_to_use:
        # Get the column data
        cat_vec = cat_data[column]
        print(f"\nProcessing categorical variable: {column}")
        
        # Create list of unique category names, sorted
        dict_inv = np.array(list(sorted(set(cat_vec.values))))
        dict_map = {t: i for i, t in enumerate(dict_inv)}
        cat_dicts.append(dict_inv)
        
        # Convert categorical variables to integers
        encoded = np.vectorize(lambda t: dict_map[t])(cat_vec)
        encoded = encoded.reshape(-1, 1)  # Reshape to column vector
        encoded_covs.append(encoded)
        
        print(f"Categories in {column}:", dict_inv)
        print(f"Number of categories:", len(dict_inv))
    
    # Combine all categorical covariates
    cat_covs = np.hstack(encoded_covs)
    print("\nCombined categorical covariates shape:", cat_covs.shape)
    
    # Load numerical covariates (currently empty)
    num_covs = np.zeros((x_train.shape[0], 0))
    
    # Convert data to PyTorch tensors and move to device
    x_train = torch.tensor(x_train, dtype=torch.float32)  # Keep on CPU for DataLoader
    
    # Create data loader
    train_dataset = TensorDataset(x_train)
    train_loader = DataLoader(
        train_dataset, 
        batch_size=CONFIG['batch_size'],
        shuffle=True,
        drop_last=True
    )
    
    # Initialize models
    vocab_sizes = [len(c) for c in cat_dicts]
    print("\nVocabulary sizes for categorical variables:", vocab_sizes)
    nb_numeric = num_covs.shape[-1]
    x_dim = x_train.shape[-1]
    
    generator = Generator(
        x_dim=x_dim,
        vocab_sizes=vocab_sizes,
        nb_numeric=nb_numeric,
        h_dims=[CONFIG['hdim']] * CONFIG['nb_layers'],
        z_dim=CONFIG['latent_dim']).to(device)
    
    discriminator = Discriminator(
        x_dim=x_dim,
        vocab_sizes=vocab_sizes,
        nb_numeric=nb_numeric,
        h_dims=[CONFIG['hdim']] * CONFIG['nb_layers']).to(device)
    
    # Define save function
    def save_models(generator, discriminator, epoch):
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        # create save directory
        categories_str = "+".join(categories_to_use)
        save_dir = os.path.join(data_path, "saved_models")
        os.makedirs(save_dir, exist_ok=True)
        # Create run folder
        run_dir = os.path.join(save_dir, f"run_{timestamp}_{categories_str}")
        os.makedirs(run_dir, exist_ok=True)

        # Save model initialization parameters
        model_config = {
            'x_dim': x_dim,
            'vocab_sizes': vocab_sizes,
            'nb_numeric': nb_numeric,
            'h_dims': [CONFIG['hdim']] * CONFIG['nb_layers'],
            'z_dim': CONFIG['latent_dim'],
            'categories': categories_to_use,
            'training_config': CONFIG}
        config_path = os.path.join(run_dir, 'model_config.json')
        with open(config_path, 'w') as f:
            json.dump(model_config, f, indent=4)
        
        # Save generator
        generator_path = os.path.join(run_dir, f"generator_{timestamp}_{categories_str}_epoch_{epoch+1}.pt")
        torch.save(generator.state_dict(), generator_path)
        
        
        # Save discriminator
        discriminator_path = os.path.join(run_dir, f"discriminator_{timestamp}_{categories_str}_epoch_{epoch+1}.pt")
        torch.save(discriminator.state_dict(), discriminator_path)
        
        print(f"\nModels saved at epoch {epoch + 1}:")
        print(f"Generator: {generator_path}")
        print(f"Discriminator: {discriminator_path}")
        
        # Log to wandb
        if wandb.run is not None:
            wandb.save(generator_path)
            wandb.save(discriminator_path)

    # Initialize wandb with unique run name
    run_name = f"run_{int(time.time())}"  # Uses timestamp for unique name
    wandb.init(
        project='adversarial_gene_expr',
        config=CONFIG,
        name=run_name,
        reinit=True  # Ensures new run each time
    )
    
    # Add selected categories to wandb config
    wandb.config.update({'selected_categories': categories_to_use})
    
    # Train model
    train_gan(
        generator=generator,
        discriminator=discriminator,
        dataloader=train_loader,
        cat_covs=cat_covs,
        num_covs=num_covs,
        config=CONFIG,
        device=device,
        save_fn=save_models
        #save_fn=None
    )

if __name__ == '__main__':
    # Example usage:
    # Use specific categories:
    main(selected_categories=['dataset','cell_type'])
    
    # Or use all available categories:
    # main()

Using device: mps
Categorical data shape: (41588, 4)
Available categorical variables: ['dataset', 'cluster', 'cell_type']

Using categorical variables: ['dataset', 'cell_type']

Processing categorical variable: dataset
Categories in dataset: ['dataset1' 'dataset2' 'dataset3' 'dataset4' 'dataset5' 'dataset6'
 'dataset7']
Number of categories: 7

Processing categorical variable: cell_type
Categories in cell_type: ['B cells' 'Dendritic cells' 'Endothelial cells' 'Erythrocytes'
 'Fibroblasts' 'Granulocytes' 'Macrophages' 'Monocytes' 'NK cells'
 'T cells']
Number of categories: 10

Combined categorical covariates shape: (41588, 2)

Vocabulary sizes for categorical variables: [7, 10]


Starting training for 100 epochs...
Total batches per epoch: 1299
Using device: mps

Epoch [1/100]
  Batch [0/1299] D_loss: -0.4650, G_loss: 0.9766
  Batch [10/1299] D_loss: -11.8101, G_loss: -5.3853
  Batch [20/1299] D_loss: -9.4431, G_loss: -7.2299
  Batch [30/1299] D_loss: -9.3688, G_loss: -8.2170
  Batch [40/1299] D_loss: -5.9184, G_loss: -5.2426
  Batch [50/1299] D_loss: -6.5075, G_loss: -9.0455
  Batch [60/1299] D_loss: -6.4068, G_loss: -7.1682
  Batch [70/1299] D_loss: -4.7488, G_loss: -5.1376
  Batch [80/1299] D_loss: -3.2610, G_loss: -6.4150
  Batch [90/1299] D_loss: -4.3258, G_loss: -4.0177
  Batch [100/1299] D_loss: -4.1959, G_loss: -7.8522
  Batch [110/1299] D_loss: -3.8843, G_loss: -6.2385
  Batch [120/1299] D_loss: -3.6678, G_loss: -6.3344
  Batch [130/1299] D_loss: -4.6632, G_loss: -3.0054
  Batch [140/1299] D_loss: -3.2591, G_loss: -3.4850
  Batch [150/1299] D_loss: -4.9756, G_loss: -4.9019
  Batch [160/1299] D_loss: -3.5521, G_loss: -4.1244
  Batch [170/1299] D_loss: -




Epoch 1 Summary:
  Average D_loss: -2.9655
  Average G_loss: -1.1891

Models saved at epoch 1:
Generator: /Users/guyshani/Documents/PHD/Aim_2/10x_data_mouse/20_1_2025__normalized/saved_models/run_20250121_121949_dataset+cell_type/generator_20250121_121949_dataset+cell_type_epoch_1.pt
Discriminator: /Users/guyshani/Documents/PHD/Aim_2/10x_data_mouse/20_1_2025__normalized/saved_models/run_20250121_121949_dataset+cell_type/discriminator_20250121_121949_dataset+cell_type_epoch_1.pt

Epoch [2/100]
  Batch [0/1299] D_loss: -2.1918, G_loss: 0.3683
  Batch [10/1299] D_loss: -3.2372, G_loss: -1.5384
  Batch [20/1299] D_loss: -2.7977, G_loss: -0.4146
  Batch [30/1299] D_loss: -4.1501, G_loss: -1.4355
  Batch [40/1299] D_loss: -3.3868, G_loss: 1.7798
  Batch [50/1299] D_loss: -4.0314, G_loss: -0.5715
  Batch [60/1299] D_loss: -2.1887, G_loss: 0.4195
  Batch [70/1299] D_loss: -3.9844, G_loss: 1.5717
  Batch [80/1299] D_loss: -2.6740, G_loss: -1.4374
  Batch [90/1299] D_loss: -3.8382, G_loss: -0.7

In [28]:
### Functions for data gneration
def inspect_generator_dims(generator):
    """
    Inspect the generator's dimensions and architecture
    
    Parameters:
        generator: Generator model
    
    Returns:
        dict containing dimension information
    """
    # Get embedding dimensions
    embedding_dims = [emb.embedding_dim for emb in generator.embeddings]
    total_embedding_dim = sum(embedding_dims)
    
    # Get first layer dimension
    first_layer_in_dim = generator.network[0].in_features
    
    return {
        'embedding_dims': embedding_dims,
        'total_embedding_dim': total_embedding_dim,
        'first_layer_in_dim': first_layer_in_dim,
        'recommended_latent_dim': first_layer_in_dim - total_embedding_dim
    }

def generate_expression_profiles(generator, n_samples, dataset_category, cell_type_category, device='mps', debug=False):
    """
    Generate gene expression profiles using the trained cWGAN generator
    
    Parameters:
        generator: Trained Generator model
        n_samples: Number of profiles to generate
        dataset_category: Integer indicating which dataset category to generate
        cell_type_category: Integer indicating which cell type to generate
        device: Device to run generation on ('cuda', 'mps', or 'cpu')
        debug: If True, print debugging information
    
    Returns:
        numpy array of generated expression profiles with shape (n_samples, n_genes)
    """
    # Set generator to eval mode
    generator.eval()
    
    # Inspect dimensions
    dims = inspect_generator_dims(generator)
    
    if debug:
        print("Generator dimensions:")
        for k, v in dims.items():
            print(f"{k}: {v}")
    
    # Create latent vectors
    latent_dim = dims['recommended_latent_dim']
    z = torch.randn(n_samples, latent_dim, device=device)
    
    if debug:
        print(f"\nLatent vector shape: {z.shape}")
    
    # Create categorical condition tensor with dataset and cell type
    num_embeddings = len(generator.embeddings)
    cat_covs = torch.zeros((n_samples, num_embeddings), dtype=torch.long, device=device)
    cat_covs[:, 0] = dataset_category  # Set dataset category
    cat_covs[:, 1] = cell_type_category  # Set cell type category
    
    if debug:
        print(f"Categorical covariates shape: {cat_covs.shape}")
        print(f"Number of embedding layers: {num_embeddings}")
    
    # Create empty numeric covariates tensor
    num_covs = torch.zeros((n_samples, 0), device=device)
    
    # Generate samples
    try:
        with torch.no_grad():
            # Get embeddings
            embeddings = [emb(cat_covs[:, i]) for i, emb in enumerate(generator.embeddings)]
            embedded = torch.cat(embeddings, dim=1)
            
            if debug:
                print(f"Embedded shape: {embedded.shape}")
            
            # Concatenate inputs
            gen_input = torch.cat([z, embedded, num_covs], dim=1)
            
            if debug:
                print(f"Generator input shape: {gen_input.shape}")
                print(f"First layer input dim: {generator.network[0].in_features}")
                print(f"First layer weight shape: {generator.network[0].weight.shape}")
            
            # Generate samples
            fake_samples = generator.network(gen_input)
            
    except RuntimeError as e:
        print("\nError during generation:")
        print(e)
        print("\nGenerator architecture:")
        print(generator)
        raise
    
    # Convert to numpy array
    return fake_samples.cpu().numpy()

def generate_and_save_profiles(generator, samples_per_combination, save_path, device='mps', debug=False):
    """
    Generate expression profiles with custom sample counts for each dataset and cell type combination
    
    Parameters:
        generator: Trained Generator model
        samples_per_combination: Dictionary with tuples as keys (dataset_num, cell_type_num) 
                               and number of samples as values
        save_path: Path to save the generated profiles
        device: Device to run generation on ('cuda', 'mps', or 'cpu')
        debug: If True, print debugging information
    """
    all_samples = []
    all_categories = []
    
    # Cell type mapping
    cell_type_names = {
        0: 'B cells',
        1: 'Dendritic cells',
        2: 'Endothelial cells',
        3: 'Erythrocytes',
        4: 'Fibroblasts',
        5: 'Granulocytes',
        6: 'Macrophages',
        7: 'Monocytes',
        8: 'NK cells',
        9: 'T cells'
    }
    
    # Generate samples for specified combinations
    for (dataset_num, cell_type_num), n_samples in samples_per_combination.items():
        dataset_category = dataset_num - 1  # Convert dataset number (1-7) to category index (0-6)
        cell_type_category = cell_type_num
        cell_type_name = cell_type_names[cell_type_num]
        
        if debug:
            print(f"\nGenerating {n_samples} samples for dataset{dataset_num}, {cell_type_name}")
        
        samples = generate_expression_profiles(
            generator, 
            n_samples, 
            dataset_category,
            cell_type_category,
            device,
            debug=debug
        )
        all_samples.append(samples)
        all_categories.extend([f'dataset{dataset_num}_{cell_type_name}'] * n_samples)

    print("Save location: "+str(save_path))

    # Combine all samples
    all_samples = np.vstack(all_samples)
    
    # Save generated profiles
    np.save(f'{save_path}_profiles.npy', all_samples)
    
    # Save category labels
    with open(f'{save_path}_categories.txt', 'w') as f:
        for category in all_categories:
            f.write(f'{category}\n')

    
            
    return all_samples, all_categories

In [None]:
# Generate data

# Set directories
# 3 hidden layers
data_path = "~/Documents/PHD/Aim_2/10x_data_mouse/20_1_2025__normalized/"
run_dir = "/Users/guyshani/Documents/PHD/Aim_2/10x_data_mouse/20_1_2025__normalized/saved_models/run_20250121_150654_dataset_cell_type"
generator_model = "generator_20250121_150654_dataset_cell_type_epoch_61.pt"

# Device configuration
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
    
print(f"Using device: {device}")

# Load configuration
config_path = os.path.join(run_dir, 'model_config.json')
with open(config_path, 'r') as f:
    model_config = json.load(f)
    
# Initialize models with saved configuration
generator = Generator(
    x_dim=model_config['x_dim'],
    vocab_sizes=model_config['vocab_sizes'],
    nb_numeric=model_config['nb_numeric'],
    h_dims=model_config['h_dims'],
    z_dim=model_config['z_dim']).to(device)
    
discriminator = Discriminator(
    x_dim=model_config['x_dim'],
    vocab_sizes=model_config['vocab_sizes'],
    nb_numeric=model_config['nb_numeric'],
    h_dims=model_config['h_dims']).to(device)



#discriminator_path = os.path.join(run_dir, "discriminator.pt")
#discriminator.load_state_dict(torch.load(discriminator_path, map_location=device, weights_only=True))
generator_path = os.path.join(run_dir, generator_model)
generator.load_state_dict(torch.load(generator_path, map_location=device, weights_only=True))

meta = pd.read_csv(data_path+'combined_metadata.csv', sep=';')
# create a dict listing datasets and celltypes
counts = meta.groupby(['dataset', 'cell_type']).size().to_dict()

# Convert the dict keys from tuple of strings to tuple of numbers
samples_dict = {}
for (dataset, cell_type), count in counts.items():
    # Extract dataset number (assuming format 'dataset1', 'dataset2', etc.)
    dataset_num = int(dataset.replace('dataset', ''))
    
    # Get cell type number (assuming you have a mapping for cell types to numbers)
    cell_type_map = {
        'B cells': 0,
        'Dendritic cells': 1,
        'Endothelial cells': 2,
        'Erythrocytes': 3,
        'Fibroblasts': 4,
        'Granulocytes': 5,
        'Macrophages': 6,
        'Monocytes': 7,
        'NK cells': 8,
        'T cells': 9
    }
    cell_type_num = cell_type_map[cell_type]
    
    # Add to new dictionary with numerical tuple as key
    samples_dict[(dataset_num, cell_type_num)] = count


#samples_dict = {
#    (1, 0): 1000,  # 1000 samples from dataset1, cell type 0
#    (1, 1): 500,   # 500 samples from dataset1, cell type 1
#    (3, 2): 750    # 750 samples from dataset3, cell type 2
#}

## this function will generate data based on the relative quantities in samples_dict
## and will generate a profiles.npy (containing expression profile for each sample)
## and a categories.txt (containing labels for each samples) files
all_samples, categories = generate_and_save_profiles(
    generator,
    samples_per_combination=samples_dict,
    save_path=run_dir+'_generated_data',
    debug=False
)

Using device: mps
Save location: /Users/guyshani/Documents/PHD/Aim_2/10x_data_mouse/20_1_2025__normalized/saved_models/run_20250121_150654_dataset_cell_type_generated_data


In [30]:
## Load and analyze generated data
# Load the generated profiles
profiles = np.load(run_dir + '_generated_data_profiles.npy')

# Load categories
with open(run_dir + '_generated_data_categories.txt', 'r') as f:
    categories = [line.strip() for line in f]

# Convert to pandas DataFrame
df = pd.DataFrame(profiles)

# Add categories as a column
df['labels'] = categories

## 
# Add separate dataset and cell_type columns
df['dataset'] = df['labels'].apply(lambda x: x.split('_')[0])
df['cell_type'] = df['labels'].apply(lambda x: x.split('_')[1])
# drop the combined labels colum
df=df.drop(['labels'], axis=1)

# Save a csv file
df.to_csv(f'{run_dir}_generated_data.csv', index=False)
#
with h5py.File(f'{run_dir}_generated_data.h5', 'w') as f:
    # Save as matrix - select only numerical values
    f.create_dataset('matrix', data=df.select_dtypes(include=[np.number]).values)  # or df.to_numpy()

# If you need to save the categorical information separately:
df_labels = df.select_dtypes(exclude=[np.number])
df_labels.to_csv(f'{run_dir}_generated_labels.csv', index=True)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,991,992,993,994,995,996,997,998,999,labels
0,-0.231811,5.012554,4.040084,3.683613,0.436647,3.761825,0.943701,3.246516,3.937541,1.443107,...,0.149865,0.011442,0.036922,-0.018661,-0.022489,0.022302,0.062507,0.022485,-0.021453,dataset1_B cells
1,0.001106,6.723053,4.880553,4.649127,0.605724,-1.628438,0.288405,4.578445,3.061715,1.110912,...,0.019799,0.044679,-0.010733,0.033751,0.013674,0.077913,-0.068833,-0.030298,-0.013414,dataset1_B cells
2,3.027212,5.233882,3.615242,3.108646,1.175094,3.089045,1.575116,2.513422,3.686309,2.185993,...,0.097830,0.009697,0.010542,0.000982,-0.001769,0.073058,0.056470,0.017281,0.005731,dataset1_B cells
3,0.265750,5.726051,4.193067,4.037908,0.014425,2.711107,-0.021076,3.636868,2.291311,2.954214,...,0.052932,0.059226,0.050422,-0.082419,0.037294,-0.017001,-0.044325,-0.001857,0.038056,dataset1_B cells
4,4.428651,1.383675,0.202624,0.219082,0.698924,4.395520,3.265499,0.442863,3.801274,-0.605415,...,0.140012,0.399742,0.205119,0.209274,-0.028910,0.215225,0.685272,0.397794,0.269825,dataset1_B cells
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
41583,-0.043463,4.471334,3.290458,2.960582,0.126797,4.563755,1.354325,2.588662,4.697986,1.769787,...,0.151689,0.023442,0.021533,-0.071441,0.003788,-0.011879,-0.017619,0.025726,0.027298,dataset7_T cells
41584,6.171608,0.434414,-0.085250,-0.186500,3.630826,2.031105,3.854700,-0.312018,3.564202,-0.271754,...,0.071843,0.059604,0.076521,0.019365,-0.059938,0.051268,-0.016612,0.197256,0.061592,dataset7_T cells
41585,0.575687,0.140930,-0.099514,-0.132665,-0.021235,0.022371,-0.057032,-0.097039,1.081624,-0.245499,...,0.066604,0.021906,0.117857,0.044744,-0.019693,0.005893,0.024329,-0.012449,-0.008558,dataset7_T cells
41586,0.902666,5.474178,3.976243,3.669129,0.097642,-0.206113,0.381000,3.302111,1.744207,2.544704,...,0.046760,0.017676,0.001327,0.035773,0.013885,0.066925,-0.038613,0.012564,-0.001759,dataset7_T cells
