In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
import wandb
import h5py

In [2]:
### Load data
# Load expression matrix (csv file with rows as cells and columns as genes)
data_path = "/Users/guyshani/Documents/PHD/Aim_2/10x_data_mouse/"
# Load expression matrix
with h5py.File(data_path+'train_data_1dataset.h5', 'r') as f:
    matrix = f['matrix'][:]

# each row is a cell
# matrix[:,0]

x_train = matrix


# Load cluster info
cluster_vec = pd.read_csv(data_path+'train_data_1dataset_cluster.csv').T
cluster_vec
# Load numerical covariates
num_covs = 0

In [8]:
#set(cluster_vec.values.flatten())
x_train

array([[ 0.61580863, -0.28865543, -0.28865543, ...,  5.01038512,
         0.07573233, -0.28865543],
       [-0.28521495, -0.28521495, -0.28521495, ...,  5.22700772,
        -0.28521495, -0.28521495],
       [-0.199723  , -0.199723  , -0.199723  , ..., -0.199723  ,
        -0.199723  , -0.199723  ],
       ...,
       [-0.09257967, -0.09257967, -0.09257967, ..., -0.09257967,
        -0.09257967, -0.09257967],
       [-0.04014808, -0.04014808, -0.04014808, ..., -0.04014808,
        -0.04014808, -0.04014808],
       [ 0.42168485, -1.07463841, -1.07463841, ...,  1.16523765,
         1.36185656, -1.07463841]])

In [3]:
# Example raw data
# tissues = np.array(['liver', 'brain', 'liver', 'kidney', 'brain'])

# Create dictionaries and inverse mappings
cat_dicts = []

## Cluster covariate
# Create list of unique cluster names, sorted
cluster_dict_inv = np.array(list(sorted(set(cluster_vec.values.flatten()))))  # ['brain', 'kidney', 'liver']
cluster_dict = {t: i for i, t in enumerate(cluster_dict_inv)}  # {'brain': 0, 'kidney': 1, 'liver': 2}
cat_dicts.append(cluster_dict_inv)

# Convert categorical variables to integers
clusters_encoded = np.vectorize(lambda t: cluster_dict[t])(cluster_vec)

print("Original clusters:", cluster_vec)
print("Encoded clusters:", clusters_encoded)
print("Cluster mapping:", cluster_dict)

print("\ncat_dicts:", cat_dicts)

# This gives us vocab_sizes for model initialization
vocab_sizes = [len(c) for c in cat_dicts]  # [3, 2] (3 tissue types, 2 dataset types)
print("\nvocab_sizes:", vocab_sizes)

# assign categorical covariates
cat_covs = clusters_encoded

Original clusters:    0     1     2     3     4     5     6     7     8     9     ...  4084  \
4     1     1     1     2     0     1     2     0     1     0  ...     1   

   4085  4086  4087  4088  4089  4090  4091  4092  4093  
4     3     2     1     0     1     0     5     3     1  

[1 rows x 4094 columns]
Encoded clusters: [[1 1 1 ... 5 3 1]]
Cluster mapping: {np.int64(0): 0, np.int64(1): 1, np.int64(2): 2, np.int64(3): 3, np.int64(4): 4, np.int64(5): 5, np.int64(6): 6}

cat_dicts: [array([0, 1, 2, 3, 4, 5, 6])]

vocab_sizes: [7]


In [4]:
class Generator(nn.Module):
    def __init__(self, x_dim, vocab_sizes, nb_numeric, h_dims, z_dim):
        """
        Generator network for conditional GAN
        Args:
            x_dim: Dimension of output data
            vocab_sizes: List of vocabulary sizes for each categorical variable
            nb_numeric: Number of numeric covariates
            h_dims: List of hidden dimensions
            z_dim: Dimension of latent noise vector
        """
        super(Generator, self).__init__()
        
        # Embedding layers for categorical variables
        self.embeddings = nn.ModuleList([
            nn.Embedding(vocab_size, min(50, vocab_size)) 
            for vocab_size in vocab_sizes
        ])
        
        # Calculate total embedding dimension
        embedding_dim = sum(min(50, vocab_size) for vocab_size in vocab_sizes)
        
        # Input dimension is latent dim + embedding dim + numeric covariates
        input_dim = z_dim + embedding_dim + nb_numeric
        
        # Build generator network
        layers = []
        current_dim = input_dim
        
        # Add hidden layers
        for h_dim in h_dims:
            layers.extend([
                nn.Linear(current_dim, h_dim),
                nn.BatchNorm1d(h_dim),
                nn.ReLU()
            ])
            current_dim = h_dim
        
        # Output layer
        layers.append(nn.Linear(current_dim, x_dim))
        
        self.network = nn.Sequential(*layers)

    def forward(self, z, cat_covs, num_covs):
        # Process categorical covariates through embeddings
        embeddings = [emb(cat_covs[:, i]) for i, emb in enumerate(self.embeddings)]
        embedded = torch.cat(embeddings, dim=1)
        
        # Concatenate all inputs
        gen_input = torch.cat([z, embedded, num_covs], dim=1)
        
        # Generate output
        return self.network(gen_input)

In [5]:
class Discriminator(nn.Module):
    def __init__(self, x_dim, vocab_sizes, nb_numeric, h_dims):
        """
        Discriminator network for conditional GAN
        Args:
            x_dim: Dimension of input data
            vocab_sizes: List of vocabulary sizes for each categorical variable
            nb_numeric: Number of numeric covariates
            h_dims: List of hidden dimensions
        """
        super(Discriminator, self).__init__()
        
        # Embedding layers for categorical variables
        self.embeddings = nn.ModuleList([
            nn.Embedding(vocab_size, min(50, vocab_size))
            for vocab_size in vocab_sizes
        ])
        
        # Calculate total embedding dimension
        embedding_dim = sum(min(50, vocab_size) for vocab_size in vocab_sizes)
        
        # Input dimension is data dim + embedding dim + numeric covariates
        input_dim = x_dim + embedding_dim + nb_numeric
        
        # Build discriminator network
        layers = []
        current_dim = input_dim
        
        # Add hidden layers
        for h_dim in h_dims:
            layers.extend([
                nn.Linear(current_dim, h_dim),
                nn.LeakyReLU(0.2),
                nn.Dropout(0.3)
            ])
            current_dim = h_dim
        
        # Output layer
        layers.append(nn.Linear(current_dim, 1))
        
        self.network = nn.Sequential(*layers)

    def forward(self, x, cat_covs, num_covs):
        # Process categorical covariates through embeddings
        embeddings = [emb(cat_covs[:, i]) for i, emb in enumerate(self.embeddings)]
        embedded = torch.cat(embeddings, dim=1)
        
        # Concatenate all inputs
        disc_input = torch.cat([x, embedded, num_covs], dim=1)
        
        # Generate output
        return self.network(disc_input)


In [6]:
def train_gan(generator, discriminator, dataloader, cat_covs, num_covs, 
              config, device, score_fn=None, save_fn=None):
    """
    Train the conditional GAN
    """
    # Optimizers
    g_optimizer = optim.RMSprop(generator.parameters(), lr=config['lr'])
    d_optimizer = optim.RMSprop(discriminator.parameters(), lr=config['lr'])
    
    # Convert covariates to tensors
    cat_covs = torch.tensor(cat_covs, dtype=torch.long).to(device)
    num_covs = torch.tensor(num_covs, dtype=torch.float32).to(device)
    
    for epoch in range(config['epochs']):
        d_losses = []
        g_losses = []
        
        for batch_idx, (real_data,) in enumerate(dataloader):
            batch_size = real_data.size(0)
            
            # Get random batch of categorical and numerical covariates
            batch_indices = torch.randint(0, cat_covs.size(0), (batch_size,))
            batch_cat_covs = cat_covs[batch_indices]
            batch_num_covs = num_covs[batch_indices]
            
            # Train Discriminator
            for _ in range(config['nb_critic']):
                d_optimizer.zero_grad()
                
                # Generate fake data
                z = torch.randn(batch_size, config['latent_dim']).to(device)
                fake_data = generator(z, batch_cat_covs, batch_num_covs)
                
                # Calculate discriminator loss
                real_validity = discriminator(real_data, batch_cat_covs, batch_num_covs)
                fake_validity = discriminator(fake_data.detach(), batch_cat_covs, batch_num_covs)
                
                d_loss = -(torch.mean(real_validity) - torch.mean(fake_validity))
                d_loss.backward()
                d_optimizer.step()
                
                # Clip discriminator weights (Wasserstein GAN)
                for p in discriminator.parameters():
                    p.data.clamp_(-0.01, 0.01)
                    
                d_losses.append(d_loss.item())
            
            # Train Generator
            g_optimizer.zero_grad()
            
            # Generate fake data
            z = torch.randn(batch_size, config['latent_dim']).to(device)
            fake_data = generator(z, batch_cat_covs, batch_num_covs)
            
            # Calculate generator loss
            fake_validity = discriminator(fake_data, batch_cat_covs, batch_num_covs)
            g_loss = -torch.mean(fake_validity)
            
            g_loss.backward()
            g_optimizer.step()
            
            g_losses.append(g_loss.item())
        
        # Log metrics
        if wandb.run is not None:
            wandb.log({
                'epoch': epoch,
                'd_loss': np.mean(d_losses),
                'g_loss': np.mean(g_losses)
            })
        
        # Evaluate and save model if needed
        if score_fn is not None and epoch % 10 == 0:
            score = score_fn(generator)
            print(f'Epoch {epoch}: Score = {score:.4f}')
            
        if save_fn is not None and epoch % 100 == 0:
            save_fn(generator, discriminator)

In [7]:
def main():
    # Configuration
    CONFIG = {
        #'epochs': 2000,
        'epochs': 10,
        'latent_dim': 64,
        'batch_size': 32,
        'nb_layers': 2,
        'hdim': 256,
        'lr': 5e-4,
        'nb_critic': 5
    }
    
    # Device configuration
    if torch.cuda.is_available():
        device = torch.device('cuda')
    elif torch.backends.mps.is_available():
        device = torch.device('mps')
    else:
        device = torch.device('cpu')
    
    # Load and preprocess data (similar to original code)
    # ... (data loading code remains the same until standardization)
    
    # Convert data to PyTorch tensors
    x_train = torch.tensor(x_train, dtype=torch.float32)
    #x_test = torch.tensor(x_test, dtype=torch.float32)
    
    # Create data loader
    train_dataset = TensorDataset(x_train)
    train_loader = DataLoader(
        train_dataset, 
        batch_size=CONFIG['batch_size'],
        shuffle=True,
        drop_last=True
    )
    
    # Initialize models
    vocab_sizes = [len(c) for c in cat_dicts]
    nb_numeric = num_covs.shape[-1]
    x_dim = x_train.shape[-1]
    
    generator = Generator(
        x_dim=x_dim,
        vocab_sizes=vocab_sizes,
        nb_numeric=nb_numeric,
        h_dims=[CONFIG['hdim']] * CONFIG['nb_layers'],
        z_dim=CONFIG['latent_dim']
    ).to(device)
    
    discriminator = Discriminator(
        x_dim=x_dim,
        vocab_sizes=vocab_sizes,
        nb_numeric=nb_numeric,
        h_dims=[CONFIG['hdim']] * CONFIG['nb_layers']
    ).to(device)
    
    # Initialize wandb
    wandb.init(project='adversarial_gene_expr', config=CONFIG)
    
    # Train model
    train_gan(
        generator=generator,
        discriminator=discriminator,
        dataloader=train_loader,
        #cat_covs=cat_covs_train,
        #num_covs=num_covs_train,
        cat_covs=cat_covs,
        num_covs=num_covs,
        config=CONFIG,
        device=device
    )

if __name__ == '__main__':
    main()

UnboundLocalError: cannot access local variable 'x_train' where it is not associated with a value