# CIRRL Quick Start Guide

This notebook demonstrates how to use CIRRL for causal representation learning on single-cell data.

## What is CIRRL?

CIRRL combines:
- **DPA (Distributional Principal Autoencoder)**: Learns latent representations across multiple environments
- **DRIG (Distributionally Robust Instrumental Regression)**: Performs robust regression on learned representations

## Workflow

1. Load and preprocess data
2. Train DPA model to learn causal representations
3. Apply DRIG estimator on learned representations
4. Evaluate performance

In [None]:
# Import libraries
import sys
sys.path.append('..')  # Add parent directory to path

import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns

from cirrl import (
    DPA, OnlyRelu,
    load_singlecell_data,
    train_cirrl_model,
    est_drig_gd_auto,
    compare_latent_dimensions
)

# Set plotting style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print("Libraries imported successfully!")

## 1. Load Data

In [None]:
# Load single-cell data
X, Y, E, X_test, Y_test, test_envs = load_singlecell_data(
    train_path='../data/singlecell.pkl',
    test_path='../data/singlecelltest.pkl',
    separate_test_path='../data/testenvs_separate.pkl'
)

print(f"\nData loaded:")
print(f"  Training samples: {X.shape[0]}")
print(f"  Features: {X.shape[1]}")
print(f"  Environments: {E.shape[1]}")
print(f"  Test samples: {X_test.shape[0]}")

## 2. Initialize and Train DPA Model

In [None]:
# Set random seed for reproducibility
torch.manual_seed(123)
np.random.seed(123)

# Initialize variance regularizer
SIGMA = OnlyRelu(epsilon=0.1)

# Create DPA model
dpa = DPA(
    data_dim=X.shape[1],
    latent_dims=[3],  # Latent dimension
    num_layer=2,
    condition_dim=E.shape[1],
    lr=1e-4,
    hidden_dim=400,
    bn_enc=True,
    bn_dec=True,
    priorvar=SIGMA,
    resblock=True,
    totalvar=True,
    seed=123
)

print("DPA model initialized!")
print(f"Device: {dpa.device}")

In [None]:
# Train the model
history = train_cirrl_model(
    dpa, X, Y, E, X_test, Y_test,
    alpha=0.1,    # GMM loss weight
    beta=0,       # Regularization weight
    gamma=5,      # DRIG gamma for evaluation
    epochs=1000,
    batch_size=len(X),  # Full batch
    print_every=100,
    verbose=True
)

print("\nTraining completed!")

## 3. Visualize Training Progress

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# MSE curves
if history['train_mse'] and history['test_mse']:
    axes[0].plot(history['train_mse'], label='Train MSE', linewidth=2)
    axes[0].plot(history['test_mse'], label='Test MSE', linewidth=2)
    axes[0].set_xlabel('Evaluation Step', fontsize=12)
    axes[0].set_ylabel('MSE', fontsize=12)
    axes[0].set_title('Training Progress', fontsize=14, fontweight='bold')
    axes[0].legend(fontsize=10)
    axes[0].grid(True, alpha=0.3)

# Loss components
if history['dpa_loss']:
    axes[1].plot(history['dpa_loss'], label='DPA Loss', linewidth=2)
    axes[1].plot(history['gmm_loss'], label='GMM Loss', linewidth=2)
    axes[1].set_xlabel('Evaluation Step', fontsize=12)
    axes[1].set_ylabel('Loss', fontsize=12)
    axes[1].set_title('Loss Components', fontsize=14, fontweight='bold')
    axes[1].legend(fontsize=10)
    axes[1].grid(True, alpha=0.3)
    axes[1].set_yscale('log')

plt.tight_layout()
plt.show()

## 4. Extract Latent Representations

In [None]:
# Extract latent representations
dpa.model.eval()

with torch.no_grad():
    _, _, z = dpa.model(
        x=X.to(dpa.device),
        k=dpa.latent_dims[0],
        c=E.to(dpa.device),
        return_latent=True,
        double=True
    )
    
    _, _, z_test = dpa.model(
        x=X_test.to(dpa.device),
        k=dpa.latent_dims[0],
        c=E.to(dpa.device),
        return_latent=True,
        double=True
    )

# Center representations based on reference environment
n_in_ref = int(torch.sum(E[:, 0]))
center_of_z_ref = torch.mean(z[:n_in_ref], dim=0)

z_centered = z.cpu() - center_of_z_ref.cpu()
z_test_cen = z_test.cpu() - center_of_z_ref.cpu()

center_of_y_ref = torch.mean(Y[:n_in_ref])
y_centered = Y.cpu() - center_of_y_ref.cpu()
y_test_cen = Y_test.cpu() - center_of_y_ref.cpu()

print(f"Latent representations extracted:")
print(f"  Shape: {z_centered.shape}")
print(f"  Mean: {z_centered.mean():.4f}")
print(f"  Std: {z_centered.std():.4f}")

## 5. Apply DRIG and Evaluate

In [None]:
# Prepare data for DRIG
from cirrl.utils.data import prepare_drig_data_from_latents

train_data = prepare_drig_data_from_latents(z_centered, y_centered, E)

# Test different gamma values
gamma_values = np.linspace(0, 15, 16)
test_mses = []

for gamma in gamma_values:
    # Estimate DRIG coefficients
    coef = est_drig_gd_auto(
        train_data,
        gamma=gamma,
        device='cuda' if torch.cuda.is_available() else 'cpu',
        verbose=False
    )
    
    # Compute test MSE
    y_pred = z_test_cen.numpy() @ coef
    mse = np.mean((y_pred.flatten() - y_test_cen.numpy()) ** 2)
    test_mses.append(mse)

# Find best gamma
best_idx = np.argmin(test_mses)
best_gamma = gamma_values[best_idx]
best_mse = test_mses[best_idx]

print(f"\nBest gamma: {best_gamma:.2f}")
print(f"Best test MSE: {best_mse:.4f}")

In [None]:
# Plot MSE vs gamma
plt.figure(figsize=(10, 6))
plt.plot(gamma_values, test_mses, marker='o', linewidth=2, markersize=8)
plt.axvline(best_gamma, color='r', linestyle='--', label=f'Best gamma = {best_gamma:.2f}')
plt.xlabel('Gamma', fontsize=12)
plt.ylabel('Test MSE', fontsize=12)
plt.title('Test MSE vs DRIG Gamma Parameter', fontsize=14, fontweight='bold')
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 6. Compare Different Latent Dimensions (Optional)

In [None]:
# This will take longer - compares multiple latent dimensions
comparison_results = compare_latent_dimensions(
    X, Y, E, X_test, Y_test,
    latent_dims_list=[2, 3, 5],
    seeds=[123, 456],
    epochs=500,  # Reduced for demo
    verbose=True
)

# Plot comparison
latent_dims = [r['latent_dim'] for r in comparison_results]
mean_mses = [np.mean(r['test_mses']) for r in comparison_results]
std_mses = [np.std(r['test_mses']) for r in comparison_results]

plt.figure(figsize=(10, 6))
plt.errorbar(latent_dims, mean_mses, yerr=std_mses,
             marker='o', capsize=5, capthick=2, linewidth=2)
plt.xlabel('Latent Dimension', fontsize=12)
plt.ylabel('Test MSE', fontsize=12)
plt.title('Performance vs Latent Dimension', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## Summary

In this notebook, you learned how to:

1. Load single-cell data with multiple environments
2. Initialize and train a DPA model for representation learning
3. Extract and center latent representations
4. Apply DRIG estimator with different gamma values
5. Evaluate performance and find optimal hyperparameters
6. Compare different latent dimensions

## Next Steps

- Try different model architectures (hidden dimensions, number of layers)
- Experiment with different training hyperparameters (alpha, beta)
- Apply CIRRL to your own datasets
- Explore the learned latent representations