In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
import wandb
import h5py
import time

In [2]:
### Load data
# Load expression matrix (csv file with rows as cells and columns as genes)
data_path = "/Users/guyshani/Documents/PHD/Aim_2/10x_data_mouse/"
# Load expression matrix
with h5py.File(data_path+'train_data_1dataset.h5', 'r') as f:
    matrix = f['matrix'][:]

# each row is a cell
# matrix[:,0]

x_train = matrix


# Load cluster info
cluster_vec = pd.read_csv(data_path+'train_data_1dataset_cluster.csv').T
cluster_vec
# Load numerical covariates
num_covs = 0

In [None]:
#set(cluster_vec.values.flatten())
x_train

In [None]:
# Example raw data
# tissues = np.array(['liver', 'brain', 'liver', 'kidney', 'brain'])

# Create dictionaries and inverse mappings
cat_dicts = []

## Cluster covariate
# Create list of unique cluster names, sorted
cluster_dict_inv = np.array(list(sorted(set(cluster_vec.values.flatten()))))  # ['brain', 'kidney', 'liver']
cluster_dict = {t: i for i, t in enumerate(cluster_dict_inv)}  # {'brain': 0, 'kidney': 1, 'liver': 2}
cat_dicts.append(cluster_dict_inv)

# Convert categorical variables to integers
clusters_encoded = np.vectorize(lambda t: cluster_dict[t])(cluster_vec)

print("Original clusters:", cluster_vec)
print("Encoded clusters:", clusters_encoded)
print("Cluster mapping:", cluster_dict)

print("\ncat_dicts:", cat_dicts)

# This gives us vocab_sizes for model initialization
vocab_sizes = [len(c) for c in cat_dicts]  # [3, 2] (3 tissue types, 2 dataset types)
print("\nvocab_sizes:", vocab_sizes)

# assign categorical covariates
cat_covs = clusters_encoded

In [6]:
class Generator(nn.Module):
    def __init__(self, x_dim, vocab_sizes, nb_numeric, h_dims, z_dim):
        """
        Generator network for conditional GAN
        Args:
            x_dim: Dimension of output data
            vocab_sizes: List of vocabulary sizes for each categorical variable
            nb_numeric: Number of numeric covariates
            h_dims: List of hidden dimensions
            z_dim: Dimension of latent noise vector
        """
        super(Generator, self).__init__()
        
        # Embedding layers for categorical variables
        self.embeddings = nn.ModuleList([
            nn.Embedding(vocab_size, min(50, vocab_size)) 
            for vocab_size in vocab_sizes
        ])
        
        # Calculate total embedding dimension
        embedding_dim = sum(min(50, vocab_size) for vocab_size in vocab_sizes)
        
        # Input dimension is latent dim + embedding dim + numeric covariates
        input_dim = z_dim + embedding_dim + nb_numeric
        
        # Build generator network
        layers = []
        current_dim = input_dim
        
        # Add hidden layers
        for h_dim in h_dims:
            layers.extend([
                nn.Linear(current_dim, h_dim),
                nn.BatchNorm1d(h_dim),
                nn.ReLU()
            ])
            current_dim = h_dim
        
        # Output layer
        layers.append(nn.Linear(current_dim, x_dim))
        
        self.network = nn.Sequential(*layers)

    def forward(self, z, cat_covs, num_covs):
        # Process categorical covariates through embeddings
        embeddings = [emb(cat_covs[:, i]) for i, emb in enumerate(self.embeddings)]
        embedded = torch.cat(embeddings, dim=1)
        
        # Concatenate all inputs
        gen_input = torch.cat([z, embedded, num_covs], dim=1)
        
        # Generate output
        return self.network(gen_input)

In [7]:
class Discriminator(nn.Module):
    def __init__(self, x_dim, vocab_sizes, nb_numeric, h_dims):
        """
        Discriminator network for conditional GAN
        Args:
            x_dim: Dimension of input data
            vocab_sizes: List of vocabulary sizes for each categorical variable
            nb_numeric: Number of numeric covariates
            h_dims: List of hidden dimensions
        """
        super(Discriminator, self).__init__()
        
        # Embedding layers for categorical variables
        self.embeddings = nn.ModuleList([
            nn.Embedding(vocab_size, min(50, vocab_size))
            for vocab_size in vocab_sizes
        ])
        
        # Calculate total embedding dimension
        embedding_dim = sum(min(50, vocab_size) for vocab_size in vocab_sizes)
        
        # Input dimension is data dim + embedding dim + numeric covariates
        input_dim = x_dim + embedding_dim + nb_numeric
        
        # Build discriminator network
        layers = []
        current_dim = input_dim
        
        # Add hidden layers
        for h_dim in h_dims:
            layers.extend([
                nn.Linear(current_dim, h_dim),
                nn.LeakyReLU(0.2),
                nn.Dropout(0.3)
            ])
            current_dim = h_dim
        
        # Output layer
        layers.append(nn.Linear(current_dim, 1))
        
        self.network = nn.Sequential(*layers)

    def forward(self, x, cat_covs, num_covs):
        # Process categorical covariates through embeddings
        embeddings = [emb(cat_covs[:, i]) for i, emb in enumerate(self.embeddings)]
        embedded = torch.cat(embeddings, dim=1)
        
        # Concatenate all inputs
        disc_input = torch.cat([x, embedded, num_covs], dim=1)
        
        # Generate output
        return self.network(disc_input)


In [12]:
def train_gan(generator, discriminator, dataloader, cat_covs, num_covs, 
              config, device, score_fn=None, save_fn=None):
    """
    Train the conditional GAN with progress tracking and proper device handling
    """
    # Optimizers
    g_optimizer = optim.RMSprop(generator.parameters(), lr=config['lr'])
    d_optimizer = optim.RMSprop(discriminator.parameters(), lr=config['lr'])
    
    # Convert covariates to tensors and move to device
    cat_covs = torch.tensor(cat_covs, dtype=torch.long).to(device)
    num_covs = torch.tensor(num_covs, dtype=torch.float32).to(device)
    
    total_batches = len(dataloader)
    
    print(f"Starting training for {config['epochs']} epochs...")
    print(f"Total batches per epoch: {total_batches}")
    print(f"Using device: {device}")
    
    for epoch in range(config['epochs']):
        d_losses = []
        g_losses = []
        print(f"\nEpoch [{epoch+1}/{config['epochs']}]")
        
        for batch_idx, (real_data,) in enumerate(dataloader):
            batch_size = real_data.size(0)
            
            # Move real data to device
            real_data = real_data.to(device)
            
            # Get random batch of categorical and numerical covariates
            batch_indices = torch.randint(0, cat_covs.size(0), (batch_size,))
            batch_cat_covs = cat_covs[batch_indices]
            batch_num_covs = num_covs[batch_indices]
            
            # Train Discriminator
            for _ in range(config['nb_critic']):
                d_optimizer.zero_grad()
                
                # Generate fake data
                z = torch.randn(batch_size, config['latent_dim']).to(device)
                fake_data = generator(z, batch_cat_covs, batch_num_covs)
                
                # Calculate discriminator loss
                real_validity = discriminator(real_data, batch_cat_covs, batch_num_covs)
                fake_validity = discriminator(fake_data.detach(), batch_cat_covs, batch_num_covs)
                
                d_loss = -(torch.mean(real_validity) - torch.mean(fake_validity))
                d_loss.backward()
                d_optimizer.step()
                
                # Clip discriminator weights (Wasserstein GAN)
                for p in discriminator.parameters():
                    p.data.clamp_(-0.01, 0.01)
                    
                d_losses.append(d_loss.item())
            
            # Train Generator
            g_optimizer.zero_grad()
            
            # Generate fake data
            z = torch.randn(batch_size, config['latent_dim']).to(device)
            fake_data = generator(z, batch_cat_covs, batch_num_covs)
            
            # Calculate generator loss
            fake_validity = discriminator(fake_data, batch_cat_covs, batch_num_covs)
            g_loss = -torch.mean(fake_validity)
            
            g_loss.backward()
            g_optimizer.step()
            
            g_losses.append(g_loss.item())
            
            # Print progress every 10 batches
            if batch_idx % 10 == 0:
                print(f"  Batch [{batch_idx}/{total_batches}] " \
                      f"D_loss: {d_loss.item():.4f}, " \
                      f"G_loss: {g_loss.item():.4f}")
        
        # Print epoch summary
        avg_d_loss = np.mean(d_losses)
        avg_g_loss = np.mean(g_losses)
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  Average D_loss: {avg_d_loss:.4f}")
        print(f"  Average G_loss: {avg_g_loss:.4f}")
        
        # Log metrics
        if wandb.run is not None:
            wandb.log({
                'epoch': epoch,
                'd_loss': np.mean(d_losses),
                'g_loss': np.mean(g_losses)
            })
        
        # Evaluate and save model if needed
        if score_fn is not None and epoch % 10 == 0:
            score = score_fn(generator)
            print(f'Epoch {epoch}: Score = {score:.4f}')
            
        if save_fn is not None and epoch % 100 == 0:
            save_fn(generator, discriminator)

In [19]:
def main(selected_categories=None):
    """
    Train the GAN with selected categorical variables
    Args:
        selected_categories: List of column names to use as categorical variables.
                           If None, uses all columns except 'cell_id'
    """
    # Configuration
    CONFIG = {
        'epochs': 200,
        'latent_dim': 64,
        'batch_size': 32,
        'nb_layers': 2,
        'hdim': 256,
        'lr': 5e-4,
        'nb_critic': 5
    }
    
    # Device configuration
    if torch.cuda.is_available():
        device = torch.device('cuda')
    elif torch.backends.mps.is_available():
        device = torch.device('mps')
    else:
        device = torch.device('cpu')
    
    print(f"Using device: {device}")
    
    # Load data
    data_path = "/Users/guyshani/Documents/PHD/Aim_2/10x_data_mouse/"
    
    # Load expression matrix
    with h5py.File(data_path+'combined_scaled_data.h5', 'r') as f:
        x_train = f['matrix'][:]
    
    # Load all categorical variables from single file
    cat_data = pd.read_csv(data_path+'combined_metadata.csv', sep=';')
    print("Categorical data shape:", cat_data.shape)
    print("Available categorical variables:", [col for col in cat_data.columns if col != 'cell_id'])
    
    # Determine which categories to use
    if selected_categories is None:
        # Use all columns except cell_id
        categories_to_use = [col for col in cat_data.columns if col != 'cell_id']
    else:
        # Validate selected categories
        invalid_categories = [cat for cat in selected_categories if cat not in cat_data.columns]
        if invalid_categories:
            raise ValueError(f"Invalid categories: {invalid_categories}")
        categories_to_use = selected_categories
    
    print(f"\nUsing categorical variables: {categories_to_use}")
    
    # Create dictionaries and inverse mappings for categorical variables
    cat_dicts = []
    encoded_covs = []
    
    # Process each selected column as a categorical variable
    for column in categories_to_use:
        # Get the column data
        cat_vec = cat_data[column]
        print(f"\nProcessing categorical variable: {column}")
        
        # Create list of unique category names, sorted
        dict_inv = np.array(list(sorted(set(cat_vec.values))))
        dict_map = {t: i for i, t in enumerate(dict_inv)}
        cat_dicts.append(dict_inv)
        
        # Convert categorical variables to integers
        encoded = np.vectorize(lambda t: dict_map[t])(cat_vec)
        encoded = encoded.reshape(-1, 1)  # Reshape to column vector
        encoded_covs.append(encoded)
        
        print(f"Categories in {column}:", dict_inv)
        print(f"Number of categories:", len(dict_inv))
    
    # Combine all categorical covariates
    cat_covs = np.hstack(encoded_covs)
    print("\nCombined categorical covariates shape:", cat_covs.shape)
    
    # Load numerical covariates (currently empty)
    num_covs = np.zeros((x_train.shape[0], 0))
    
    # Convert data to PyTorch tensors and move to device
    x_train = torch.tensor(x_train, dtype=torch.float32)  # Keep on CPU for DataLoader
    
    # Create data loader
    train_dataset = TensorDataset(x_train)
    train_loader = DataLoader(
        train_dataset, 
        batch_size=CONFIG['batch_size'],
        shuffle=True,
        drop_last=True
    )
    
    # Initialize models
    vocab_sizes = [len(c) for c in cat_dicts]
    print("\nVocabulary sizes for categorical variables:", vocab_sizes)
    nb_numeric = num_covs.shape[-1]
    x_dim = x_train.shape[-1]
    
    generator = Generator(
        x_dim=x_dim,
        vocab_sizes=vocab_sizes,
        nb_numeric=nb_numeric,
        h_dims=[CONFIG['hdim']] * CONFIG['nb_layers'],
        z_dim=CONFIG['latent_dim']
    ).to(device)
    
    discriminator = Discriminator(
        x_dim=x_dim,
        vocab_sizes=vocab_sizes,
        nb_numeric=nb_numeric,
        h_dims=[CONFIG['hdim']] * CONFIG['nb_layers']
    ).to(device)
    
    # Initialize wandb with unique run name
    run_name = f"run_{int(time.time())}"  # Uses timestamp for unique name
    wandb.init(
        project='adversarial_gene_expr',
        config=CONFIG,
        name=run_name,
        reinit=True  # Ensures new run each time
    )
    
    # Add selected categories to wandb config
    wandb.config.update({'selected_categories': categories_to_use})
    
    # Train model
    train_gan(
        generator=generator,
        discriminator=discriminator,
        dataloader=train_loader,
        cat_covs=cat_covs,
        num_covs=num_covs,
        config=CONFIG,
        device=device
    )

if __name__ == '__main__':
    # Example usage:
    # Use specific categories:
    main(selected_categories=['dataset'])
    
    # Or use all available categories:
    #main()

Using device: mps
Categorical data shape: (41588, 3)
Available categorical variables: ['dataset', 'cluster']

Using categorical variables: ['dataset']

Processing categorical variable: dataset
Categories in dataset: ['dataset1' 'dataset2' 'dataset3' 'dataset4' 'dataset5' 'dataset6'
 'dataset7']
Number of categories: 7

Combined categorical covariates shape: (41588, 1)

Vocabulary sizes for categorical variables: [7]


0,1
d_loss,▇██████████████████████████████████▁▂▂▃▃
epoch,▁▁▁▁▂▃▃▃▃▃▃▄▄▄▄▄▄▄▄▄▅▅▅▆▆▆▇▇▇▇▇████▁▂▂▂▂
g_loss,▃▃▃▃▂▃▂▃▂▃▂▃▂▁▃▂▂▃▂▃▃▂▃▂▂▂▃▂▃▂▃▃▂▁▂█▄█▄▂

0,1
d_loss,-14.09647
epoch,99.0
g_loss,8.20882


Starting training for 200 epochs...
Total batches per epoch: 31
Using device: mps

Epoch [1/200]
  Batch [0/31] D_loss: -163.3172, G_loss: 78.1496
  Batch [10/31] D_loss: -7.6385, G_loss: 18.5354
  Batch [20/31] D_loss: -76.1167, G_loss: 54.2831
  Batch [30/31] D_loss: -53.5064, G_loss: 14.7118

Epoch 1 Summary:
  Average D_loss: -24.2392
  Average G_loss: 19.4420

Epoch [2/200]
  Batch [0/31] D_loss: -63.6787, G_loss: 28.0441
  Batch [10/31] D_loss: -55.8071, G_loss: -8.5930
  Batch [20/31] D_loss: -40.7650, G_loss: -5.0894
  Batch [30/31] D_loss: -39.9867, G_loss: -0.8309

Epoch 2 Summary:
  Average D_loss: -30.0696
  Average G_loss: 5.1760

Epoch [3/200]
  Batch [0/31] D_loss: -31.2131, G_loss: -1.8148
  Batch [10/31] D_loss: -31.2604, G_loss: 13.7603
  Batch [20/31] D_loss: -52.0506, G_loss: 10.1782
  Batch [30/31] D_loss: -31.7023, G_loss: 0.6628

Epoch 3 Summary:
  Average D_loss: -21.7641
  Average G_loss: 5.8327

Epoch [4/200]
  Batch [0/31] D_loss: -52.2966, G_loss: 12.2898
  

In [None]:
def main():
    # Configuration
    CONFIG = {
        'epochs': 500,
        'latent_dim': 64,
        'batch_size': 32,
        'nb_layers': 2,
        'hdim': 256,
        'lr': 5e-4,
        'nb_critic': 5
    }
    
    # Device configuration
    if torch.cuda.is_available():
        device = torch.device('cuda')
    elif torch.backends.mps.is_available():
        device = torch.device('mps')
    else:
        device = torch.device('cpu')
    
    print(f"Using device: {device}")
    
    # Load data
    data_path = "/Users/guyshani/Documents/PHD/Aim_2/10x_data_mouse/"
    
    # Load expression matrix
    with h5py.File(data_path+'train_data_1dataset.h5', 'r') as f:
        x_train = f['matrix'][:]
    
    # Load cluster info
    cluster_vec = pd.read_csv(data_path+'train_data_1dataset_cluster.csv').T
    
    # Load numerical covariates (currently empty)
    num_covs = np.zeros((x_train.shape[0], 0))  # Changed to create empty numpy array
    
    # Create dictionaries and inverse mappings for categorical variables
    cat_dicts = []
    
    # Create list of unique cluster names, sorted
    cluster_dict_inv = np.array(list(sorted(set(cluster_vec.values.flatten()))))
    cluster_dict = {t: i for i, t in enumerate(cluster_dict_inv)}
    cat_dicts.append(cluster_dict_inv)
    
    # Convert categorical variables to integers
    clusters_encoded = np.vectorize(lambda t: cluster_dict[t])(cluster_vec)
    
    # Assign categorical covariates
    cat_covs = clusters_encoded
    
    # Convert data to PyTorch tensors and move to device
    x_train = torch.tensor(x_train, dtype=torch.float32)  # Keep on CPU for DataLoader
    
    # Create data loader
    train_dataset = TensorDataset(x_train)
    train_loader = DataLoader(
        train_dataset, 
        batch_size=CONFIG['batch_size'],
        shuffle=True,
        drop_last=True
    )
    
    # Initialize models
    vocab_sizes = [len(c) for c in cat_dicts]
    nb_numeric = num_covs.shape[-1]
    x_dim = x_train.shape[-1]
    
    generator = Generator(
        x_dim=x_dim,
        vocab_sizes=vocab_sizes,
        nb_numeric=nb_numeric,
        h_dims=[CONFIG['hdim']] * CONFIG['nb_layers'],
        z_dim=CONFIG['latent_dim']
    ).to(device)
    
    discriminator = Discriminator(
        x_dim=x_dim,
        vocab_sizes=vocab_sizes,
        nb_numeric=nb_numeric,
        h_dims=[CONFIG['hdim']] * CONFIG['nb_layers']
    ).to(device)
    
    # Initialize wandb
    wandb.init(project='adversarial_gene_expr', config=CONFIG)
    
    # Train model
    train_gan(
        generator=generator,
        discriminator=discriminator,
        dataloader=train_loader,
        cat_covs=cat_covs,
        num_covs=num_covs,
        config=CONFIG,
        device=device
    )

if __name__ == '__main__':
    main()