In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
import wandb
import h5py

In [2]:
### Load data
# Load expression matrix (csv file with rows as cells and columns as genes)
data_path = "/Users/guyshani/Documents/PHD/Aim_2/10x_data_mouse/"
# Load expression matrix
with h5py.File(data_path+'train_data_1dataset.h5', 'r') as f:
    matrix = f['matrix'][:]

# each row is a cell
# matrix[:,0]

x_train = matrix


# Load cluster info
cluster_vec = pd.read_csv(data_path+'train_data_1dataset_cluster.csv').T
cluster_vec
# Load numerical covariates
num_covs = 0

In [None]:
#set(cluster_vec.values.flatten())
x_train

In [None]:
# Example raw data
# tissues = np.array(['liver', 'brain', 'liver', 'kidney', 'brain'])

# Create dictionaries and inverse mappings
cat_dicts = []

## Cluster covariate
# Create list of unique cluster names, sorted
cluster_dict_inv = np.array(list(sorted(set(cluster_vec.values.flatten()))))  # ['brain', 'kidney', 'liver']
cluster_dict = {t: i for i, t in enumerate(cluster_dict_inv)}  # {'brain': 0, 'kidney': 1, 'liver': 2}
cat_dicts.append(cluster_dict_inv)

# Convert categorical variables to integers
clusters_encoded = np.vectorize(lambda t: cluster_dict[t])(cluster_vec)

print("Original clusters:", cluster_vec)
print("Encoded clusters:", clusters_encoded)
print("Cluster mapping:", cluster_dict)

print("\ncat_dicts:", cat_dicts)

# This gives us vocab_sizes for model initialization
vocab_sizes = [len(c) for c in cat_dicts]  # [3, 2] (3 tissue types, 2 dataset types)
print("\nvocab_sizes:", vocab_sizes)

# assign categorical covariates
cat_covs = clusters_encoded

In [6]:
class Generator(nn.Module):
    def __init__(self, x_dim, vocab_sizes, nb_numeric, h_dims, z_dim):
        """
        Generator network for conditional GAN
        Args:
            x_dim: Dimension of output data
            vocab_sizes: List of vocabulary sizes for each categorical variable
            nb_numeric: Number of numeric covariates
            h_dims: List of hidden dimensions
            z_dim: Dimension of latent noise vector
        """
        super(Generator, self).__init__()
        
        # Embedding layers for categorical variables
        self.embeddings = nn.ModuleList([
            nn.Embedding(vocab_size, min(50, vocab_size)) 
            for vocab_size in vocab_sizes
        ])
        
        # Calculate total embedding dimension
        embedding_dim = sum(min(50, vocab_size) for vocab_size in vocab_sizes)
        
        # Input dimension is latent dim + embedding dim + numeric covariates
        input_dim = z_dim + embedding_dim + nb_numeric
        
        # Build generator network
        layers = []
        current_dim = input_dim
        
        # Add hidden layers
        for h_dim in h_dims:
            layers.extend([
                nn.Linear(current_dim, h_dim),
                nn.BatchNorm1d(h_dim),
                nn.ReLU()
            ])
            current_dim = h_dim
        
        # Output layer
        layers.append(nn.Linear(current_dim, x_dim))
        
        self.network = nn.Sequential(*layers)

    def forward(self, z, cat_covs, num_covs):
        # Process categorical covariates through embeddings
        embeddings = [emb(cat_covs[:, i]) for i, emb in enumerate(self.embeddings)]
        embedded = torch.cat(embeddings, dim=1)
        
        # Concatenate all inputs
        gen_input = torch.cat([z, embedded, num_covs], dim=1)
        
        # Generate output
        return self.network(gen_input)

In [7]:
class Discriminator(nn.Module):
    def __init__(self, x_dim, vocab_sizes, nb_numeric, h_dims):
        """
        Discriminator network for conditional GAN
        Args:
            x_dim: Dimension of input data
            vocab_sizes: List of vocabulary sizes for each categorical variable
            nb_numeric: Number of numeric covariates
            h_dims: List of hidden dimensions
        """
        super(Discriminator, self).__init__()
        
        # Embedding layers for categorical variables
        self.embeddings = nn.ModuleList([
            nn.Embedding(vocab_size, min(50, vocab_size))
            for vocab_size in vocab_sizes
        ])
        
        # Calculate total embedding dimension
        embedding_dim = sum(min(50, vocab_size) for vocab_size in vocab_sizes)
        
        # Input dimension is data dim + embedding dim + numeric covariates
        input_dim = x_dim + embedding_dim + nb_numeric
        
        # Build discriminator network
        layers = []
        current_dim = input_dim
        
        # Add hidden layers
        for h_dim in h_dims:
            layers.extend([
                nn.Linear(current_dim, h_dim),
                nn.LeakyReLU(0.2),
                nn.Dropout(0.3)
            ])
            current_dim = h_dim
        
        # Output layer
        layers.append(nn.Linear(current_dim, 1))
        
        self.network = nn.Sequential(*layers)

    def forward(self, x, cat_covs, num_covs):
        # Process categorical covariates through embeddings
        embeddings = [emb(cat_covs[:, i]) for i, emb in enumerate(self.embeddings)]
        embedded = torch.cat(embeddings, dim=1)
        
        # Concatenate all inputs
        disc_input = torch.cat([x, embedded, num_covs], dim=1)
        
        # Generate output
        return self.network(disc_input)


In [8]:
def train_gan(generator, discriminator, dataloader, cat_covs, num_covs, 
              config, device, score_fn=None, save_fn=None):
    """
    Train the conditional GAN with progress tracking and proper device handling
    """
    # Optimizers
    g_optimizer = optim.RMSprop(generator.parameters(), lr=config['lr'])
    d_optimizer = optim.RMSprop(discriminator.parameters(), lr=config['lr'])
    
    # Convert covariates to tensors and move to device
    cat_covs = torch.tensor(cat_covs, dtype=torch.long).to(device)
    num_covs = torch.tensor(num_covs, dtype=torch.float32).to(device)
    
    total_batches = len(dataloader)
    
    print(f"Starting training for {config['epochs']} epochs...")
    print(f"Total batches per epoch: {total_batches}")
    print(f"Using device: {device}")
    
    for epoch in range(config['epochs']):
        d_losses = []
        g_losses = []
        print(f"\nEpoch [{epoch+1}/{config['epochs']}]")
        
        for batch_idx, (real_data,) in enumerate(dataloader):
            batch_size = real_data.size(0)
            
            # Move real data to device
            real_data = real_data.to(device)
            
            # Get random batch of categorical and numerical covariates
            batch_indices = torch.randint(0, cat_covs.size(0), (batch_size,))
            batch_cat_covs = cat_covs[batch_indices]
            batch_num_covs = num_covs[batch_indices]
            
            # Train Discriminator
            for _ in range(config['nb_critic']):
                d_optimizer.zero_grad()
                
                # Generate fake data
                z = torch.randn(batch_size, config['latent_dim']).to(device)
                fake_data = generator(z, batch_cat_covs, batch_num_covs)
                
                # Calculate discriminator loss
                real_validity = discriminator(real_data, batch_cat_covs, batch_num_covs)
                fake_validity = discriminator(fake_data.detach(), batch_cat_covs, batch_num_covs)
                
                d_loss = -(torch.mean(real_validity) - torch.mean(fake_validity))
                d_loss.backward()
                d_optimizer.step()
                
                # Clip discriminator weights (Wasserstein GAN)
                for p in discriminator.parameters():
                    p.data.clamp_(-0.01, 0.01)
                    
                d_losses.append(d_loss.item())
            
            # Train Generator
            g_optimizer.zero_grad()
            
            # Generate fake data
            z = torch.randn(batch_size, config['latent_dim']).to(device)
            fake_data = generator(z, batch_cat_covs, batch_num_covs)
            
            # Calculate generator loss
            fake_validity = discriminator(fake_data, batch_cat_covs, batch_num_covs)
            g_loss = -torch.mean(fake_validity)
            
            g_loss.backward()
            g_optimizer.step()
            
            g_losses.append(g_loss.item())
            
            # Print progress every 10 batches
            if batch_idx % 10 == 0:
                print(f"  Batch [{batch_idx}/{total_batches}] " \
                      f"D_loss: {d_loss.item():.4f}, " \
                      f"G_loss: {g_loss.item():.4f}")
        
        # Print epoch summary
        avg_d_loss = np.mean(d_losses)
        avg_g_loss = np.mean(g_losses)
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  Average D_loss: {avg_d_loss:.4f}")
        print(f"  Average G_loss: {avg_g_loss:.4f}")
        
        # Log metrics
        if wandb.run is not None:
            wandb.log({
                'epoch': epoch,
                'd_loss': np.mean(d_losses),
                'g_loss': np.mean(g_losses)
            })
        
        # Evaluate and save model if needed
        if score_fn is not None and epoch % 10 == 0:
            score = score_fn(generator)
            print(f'Epoch {epoch}: Score = {score:.4f}')
            
        if save_fn is not None and epoch % 100 == 0:
            save_fn(generator, discriminator)

In [10]:
def main():
    # Configuration
    CONFIG = {
        'epochs': 500,
        'latent_dim': 64,
        'batch_size': 32,
        'nb_layers': 2,
        'hdim': 256,
        'lr': 5e-4,
        'nb_critic': 5
    }
    
    # Device configuration
    if torch.cuda.is_available():
        device = torch.device('cuda')
    elif torch.backends.mps.is_available():
        device = torch.device('mps')
    else:
        device = torch.device('cpu')
    
    print(f"Using device: {device}")
    
    # Load data
    data_path = "/Users/guyshani/Documents/PHD/Aim_2/10x_data_mouse/"
    
    # Load expression matrix
    with h5py.File(data_path+'train_data_1dataset.h5', 'r') as f:
        x_train = f['matrix'][:]
    
    # Load cluster info
    cluster_vec = pd.read_csv(data_path+'train_data_1dataset_cluster.csv').T
    
    # Load numerical covariates (currently empty)
    num_covs = np.zeros((x_train.shape[0], 0))  # Changed to create empty numpy array
    
    # Create dictionaries and inverse mappings for categorical variables
    cat_dicts = []
    
    # Create list of unique cluster names, sorted
    cluster_dict_inv = np.array(list(sorted(set(cluster_vec.values.flatten()))))
    cluster_dict = {t: i for i, t in enumerate(cluster_dict_inv)}
    cat_dicts.append(cluster_dict_inv)
    
    # Convert categorical variables to integers
    clusters_encoded = np.vectorize(lambda t: cluster_dict[t])(cluster_vec)
    
    # Assign categorical covariates
    cat_covs = clusters_encoded
    
    # Convert data to PyTorch tensors and move to device
    x_train = torch.tensor(x_train, dtype=torch.float32)  # Keep on CPU for DataLoader
    
    # Create data loader
    train_dataset = TensorDataset(x_train)
    train_loader = DataLoader(
        train_dataset, 
        batch_size=CONFIG['batch_size'],
        shuffle=True,
        drop_last=True
    )
    
    # Initialize models
    vocab_sizes = [len(c) for c in cat_dicts]
    nb_numeric = num_covs.shape[-1]
    x_dim = x_train.shape[-1]
    
    generator = Generator(
        x_dim=x_dim,
        vocab_sizes=vocab_sizes,
        nb_numeric=nb_numeric,
        h_dims=[CONFIG['hdim']] * CONFIG['nb_layers'],
        z_dim=CONFIG['latent_dim']
    ).to(device)
    
    discriminator = Discriminator(
        x_dim=x_dim,
        vocab_sizes=vocab_sizes,
        nb_numeric=nb_numeric,
        h_dims=[CONFIG['hdim']] * CONFIG['nb_layers']
    ).to(device)
    
    # Initialize wandb
    wandb.init(project='adversarial_gene_expr', config=CONFIG)
    
    # Train model
    train_gan(
        generator=generator,
        discriminator=discriminator,
        dataloader=train_loader,
        cat_covs=cat_covs,
        num_covs=num_covs,
        config=CONFIG,
        device=device
    )

if __name__ == '__main__':
    main()

Using device: mps
Starting training for 500 epochs...
Total batches per epoch: 31
Using device: mps

Epoch [1/500]
  Batch [0/31] D_loss: -12.9184, G_loss: 6.4762
  Batch [10/31] D_loss: -1.8696, G_loss: -0.0597
  Batch [20/31] D_loss: -2.3705, G_loss: 0.5551
  Batch [30/31] D_loss: -0.9042, G_loss: -0.8179

Epoch 1 Summary:
  Average D_loss: -1.5992
  Average G_loss: 1.6461

Epoch [2/500]
  Batch [0/31] D_loss: -5.1307, G_loss: 3.7898
  Batch [10/31] D_loss: -6.8352, G_loss: 4.9523
  Batch [20/31] D_loss: -5.4994, G_loss: 1.7115
  Batch [30/31] D_loss: -4.7128, G_loss: 0.1949

Epoch 2 Summary:
  Average D_loss: -2.9820
  Average G_loss: 0.8077

Epoch [3/500]
  Batch [0/31] D_loss: -4.5063, G_loss: 0.3721
  Batch [10/31] D_loss: -3.5614, G_loss: 1.9659
  Batch [20/31] D_loss: -4.4366, G_loss: 0.8575
  Batch [30/31] D_loss: -3.4399, G_loss: -0.2558

Epoch 3 Summary:
  Average D_loss: -2.8837
  Average G_loss: 0.7646

Epoch [4/500]
  Batch [0/31] D_loss: -5.1230, G_loss: -0.6822
  Batch 