# CNV-Guided Contrastive Learning - Complete Pipeline

This notebook walks through the entire pipeline from data loading to training and evaluation.

## Steps:
1. Data loading and preprocessing
2. CNV inference (run separately in R)
3. Dataset creation
4. Model training
5. Evaluation
6. Downstream analysis

In [None]:
import numpy as np
import pandas as pd
import scanpy as sc
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Import custom modules
import sys
sys.path.append('../')

from model import MultimodalEncoder, build_cnv_anchor_bank
from losses import CombinedLoss
from data_processing import MultimodalScDataset, preprocess_adata
from train import train_model
from evaluation import evaluate_alignment, plot_similarity_heatmap, plot_umap, plot_training_curves

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)

# Check GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

## Step 1: Load Data

First, you need to download the data from GEO (GSE131907).

**Manual steps:**
1. Go to: https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE131907
2. Download the count matrix for Patient P0006 (LUNG_T06 and LUNG_N06)
3. Save to `data/raw/`

In [None]:
# For now, we'll create synthetic data to test the pipeline
# Replace this with actual data loading once you have the data

def create_synthetic_data(n_cells=10000, n_genes=5884, n_subclusters=78):
    """Create synthetic data for testing."""
    print("Creating synthetic data...")
    
    # Create expression matrix
    X = np.random.negative_binomial(5, 0.3, (n_cells, n_genes)).astype(float)
    
    # Create obs with metadata
    obs = pd.DataFrame({
        'sample': np.random.choice(['LUNG_T06', 'LUNG_N06'], n_cells),
        'cell_type': np.random.choice(['Epithelial', 'T cell', 'Myeloid'], n_cells),
        'cancer_vs_normal': np.random.choice(['Cancer', 'Normal'], n_cells)
    })
    obs.index = [f'cell_{i}' for i in range(n_cells)]
    
    # Create var
    var = pd.DataFrame(index=[f'gene_{i}' for i in range(n_genes)])
    
    # Create AnnData
    adata = sc.AnnData(X=X, obs=obs, var=var)
    
    # Add CNV subclusters (would come from inferCNV)
    adata.obs['subcluster'] = [f'cnv_cluster_{i % n_subclusters}' for i in range(n_cells)]
    
    # Create CNV profiles (would come from inferCNV)
    cnv_profiles = pd.DataFrame(
        np.random.randn(n_subclusters, n_genes),
        index=[f'cnv_cluster_{i}' for i in range(n_subclusters)],
        columns=var.index
    )
    
    return adata, cnv_profiles

# Create or load data
adata, cnv_profiles = create_synthetic_data()

print(f"\nData shape: {adata.shape}")
print(f"CNV profiles shape: {cnv_profiles.shape}")
print(f"Number of subclusters: {adata.obs['subcluster'].nunique()}")

## Step 2: Preprocessing

Standard scRNA-seq preprocessing:
- Filter cells and genes
- Normalize
- Log-transform
- (Optional) Select highly variable genes

In [None]:
# Preprocess
adata_processed = preprocess_adata(
    adata.copy(),
    min_genes=200,
    min_cells=3,
    target_sum=1e4,
    n_top_genes=None,  # Keep all genes or set to 5884
    log_transform=True
)

print(f"\nProcessed data shape: {adata_processed.shape}")

## Step 3: CNV Inference (Run in R)

**This step must be done in R using inferCNV:**

```R
library(infercnv)

# Create inferCNV object
infercnv_obj = CreateInfercnvObject(
    raw_counts_matrix="data/processed/counts_matrix.txt",
    annotations_file="data/processed/cell_annotations.txt",
    delim="\t",
    gene_order_file="data/processed/gene_positions.txt",
    ref_group_names=c("Normal")
)

# Run inferCNV
infercnv_obj = infercnv::run(
    infercnv_obj,
    cutoff=0.1,
    out_dir="data/processed/infercnv_output",
    cluster_by_groups=TRUE,
    denoise=TRUE,
    HMM=TRUE
)
```

After running inferCNV:
1. Load the subcluster assignments into `adata.obs['subcluster']`
2. Load the CNV profiles matrix

For this demo, we're using synthetic CNV data created above.

## Step 4: Create Dataset

In [None]:
# Create dataset
dataset = MultimodalScDataset(
    adata_processed,
    cnv_profiles,
    subcluster_col='subcluster',
    expr_layer=None  # Use .X (log-normalized)
)

print(f"\nDataset size: {len(dataset)}")

# Test dataset
sample = dataset[0]
print(f"Sample shapes:")
print(f"  x_expr: {sample['x_expr'].shape}")
print(f"  x_cnv: {sample['x_cnv'].shape}")
print(f"  label: {sample['label']}")

## Step 5: Initialize Model

In [None]:
# Get number of genes
n_genes = adata_processed.n_vars

# Initialize model
model = MultimodalEncoder(
    n_genes=n_genes,
    hidden_dim=256,
    latent_dim=64,
    freeze_cnv=True
)

print(f"Model initialized with {n_genes} genes")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

## Step 6: Train Model

**Note:** With synthetic data, we'll use fewer epochs for demo. 
With real data, use 100 epochs.

In [None]:
# Training parameters
n_epochs = 10  # Use 100 for real training
batch_size = 512  # Use 4096 for real training (if GPU allows)
learning_rate = 1e-3

# Convert CNV profiles to tensor
cnv_profiles_tensor = torch.FloatTensor(cnv_profiles.values)

# Train
model_trained, history = train_model(
    model=model,
    dataset=dataset,
    cnv_profiles_tensor=cnv_profiles_tensor,
    n_epochs=n_epochs,
    batch_size=batch_size,
    learning_rate=learning_rate,
    weight_decay=1e-4,
    device=device,
    save_dir='../checkpoints',
    save_every=5
)

## Step 7: Plot Training Curves

In [None]:
plot_training_curves(history, save_path='../figures/training_curves.png')

## Step 8: Evaluate Alignment

Compute top-k retrieval accuracy (target: 97.4% for k=5)

In [None]:
# Evaluate
results = evaluate_alignment(
    model=model_trained,
    dataset=dataset,
    cnv_profiles_tensor=cnv_profiles_tensor,
    device=device,
    k=5
)

print(f"\n{'='*50}")
print(f"ALIGNMENT RESULTS")
print(f"{'='*50}")
print(f"Z-space top-5 accuracy: {results['z_space_accuracy']:.1%}")
print(f"H-space top-5 accuracy: {results['h_space_accuracy']:.1%}")
print(f"\nTarget z-space accuracy: 97.4%")

## Step 9: Visualizations

In [None]:
# Plot CNV similarity heatmap
plot_similarity_heatmap(
    results['z_similarities'],
    save_path='../figures/cnv_similarity_heatmap.png',
    title="CNV Embeddings Similarity Matrix (Z-space)"
)

In [None]:
# UMAP visualization colored by cancer vs normal
cancer_labels = (adata_processed.obs['cancer_vs_normal'] == 'Cancer').astype(int).values

umap_coords = plot_umap(
    results['z_expr'],
    cancer_labels,
    color_by='Cancer (1) vs Normal (0)',
    title="UMAP of Expression Embeddings",
    save_path='../figures/umap_cancer_vs_normal.png'
)

## Step 10: Store Embeddings in AnnData

In [None]:
# Store embeddings in adata for downstream analysis
adata_processed.obsm['X_h_expr'] = results['h_expr']
adata_processed.obsm['X_z_expr'] = results['z_expr']
adata_processed.obsm['X_umap'] = umap_coords

# Save
adata_processed.write_h5ad('../data/processed/adata_with_embeddings.h5ad')
print("Saved AnnData with embeddings")

## Step 11: Downstream Analysis

### Traditional Differential Expression

In [None]:
# Traditional DE: Cancer vs Normal in Epithelial cells
epithelial_mask = adata_processed.obs['cell_type'] == 'Epithelial'
adata_epithelial = adata_processed[epithelial_mask].copy()

print(f"Epithelial cells: {adata_epithelial.n_obs}")

# Run Wilcoxon test
sc.tl.rank_genes_groups(
    adata_epithelial,
    groupby='cancer_vs_normal',
    reference='Normal',
    method='wilcoxon'
)

# Get top markers
markers_df = sc.get.rank_genes_groups_df(adata_epithelial, group='Cancer')
print("\nTop 10 differentially expressed genes:")
print(markers_df.head(10))

### CNV-Conditioned Clustering

In [None]:
# Use h-space embeddings for clustering
sc.pp.neighbors(adata_processed, use_rep='X_h_expr', n_neighbors=30)
sc.tl.leiden(adata_processed, resolution=0.5)

print(f"\nFound {adata_processed.obs['leiden'].nunique()} Leiden clusters")

# Visualize clusters on UMAP
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Color by Leiden cluster
scatter1 = axes[0].scatter(
    umap_coords[:, 0],
    umap_coords[:, 1],
    c=adata_processed.obs['leiden'].astype(int),
    cmap='tab20',
    s=1,
    alpha=0.6
)
axes[0].set_title('Leiden Clusters (CNV-informed)')
plt.colorbar(scatter1, ax=axes[0])

# Color by cancer vs normal
scatter2 = axes[1].scatter(
    umap_coords[:, 0],
    umap_coords[:, 1],
    c=cancer_labels,
    cmap='RdBu_r',
    s=1,
    alpha=0.6
)
axes[1].set_title('Cancer vs Normal')
plt.colorbar(scatter2, ax=axes[1])

plt.tight_layout()
plt.savefig('../figures/clustering_results.png', dpi=300, bbox_inches='tight')
plt.show()

## Next Steps

1. **With Real Data:**
   - Download GSE131907 Patient P0006
   - Run inferCNV to get actual CNV profiles
   - Train for 100 epochs with batch_size=4096
   - Achieve 97.4% top-5 accuracy

2. **CNV-Conditioned DE:**
   - Identify neighborhoods with mixed cancer/normal cells
   - Perform DE within CNV-consistent regions
   - Compare to traditional DE results

3. **Biomarker Validation:**
   - Validate APOC1 downregulation
   - Check other top markers (Table 1 in paper)
   - Explore early malignant transition signatures

4. **Multi-Patient Analysis:**
   - Extend to other patients in GSE131907
   - Assess reproducibility
   - Study inter-patient variability

## Summary

This notebook demonstrates the complete pipeline:

✓ Data preprocessing  
✓ Dataset creation  
✓ Model training  
✓ Evaluation (top-k retrieval)  
✓ Visualization  
✓ Downstream analysis  

**Key files created:**
- `model.py`: Encoder architectures
- `losses.py`: Three loss functions
- `data_processing.py`: Dataset utilities
- `train.py`: Training loop
- `evaluation.py`: Metrics and visualization

**Next:** Run with real data to reproduce the paper's 97.4% accuracy!