# scIDiff Basic Usage Tutorial (Fixed Version)

This notebook demonstrates the basic usage of scIDiff for single-cell gene expression modeling and inverse design.

## Setup Instructions

Before running this notebook, make sure to install the package in development mode:

```bash
# Navigate to the scIDiff directory
cd /path/to/scIDiff

# Install in development mode
pip install -e .
```

Or if you're running from the repository directory, you can add the path to sys.path as shown below.

In [None]:
# Add the parent directory to Python path (if not installed as package)
import sys
import os
sys.path.append(os.path.dirname(os.getcwd()))

# Import necessary libraries
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

In [None]:
# Import scIDiff components
try:
    # Try importing as installed package
    from scIDiff.models import ScIDiffModel
    from scIDiff.training import ScIDiffTrainer
    from scIDiff.sampling import InverseDesigner, PhenotypeTarget, GeneExpressionObjective
    from scIDiff.data import SingleCellDataset
    print("Successfully imported scIDiff components!")
except ImportError as e:
    print(f"Import error: {e}")
    print("Please install the package using: pip install -e .")
    print("Or make sure you're running from the correct directory.")

## 1. Data Preparation

First, let's create some synthetic single-cell data for demonstration purposes.

In [None]:
# Create synthetic single-cell data
def create_synthetic_data(n_cells=1000, n_genes=2000, n_cell_types=5):
    """
    Create synthetic single-cell RNA-seq data
    """
    # Generate cell type labels
    cell_types = np.random.randint(0, n_cell_types, n_cells)
    
    # Generate gene expression data with cell type specific patterns
    expression_data = []
    
    for cell_type in range(n_cell_types):
        n_cells_type = np.sum(cell_types == cell_type)
        
        # Create cell type specific expression pattern
        base_expression = np.random.lognormal(0, 1, (n_cells_type, n_genes))
        
        # Add cell type specific markers
        marker_genes = np.random.choice(n_genes, 50, replace=False)
        base_expression[:, marker_genes] *= (cell_type + 1) * 2
        
        # Add sparsity (many genes are not expressed)
        sparsity_mask = np.random.random((n_cells_type, n_genes)) < 0.7
        base_expression[sparsity_mask] = 0
        
        expression_data.append(base_expression)
    
    # Combine all cell types
    expression_matrix = np.vstack(expression_data)
    
    return expression_matrix, cell_types

# Generate synthetic data
expression_data, cell_type_labels = create_synthetic_data()

print(f"Expression data shape: {expression_data.shape}")
print(f"Cell types: {np.unique(cell_type_labels)}")
print(f"Sparsity: {(expression_data == 0).mean():.2%}")

In [None]:
# Visualize the data
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Expression distribution
axes[0].hist(expression_data.flatten(), bins=50, alpha=0.7)
axes[0].set_xlabel('Expression')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Expression Distribution')

# Cell type distribution
unique, counts = np.unique(cell_type_labels, return_counts=True)
axes[1].bar(unique, counts)
axes[1].set_xlabel('Cell Type')
axes[1].set_ylabel('Number of Cells')
axes[1].set_title('Cell Type Distribution')

# Sparsity per cell
sparsity_per_cell = (expression_data == 0).mean(axis=1)
axes[2].hist(sparsity_per_cell, bins=30, alpha=0.7)
axes[2].set_xlabel('Sparsity (fraction of zeros)')
axes[2].set_ylabel('Number of Cells')
axes[2].set_title('Sparsity Distribution')

plt.tight_layout()
plt.show()

## 2. Create Dataset

Use the scIDiff SingleCellDataset class to handle the data.

In [None]:
# Create metadata dictionary
metadata = {
    'cell_type': cell_type_labels,
    'batch': np.random.randint(0, 3, len(cell_type_labels)),  # 3 batches
    'total_counts': expression_data.sum(axis=1)
}

# Create dataset
dataset = SingleCellDataset(
    expression_data=expression_data,
    cell_metadata=metadata,
    normalize=True  # Apply log1p normalization
)

print(f"Dataset created with {len(dataset)} cells")
print(f"Dataset statistics: {dataset.get_statistics()}")
print(f"Cell type categories: {dataset.get_metadata_categories('cell_type')}")

In [None]:
# Test dataset access
sample = dataset[0]
print("Sample keys:", list(sample.keys()))
print("Expression shape:", sample['expression'].shape)
print("Cell type:", sample['cell_type'].item())
print("Batch:", sample['batch'].item())

## 3. Model Initialization

Now let's initialize the scIDiff model with appropriate parameters.

In [None]:
# Model configuration
model_config = {
    'gene_dim': dataset.n_genes,          # Number of genes
    'hidden_dim': 512,                    # Hidden dimension
    'num_layers': 6,                      # Number of layers
    'num_timesteps': 1000,                # Diffusion timesteps
    'conditioning_dim': 128,              # Conditioning dimension
    'dropout': 0.1,                       # Dropout rate
    'use_attention': True,                # Use attention layers
}

# Initialize model
model = ScIDiffModel(**model_config)
model = model.to(device)

print(f"Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"Model device: {next(model.parameters()).device}")

## 4. Data Loading

Prepare the data for training using PyTorch DataLoader.

In [None]:
from torch.utils.data import DataLoader, random_split

# Split dataset into train and validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Batch size: {train_loader.batch_size}")

# Test data loader
batch = next(iter(train_loader))
print("\nBatch keys:", list(batch.keys()))
print("Batch expression shape:", batch['expression'].shape)
print("Batch cell_type shape:", batch['cell_type'].shape)

## 5. Model Training

Train the scIDiff model using the prepared data.

In [None]:
# Initialize trainer
trainer = ScIDiffTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    log_dir='./logs',
    checkpoint_dir='./checkpoints',
    use_wandb=False,  # Set to True if you want to use Weights & Biases
    gradient_clip_val=1.0
)

print("Trainer initialized successfully")
print(trainer.get_model_summary())

In [None]:
# Train the model (reduced epochs for demo)
trainer.train(
    num_epochs=5,  # Use more epochs for real training
    save_every=2,
    validate_every=1,
    log_every=1
)

print("Training completed!")

## 6. Sample Generation

Generate new single-cell expression profiles using the trained model.

In [None]:
# Generate unconditional samples
model.eval()
with torch.no_grad():
    generated_samples = model.sample(batch_size=100)
    generated_samples = generated_samples.cpu().numpy()

print(f"Generated samples shape: {generated_samples.shape}")
print(f"Generated samples sparsity: {(generated_samples == 0).mean():.2%}")

In [None]:
# Compare real vs generated data
real_data = dataset.expression_data.numpy()

fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Expression distributions
axes[0, 0].hist(real_data.flatten(), bins=50, alpha=0.7, label='Real', density=True)
axes[0, 0].hist(generated_samples.flatten(), bins=50, alpha=0.7, label='Generated', density=True)
axes[0, 0].set_xlabel('Log Expression')
axes[0, 0].set_ylabel('Density')
axes[0, 0].set_title('Expression Distribution Comparison')
axes[0, 0].legend()

# Mean expression per gene
real_mean = real_data.mean(axis=0)
gen_mean = generated_samples.mean(axis=0)
axes[0, 1].scatter(real_mean, gen_mean, alpha=0.5)
axes[0, 1].plot([0, real_mean.max()], [0, real_mean.max()], 'r--')
axes[0, 1].set_xlabel('Real Mean Expression')
axes[0, 1].set_ylabel('Generated Mean Expression')
axes[0, 1].set_title('Mean Expression Correlation')

# Variance comparison
real_var = real_data.var(axis=0)
gen_var = generated_samples.var(axis=0)
axes[1, 0].scatter(real_var, gen_var, alpha=0.5)
axes[1, 0].plot([0, real_var.max()], [0, real_var.max()], 'r--')
axes[1, 0].set_xlabel('Real Variance')
axes[1, 0].set_ylabel('Generated Variance')
axes[1, 0].set_title('Variance Correlation')

# Sparsity comparison
real_sparsity = (real_data == 0).mean(axis=0)
gen_sparsity = (generated_samples == 0).mean(axis=0)
axes[1, 1].scatter(real_sparsity, gen_sparsity, alpha=0.5)
axes[1, 1].plot([0, 1], [0, 1], 'r--')
axes[1, 1].set_xlabel('Real Sparsity')
axes[1, 1].set_ylabel('Generated Sparsity')
axes[1, 1].set_title('Sparsity Correlation')

plt.tight_layout()
plt.show()

## 7. Conditional Generation

Generate samples conditioned on specific cell types.

In [None]:
# Generate samples for each cell type
cell_type_samples = {}

for cell_type in range(5):  # 5 cell types
    conditioning = {
        'cell_type': torch.tensor([cell_type] * 50, device=device)
    }
    
    with torch.no_grad():
        samples = model.sample(batch_size=50, conditioning=conditioning)
        cell_type_samples[cell_type] = samples.cpu().numpy()

print("Generated conditional samples for all cell types")

## 8. Inverse Design

Now let's demonstrate the inverse design capability - generating cells with specific gene expression targets.

In [None]:
# Create gene name to index mapping (for demo purposes)
gene_names = dataset.get_gene_names()
gene_to_idx = {name: idx for idx, name in enumerate(gene_names)}

# Add some "marker" genes
marker_genes = ['Gene_0', 'Gene_1', 'Gene_2', 'Gene_10', 'Gene_20']
suppressed_genes = ['Gene_100', 'Gene_200', 'Gene_300']

print(f"Marker genes: {marker_genes}")
print(f"Suppressed genes: {suppressed_genes}")

In [None]:
# Setup inverse design
objective_function = GeneExpressionObjective(gene_to_idx)
designer = InverseDesigner(
    model=model,
    objective_functions=[objective_function],
    device=device
)

# Define target phenotype
target_phenotype = PhenotypeTarget(
    gene_targets={
        'Gene_0': 5.0,   # High expression
        'Gene_1': 4.0,   # High expression
        'Gene_2': 3.0,   # Medium-high expression
    },
    marker_genes=marker_genes,
    suppressed_genes=suppressed_genes,
    cell_type='custom'
)

print("Target phenotype defined")
print(f"Gene targets: {target_phenotype.gene_targets}")

In [None]:
# Perform inverse design
designed_cells = designer.design(
    target=target_phenotype,
    num_samples=32,
    num_optimization_steps=20,  # Reduced for demo
    learning_rate=0.01
)

designed_cells_np = designed_cells.cpu().numpy()
print(f"Designed cells shape: {designed_cells_np.shape}")

## 9. Save and Load Model

Demonstrate how to save and load the trained model.

In [None]:
# Save the model
model_save_path = 'scidiff_demo_model.pt'

torch.save({
    'model_state_dict': model.state_dict(),
    'model_config': model_config,
    'dataset_stats': dataset.get_statistics()
}, model_save_path)

print(f"Model saved to {model_save_path}")

In [None]:
# Load the model
checkpoint = torch.load(model_save_path, map_location=device)

# Create new model instance
loaded_model = ScIDiffModel(**checkpoint['model_config'])
loaded_model.load_state_dict(checkpoint['model_state_dict'])
loaded_model = loaded_model.to(device)
loaded_model.eval()

print("Model loaded successfully")
print(f"Loaded model has {sum(p.numel() for p in loaded_model.parameters()):,} parameters")

In [None]:
# Test loaded model
with torch.no_grad():
    test_samples = loaded_model.sample(batch_size=10)
    print(f"Generated test samples shape: {test_samples.shape}")
    print("Model loading and testing successful!")

## 10. Summary

In this tutorial, we demonstrated:

1. **Data Preparation**: Creating synthetic single-cell data and using SingleCellDataset
2. **Model Initialization**: Setting up the scIDiff model
3. **Training**: Training the diffusion model on single-cell data
4. **Generation**: Generating new single-cell expression profiles
5. **Conditional Generation**: Generating samples conditioned on cell types
6. **Inverse Design**: Designing cells with specific gene expression targets
7. **Model Persistence**: Saving and loading trained models

### Next Steps

- Try with real single-cell datasets (load with scanpy)
- Experiment with different model architectures
- Implement custom objective functions for inverse design
- Explore perturbation prediction capabilities
- Scale up training with larger datasets and more epochs

### Key Takeaways

- scIDiff provides a flexible framework for single-cell generative modeling
- The inverse design capability enables targeted cellular state generation
- The model can be conditioned on various biological covariates
- Proper evaluation is crucial for assessing generation quality