In [None]:
import os
import warnings
import pandas as pd
import numpy as np
from pathlib import Path
from tqdm import tqdm
import anndata as ad
import lightning as L
from os.path import join
from modlyn.io.loading import read_lazy

import lamindb as ln

from modlyn.io.datamodules import ClassificationDataModule
from modlyn.models.linear import Linear
from modlyn.io.loading import read_lazy

In [None]:
store_path = Path("/home/ubuntu/tahoe100M_chunk_1")


In [None]:
adata = read_lazy(store_path)
var = pd.read_parquet("var_subset_tahoe100M.parquet")
print(var)
# adata.var = var.reindex(adata.var.index)
# print(adata)

adata.var = var

In [None]:
adata

In [None]:
adata.obs["y"] = adata.obs["cell_line"].astype("category").cat.codes.to_numpy().astype("i8")

In [None]:
adata_train = adata[:800000]
adata_val = adata[800000:]

datamodule = ClassificationDataModule(
    adata_train=adata_train,
    adata_val=adata_val,
    label_column="y",
    train_dataloader_kwargs={
        "batch_size": 2048,
        "drop_last": True,
    },
    val_dataloader_kwargs={
        "batch_size": 2048,
        "drop_last": False,
    },
)

In [None]:
linear = Linear(
    n_genes=adata.n_vars,
    n_covariates=adata.obs["y"].nunique(),
    learning_rate=1e-2,
)

In [None]:
trainer = L.Trainer(
    max_epochs=3,
    log_every_n_steps=100,
    max_steps=3000,  # only fit a few steps for the sake of this tutorial
)

In [None]:
trainer.fit(model=linear, datamodule=datamodule)

## Quick analysis

In [None]:
import importlib
import LinearModuleAnalyzer
importlib.reload(LinearModuleAnalyzer)

from LinearModuleAnalyzer import quick_analysis_with_scanpy_dotplot, full_analysis

# analyzer, weight_adata, df = quick_analysis_with_scanpy_dotplot(linear, adata, datamodule)
results = full_analysis(linear, adata, datamodule)

# Uncertainty scores

In [None]:
# import UncertaintyEstimation
# importlib.reload(UncertaintyEstimation)
# from UncertaintyEstimation import get_proper_uncertainty

# results = get_proper_uncertainty(linear, adata, datamodule)


In [None]:
import Figures
importlib.reload(Figures)
from Figures import create_publication_figures

nf, legends = create_publication_figures(linear, adata)

MODLYN: LINEAR MODELS FOR MASSIVE SINGLE-CELL PERTURBATION ANALYSIS
================================================================

ABSTRACT
--------
We present MODLYN, a scalable framework for analyzing massive single-cell perturbation datasets 
using interpretable linear models. Applied to the Tahoe-100M dataset (100M cells -eventually-, 
19,177 genes, 50 perturbations), our approach enables rapid 
identification of perturbation-specific gene signatures, mechanism clustering, and biomarker 
discovery at unprecedented scale.

INTRODUCTION
-----------
Single-cell RNA sequencing has revolutionized our understanding of cellular responses to 
perturbations. However, analyzing datasets with hundreds of millions of cells presents 
computational and interpretability challenges. Traditional non-linear methods, while powerful, 
often lack the transparency needed for biological interpretation and struggle with scale.

We hypothesized that linear models, despite their simplicity, could effectively capture 
perturbation-specific signatures while maintaining computational efficiency and interpretability. 
The MODLYN framework tests this hypothesis on the largest single-cell perturbation dataset 
to date.

RESULTS
-------

Dataset Scale and Computational Performance (numbers to-be-updated)
Our analysis of the Tahoe-100M dataset represents a XYZ% increase in scale 
compared to typical single-cell studies. The linear model achieved:
- Training time: 25.3 minutes
- Peak memory usage: 8.5 GB  
- Model parameters: 958,850 weights
- Inference speed: ~1ms per cell

Gene Importance and Statistical Significance
We identified 959 highly predictive genes 
(>95th percentile importance). Statistical uncertainty analysis revealed:
- 0 significant gene-perturbation associations (p<0.05)
- 0 highly significant associations (p<0.001)
- Mean standard error: 0.0000

CONCLUSIONS
-----------
The MODLYN framework enables scalable, interpretable analysis of massive single-cell 
perturbation data. Linear models provide surprising effectiveness at this scale, offering 
a compelling alternative to complex non-linear approaches for many biological questions.



In [None]:
import OverviewFig
importlib.reload(OverviewFig)
from OverviewFig import create_modlyn_figure

fig, caption = create_modlyn_figure()

# Dataset / Biological analysis

Figure 1: Expression Overview & Quality Control

Figure 2: Differential Expression Analysis

Figure 3: Cell Clustering Analysis

Figure 4: Drug Response Analysis

Figure 5: Scanpy Expression Analysis

!!!! Some mock functions

In [None]:
import gene_level_analysis
import importlib
importlib.reload(gene_level_analysis)

# Import the class from the module
from gene_level_analysis import GeneExpressionAnalyzer

# Now you can use it
analyzer = GeneExpressionAnalyzer(adata)
analyzer.figure_1_expression_overview()


# Or run the complete analysis
# analyzer.run_complete_gene_analysis()

In [None]:
# analyzer.figure_2_differential_expression() 
# analyzer.figure_3_cell_clustering_analysis()


In [None]:
analyzer.figure_4_drug_response_analysis()


In [None]:
analyzer.figure_5_scanpy_expression_analysis()


In [None]:
analyzer.generate_biological_narrative()

In [None]:
# !pip install dask_ml
import dask_ml

In [None]:
# import importlib
# import comprehensive_analysis
# importlib.reload(comprehensive_analysis)
# from comprehensive_analysis import ComprehensiveAnalysis

# analysis = ComprehensiveAnalysis()
# final_results = analysis.run_complete_analysis(n_cells=1000)

In [None]:
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = 'DejaVu Sans'  # or 'Liberation Sans'
# Alternative: plt.rcParams['font.family'] = 'sans-serif'

In [None]:
from comprehensive_analysis import ComprehensiveAnalysis
from figure_generator import FigureGenerator
from blog_generator import BlogGenerator
from pathlib import Path

# Initialize
analysis = ComprehensiveAnalysis()
analysis.chunk_path = Path("/home/ubuntu/tahoe100M_chunk_1")
analysis.var_path = Path("/home/ubuntu/var_subset_tahoe100M.parquet")

# Load data
adata = analysis.load_data(n_cells=5000)

# Run methods
scanpy_results = analysis.run_scanpy_analysis(adata)
modlyn_results = analysis.run_modlyn_analysis(adata)
linscvi_results = analysis.run_linscvi_analysis(adata)  # Optional

# Store results
analysis.results = {
    "adata": adata,
    "scanpy": scanpy_results,
    "modlyn": modlyn_results,
    "linscvi": linscvi_results
}

# Generate figures
fig_gen = FigureGenerator(analysis)
fig_gen.generate_all_figures()

# Generate blog content
concordance_df = analysis.analyze_concordance(scanpy_results, modlyn_results, linscvi_results)
scalability_df = analysis.run_scalability_test([1000, 2000, 5000])

blog_gen = BlogGenerator(analysis)
blog_gen.generate_blog_post(concordance_df, scalability_df)

In [None]:
## Quick test
import importlib
import run_analysis
importlib.reload(run_analysis)
from run_analysis import run_complete_pipeline
from comprehensive_analysis import ComprehensiveAnalysis

analysis = ComprehensiveAnalysis()
analysis.chunk_path = Path("/home/ubuntu/tahoe100M_chunk_1")
analysis.var_path = Path("/home/ubuntu/var_subset_tahoe100M.parquet")

# Quick test with small dataset
results = run_complete_pipeline(
    analysis,
    n_cells=2000,
    max_epochs=1,
    skip_linscvi=True
)

In [None]:
# Option 2: Step by step
exec(open('minimal_analysis.py').read())
exec(open('minimal_figures.py').read())

analysis = MinimalAnalysis()
results = analysis.run_complete_analysis(
    chunk_path="/home/ubuntu/tahoe100M_chunk_1",
    var_path="/home/ubuntu/var_subset_tahoe100M.parquet",
    n_cells=5000,
    skip_linscvi=False
)

figures = MinimalFigures(analysis)
figures.generate_all_figures(results["concordance"])

In [None]:
# Simplified Analysis for Jupyter Notebook
# Copy this entire cell and run it in your notebook

import warnings
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
import time
import lightning as L
from tqdm import tqdm
import torch
from scipy import stats

warnings.filterwarnings("ignore")
sc.settings.verbosity = 0

# Set up plotting
plt.rcParams.update({
    'figure.figsize': (12, 8),
    'figure.dpi': 150,
    'font.size': 12
})

class NotebookAnalysis:
    def __init__(self):
        self.results = {}
        self.performance = {}
        
    def load_and_analyze(self, chunk_path, var_path=None, n_cells=5000):
        """Main analysis function - simplified for notebook use"""
        
        print("Loading data...")
        # Load data
        from modlyn.io.loading import read_lazy
        adata = read_lazy(Path(chunk_path))
        
        if var_path and Path(var_path).exists():
            var = pd.read_parquet(var_path)
            adata.var = var
        
        if n_cells and n_cells < adata.n_obs:
            np.random.seed(42)
            idx = np.random.choice(adata.n_obs, n_cells, replace=False)
            adata = adata[idx].copy()
        
        adata.obs["cell_line"] = adata.obs["cell_line"].astype("category")
        adata.obs["y"] = adata.obs["cell_line"].cat.codes.astype(int)
        
        print(f"Loaded: {adata.n_obs} cells, {adata.n_vars} genes, {adata.obs.cell_line.nunique()} cell lines")
        
        # Run Scanpy
        print("\nRunning Scanpy...")
        start_time = time.time()
        scanpy_results = self.run_scanpy(adata)
        scanpy_time = time.time() - start_time
        self.performance['scanpy'] = {'time': scanpy_time, 'n_genes': adata.n_vars}
        
        # Run MODLYN
        print("Running MODLYN...")
        start_time = time.time()
        modlyn_results = self.run_modlyn(adata)
        modlyn_time = time.time() - start_time
        self.performance['modlyn'] = {'time': modlyn_time, 'n_genes': adata.n_vars}
        
        # Store results
        self.results = {
            'adata': adata,
            'scanpy': scanpy_results,
            'modlyn': modlyn_results
        }
        
        # Create figures
        print("Creating figures...")
        self.create_comparison_figure()
        self.create_performance_figure()
        
        # Print summary
        self.print_summary()
        
        return self.results
    
    def run_scanpy(self, adata):
        """Run Scanpy analysis"""
        adata_sc = adata.copy()
        if hasattr(adata_sc.X, 'compute'):
            adata_sc.X = adata_sc.X.compute()
        
        sc.pp.normalize_total(adata_sc, target_sum=1e4)
        sc.pp.log1p(adata_sc)
        sc.pp.highly_variable_genes(adata_sc, n_top_genes=2000)
        adata_sc = adata_sc[:, adata_sc.var.highly_variable].copy()
        
        de_results = {}
        cell_lines = adata_sc.obs["cell_line"].cat.categories
        
        for cell_line in tqdm(cell_lines, desc="Scanpy DE"):
            try:
                adata_sc.obs["group"] = (adata_sc.obs["cell_line"] == cell_line).astype(str)
                sc.tl.rank_genes_groups(
                    adata_sc, 
                    groupby="group",
                    groups=["True"],
                    reference="False",
                    method="wilcoxon"
                )
                de_result = sc.get.rank_genes_groups_df(adata_sc, group="True")
                de_results[cell_line] = de_result
            except Exception as e:
                print(f"Scanpy failed for {cell_line}: {e}")
                de_results[cell_line] = pd.DataFrame()
        
        return de_results
    
    def run_modlyn(self, adata):
        """Run MODLYN analysis"""
        from modlyn.io.datamodules import ClassificationDataModule
        from modlyn.models.linear import Linear
        
        n_train = int(0.8 * adata.n_obs)
        adata_train = adata[:n_train].copy()
        adata_val = adata[n_train:].copy()
        
        datamodule = ClassificationDataModule(
            adata_train=adata_train,
            adata_val=adata_val,
            label_column="y",
            train_dataloader_kwargs={"batch_size": 512, "num_workers": 0},
            val_dataloader_kwargs={"batch_size": 512, "num_workers": 0},
        )
        
        model = Linear(
            n_genes=adata.n_vars,
            n_covariates=adata.obs["y"].nunique(),
            learning_rate=1e-2,
        )
        
        trainer = L.Trainer(
            max_epochs=3,
            enable_progress_bar=False,
            enable_model_summary=False,
            logger=False
        )
        
        trainer.fit(model=model, datamodule=datamodule)
        
        weights = model.linear.weight.detach().cpu().numpy()
        class_to_cellline = dict(enumerate(adata.obs["cell_line"].cat.categories))
        
        modlyn_results = {}
        for class_idx, cell_line in class_to_cellline.items():
            class_weights = weights[class_idx]
            
            gene_results = pd.DataFrame({
                "gene": adata.var_names,
                "weight": class_weights,
                "abs_weight": np.abs(class_weights)
            })
            
            z_scores = (class_weights - class_weights.mean()) / class_weights.std()
            gene_results["z_score"] = z_scores
            gene_results["p_value"] = 2 * (1 - stats.norm.cdf(np.abs(z_scores)))
            gene_results = gene_results.sort_values("abs_weight", ascending=False)
            modlyn_results[cell_line] = gene_results
        
        return modlyn_results
    
    def create_comparison_figure(self, cell_line=None, n_top=20):
        """Create method comparison figure"""
        if cell_line is None:
            cell_line = list(self.results['modlyn'].keys())[0]
        
        fig, axes = plt.subplots(1, 3, figsize=(18, 8))
        
        # Scanpy
        if not self.results['scanpy'][cell_line].empty:
            scanpy_data = self.results['scanpy'][cell_line].head(n_top)
            y_pos = np.arange(len(scanpy_data))
            
            axes[0].barh(y_pos, scanpy_data["scores"], color='#2E86AB', alpha=0.8)
            axes[0].set_yticks(y_pos)
            axes[0].set_yticklabels(scanpy_data["names"], fontsize=10)
            axes[0].set_title("Scanpy (Statistical DE)", fontweight='bold')
            axes[0].set_xlabel("Wilcoxon Score")
        else:
            axes[0].text(0.5, 0.5, "No Scanpy Results", ha='center', va='center',
                        transform=axes[0].transAxes, fontsize=14)
        
        # MODLYN
        modlyn_data = self.results['modlyn'][cell_line].head(n_top)
        y_pos = np.arange(len(modlyn_data))
        
        colors = ['#F18F01' if w > 0 else '#A23B72' for w in modlyn_data["weight"]]
        axes[1].barh(y_pos, modlyn_data["weight"], color=colors, alpha=0.8)
        axes[1].set_yticks(y_pos)
        axes[1].set_yticklabels(modlyn_data["gene"], fontsize=10)
        axes[1].set_title("MODLYN (Linear Model)", fontweight='bold')
        axes[1].set_xlabel("Linear Weight")
        axes[1].axvline(x=0, color='black', linestyle='-', alpha=0.5)
        
        # Overlap analysis
        if not self.results['scanpy'][cell_line].empty:
            scanpy_genes = set(self.results['scanpy'][cell_line].head(n_top)["names"])
            modlyn_genes = set(self.results['modlyn'][cell_line].head(n_top)["gene"])
            
            overlap = len(scanpy_genes & modlyn_genes)
            scanpy_unique = len(scanpy_genes - modlyn_genes)
            modlyn_unique = len(modlyn_genes - scanpy_genes)
            
            categories = ['Scanpy\nUnique', 'Overlap', 'MODLYN\nUnique']
            values = [scanpy_unique, overlap, modlyn_unique]
            colors_pie = ['#2E86AB', '#52B788', '#F18F01']
            
            axes[2].pie(values, labels=categories, colors=colors_pie, autopct='%1.0f',
                       startangle=90)
            axes[2].set_title(f"Gene Overlap\n(Top {n_top} genes)", fontweight='bold')
        
        plt.suptitle(f'Method Comparison - {cell_line}', fontsize=16, fontweight='bold')
        plt.tight_layout()
        plt.show()
    
    def create_performance_figure(self):
        """Create performance comparison figure"""
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        
        methods = list(self.performance.keys())
        runtimes = [self.performance[m]["time"] for m in methods]
        colors = ['#2E86AB', '#F18F01']
        
        # Runtime comparison
        bars = axes[0].bar(methods, runtimes, color=colors, alpha=0.8, edgecolor='white')
        axes[0].set_ylabel('Runtime (seconds)', fontweight='bold')
        axes[0].set_title('Runtime Comparison', fontweight='bold')
        
        for bar, time_val in zip(bars, runtimes):
            height = bar.get_height()
            axes[0].text(bar.get_x() + bar.get_width()/2., height + max(runtimes)*0.02,
                       f'{time_val:.1f}s', ha='center', va='bottom', fontweight='bold')
        
        # Speedup calculation
        if len(methods) == 2 and 'scanpy' in methods and 'modlyn' in methods:
            speedup = self.performance['scanpy']['time'] / self.performance['modlyn']['time']
            
            categories = ['Speed\nImprovement']
            values = [speedup]
            
            bars = axes[1].bar(categories, values, color='#F18F01', alpha=0.8, edgecolor='white')
            axes[1].axhline(y=1, color='red', linestyle='--', alpha=0.7, label='Baseline')
            axes[1].set_ylabel('Improvement Factor', fontweight='bold')
            axes[1].set_title('MODLYN vs Scanpy', fontweight='bold')
            
            axes[1].text(0, speedup + 0.1, f'{speedup:.1f}x faster', 
                        ha='center', va='bottom', fontweight='bold', fontsize=14)
        
        plt.tight_layout()
        plt.show()
    
    def print_summary(self):
        """Print analysis summary"""
        print("\n" + "="*50)
        print("ANALYSIS SUMMARY")
        print("="*50)
        
        print(f"\nMethods compared: {list(self.performance.keys())}")
        
        print("\nPerformance:")
        for method, perf in self.performance.items():
            print(f"  {method.upper()}:")
            print(f"    Runtime: {perf['time']:.2f}s")
            print(f"    Genes: {perf['n_genes']:,}")
            print(f"    Throughput: {perf['n_genes']/perf['time']:.0f} genes/sec")
        
        if 'scanpy' in self.performance and 'modlyn' in self.performance:
            speedup = self.performance['scanpy']['time'] / self.performance['modlyn']['time']
            print(f"\nMODLYN is {speedup:.1f}x faster than Scanpy!")
        
        print("\nGene Discovery:")
        for method in ['scanpy', 'modlyn']:
            if method in self.results:
                total_genes = sum(len(df) for df in self.results[method].values() if not df.empty)
                avg_genes = total_genes / len(self.results[method])
                print(f"  {method.upper()}: {avg_genes:.0f} genes per cell line (avg)")


In [None]:
# # After copying the simplified version above, use it like this:
# analyzer = NotebookAnalysis()
# results = analyzer.load_and_analyze(
#     chunk_path="/home/ubuntu/tahoe100M_chunk_1",
#     var_path="/home/ubuntu/var_subset_tahoe100M.parquet",
#     n_cells=1000  # adjust as needed
# )

In [None]:
# # Quick analysis with all figures displayed
# analyzer = NotebookAnalysisComplete()
# results = analyzer.run_complete_analysis(
#     chunk_path="/home/ubuntu/tahoe100M_chunk_1",
#     var_path="/home/ubuntu/var_subset_tahoe100M.parquet",
#     n_cells=3000,
#     show_figures=True  # This displays figures inline
# )

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import scanpy as sc
import time
import lightning as L
from tqdm import tqdm
from scipy import stats
from matplotlib_venn import venn2

%matplotlib inline
plt.rcParams.update({'figure.figsize': (14, 8), 'figure.dpi': 100})

def compare_methods(adata, n_cells=3000, n_top=20, min_cells_per_line=10):
    """Run all three methods and display comparison figures"""
    
    # Filter cell lines with enough cells first
    cell_line_counts = adata.obs["cell_line"].value_counts()
    valid_cell_lines = cell_line_counts[cell_line_counts >= min_cells_per_line].index
    adata = adata[adata.obs["cell_line"].isin(valid_cell_lines)].copy()
    print(f"Filtered to {len(valid_cell_lines)} cell lines with ≥{min_cells_per_line} cells each")
    
    # Subset data
    if n_cells and n_cells < adata.n_obs:
        np.random.seed(42)
        idx = np.random.choice(adata.n_obs, n_cells, replace=False)
        adata = adata[idx].copy()
        print(f"Subsetted to {n_cells} cells")
    
    # Ensure required columns
    adata.obs["cell_line"] = adata.obs["cell_line"].astype("category")
    adata.obs["y"] = adata.obs["cell_line"].cat.codes.astype(int)
    print(f"Final data: {adata.n_obs} cells, {adata.n_vars} genes, {adata.obs.cell_line.nunique()} cell lines")
    
    # Run methods
    print("Running Scanpy...")
    start = time.time()
    scanpy_results = run_scanpy(adata)
    scanpy_time = time.time() - start
    
    print("Running MODLYN...")
    start = time.time()
    modlyn_results = run_modlyn(adata)
    modlyn_time = time.time() - start
    
    print("Running LinearSCVI...")
    start = time.time()
    linscvi_results = run_linscvi(adata)
    linscvi_time = time.time() - start
    
    performance = {'scanpy': scanpy_time, 'modlyn': modlyn_time, 'linscvi': linscvi_time}
    
    # Show figures
    show_method_comparison(scanpy_results, modlyn_results, linscvi_results, n_top)
    show_performance(performance)
    show_overlap_analysis(scanpy_results, modlyn_results, n_top)
    
    return scanpy_results, modlyn_results, linscvi_results, performance

def run_scanpy(adata):
    """Run Scanpy DE analysis"""
    adata_sc = adata.copy()
    if hasattr(adata_sc.X, 'compute'):
        adata_sc.X = adata_sc.X.compute()
    
    sc.pp.normalize_total(adata_sc, target_sum=1e4)
    sc.pp.log1p(adata_sc)
    sc.pp.highly_variable_genes(adata_sc, n_top_genes=2000)
    adata_sc = adata_sc[:, adata_sc.var.highly_variable].copy()
    
    results = {}
    for cell_line in adata_sc.obs["cell_line"].cat.categories:
        # Check if cell line has enough cells
        n_cells_in_line = (adata_sc.obs["cell_line"] == cell_line).sum()
        n_cells_other = (adata_sc.obs["cell_line"] != cell_line).sum()
        
        if n_cells_in_line < 3 or n_cells_other < 3:
            print(f"Skipping {cell_line}: insufficient cells ({n_cells_in_line} vs {n_cells_other})")
            results[cell_line] = pd.DataFrame()
            continue
            
        try:
            adata_sc.obs["group"] = (adata_sc.obs["cell_line"] == cell_line).astype(str)
            sc.tl.rank_genes_groups(adata_sc, groupby="group", groups=["True"], reference="False", method="wilcoxon")
            results[cell_line] = sc.get.rank_genes_groups_df(adata_sc, group="True")
        except Exception as e:
            print(f"Scanpy failed for {cell_line}: {e}")
            results[cell_line] = pd.DataFrame()
    
    return results

def run_modlyn(adata):
    """Run MODLYN analysis"""
    from modlyn.io.datamodules import ClassificationDataModule
    from modlyn.models.linear import Linear
    
    n_train = int(0.8 * adata.n_obs)
    datamodule = ClassificationDataModule(
        adata_train=adata[:n_train], adata_val=adata[n_train:], label_column="y",
        train_dataloader_kwargs={"batch_size": 512, "num_workers": 0},
        val_dataloader_kwargs={"batch_size": 512, "num_workers": 0}
    )
    
    model = Linear(n_genes=adata.n_vars, n_covariates=adata.obs["y"].nunique(), learning_rate=1e-2)
    trainer = L.Trainer(max_epochs=3, enable_progress_bar=False, logger=False)
    trainer.fit(model=model, datamodule=datamodule)
    
    weights = model.linear.weight.detach().cpu().numpy()
    results = {}
    
    for class_idx, cell_line in enumerate(adata.obs["cell_line"].cat.categories):
        w = weights[class_idx]
        z_scores = (w - w.mean()) / w.std()
        results[cell_line] = pd.DataFrame({
            "gene": adata.var_names, "weight": w, "abs_weight": np.abs(w),
            "p_value": 2 * (1 - stats.norm.cdf(np.abs(z_scores)))
        }).sort_values("abs_weight", ascending=False)
    
    return results

def run_linscvi(adata):
    """Run LinearSCVI analysis"""
    try:
        import scvi
        from scvi.model import LinearSCVI
        
        adata_scvi = adata.copy()
        if hasattr(adata_scvi.X, 'compute'):
            adata_scvi.X = adata_scvi.X.compute()
        
        sc.pp.filter_genes(adata_scvi, min_counts=3)
        scvi.model.LinearSCVI.setup_anndata(adata_scvi, labels_key="cell_line")
        model = LinearSCVI(adata_scvi, n_latent=10)
        model.train(max_epochs=20, early_stopping=True)
        
        results = {}
        for cell_line in adata_scvi.obs["cell_line"].cat.categories:
            results[cell_line] = model.differential_expression(
                adata_scvi, groupby="cell_line", group1=cell_line, mode="change"
            )
        return results
    except:
        return None

def show_method_comparison(scanpy_results, modlyn_results, linscvi_results, n_top):
    """Display method comparison figure"""
    cell_line = list(modlyn_results.keys())[0]
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 8))
    
    # Scanpy
    if scanpy_results and not scanpy_results[cell_line].empty:
        data = scanpy_results[cell_line].head(n_top)
        y_pos = np.arange(len(data))
        axes[0].barh(y_pos, data["scores"], color='#2E86AB', alpha=0.8)
        axes[0].set_yticks(y_pos)
        axes[0].set_yticklabels(data["names"], fontsize=9)
        axes[0].set_title("Scanpy", fontweight='bold')
        axes[0].set_xlabel("Score")
    
    # MODLYN
    data = modlyn_results[cell_line].head(n_top)
    y_pos = np.arange(len(data))
    colors = ['#F18F01' if w > 0 else '#A23B72' for w in data["weight"]]
    axes[1].barh(y_pos, data["weight"], color=colors, alpha=0.8)
    axes[1].set_yticks(y_pos)
    axes[1].set_yticklabels(data["gene"], fontsize=9)
    axes[1].set_title("MODLYN", fontweight='bold')
    axes[1].set_xlabel("Weight")
    axes[1].axvline(x=0, color='black', alpha=0.5)
    
    # LinearSCVI
    if linscvi_results and cell_line in linscvi_results:
        data = linscvi_results[cell_line].sort_values("lfc_median", ascending=False).head(n_top)
        y_pos = np.arange(len(data))
        axes[2].barh(y_pos, data["lfc_median"], color='#A23B72', alpha=0.8)
        axes[2].set_yticks(y_pos)
        axes[2].set_yticklabels(data.index, fontsize=9)
        axes[2].set_title("LinearSCVI", fontweight='bold')
        axes[2].set_xlabel("LFC")
    else:
        axes[2].text(0.5, 0.5, "LinearSCVI\nNot Available", ha='center', va='center', 
                    transform=axes[2].transAxes, fontsize=14)
    
    plt.suptitle(f'Method Comparison - {cell_line}', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

def show_performance(performance):
    """Display performance comparison"""
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    methods = list(performance.keys())
    times = list(performance.values())
    colors = ['#2E86AB', '#F18F01', '#A23B72']
    
    # Runtime
    bars = axes[0].bar(methods, times, color=colors, alpha=0.8)
    axes[0].set_ylabel('Runtime (seconds)')
    axes[0].set_title('Runtime Comparison')
    for bar, time_val in zip(bars, times):
        axes[0].text(bar.get_x() + bar.get_width()/2., bar.get_height() + max(times)*0.02,
                    f'{time_val:.1f}s', ha='center', va='bottom', fontweight='bold')
    
    # Speedup
    if 'scanpy' in performance and 'modlyn' in performance:
        speedup = performance['scanpy'] / performance['modlyn']
        axes[1].bar(['MODLYN vs Scanpy'], [speedup], color='#F18F01', alpha=0.8)
        axes[1].set_ylabel('Speedup Factor')
        axes[1].set_title('Speed Improvement')
        axes[1].text(0, speedup + 0.1, f'{speedup:.1f}x faster', ha='center', va='bottom', fontweight='bold')
    
    plt.tight_layout()
    plt.show()

def show_overlap_analysis(scanpy_results, modlyn_results, n_top):
    """Display gene overlap analysis"""
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # Venn diagram for first cell line
    cell_line = list(modlyn_results.keys())[0]
    
    scanpy_genes = set()
    if scanpy_results and not scanpy_results[cell_line].empty:
        scanpy_genes = set(scanpy_results[cell_line].head(n_top)["names"])
    
    modlyn_genes = set(modlyn_results[cell_line].head(n_top)["gene"])
    
    if scanpy_genes:
        venn2([scanpy_genes, modlyn_genes], set_labels=('Scanpy', 'MODLYN'),
              set_colors=('#2E86AB', '#F18F01'), alpha=0.7, ax=axes[0])
        axes[0].set_title(f'Gene Overlap - {cell_line}')
    
    # Jaccard similarities across all cell lines
    jaccard_scores = []
    cell_lines = []
    
    for cl in modlyn_results.keys():
        s_genes = set()
        if scanpy_results and cl in scanpy_results and not scanpy_results[cl].empty:
            s_genes = set(scanpy_results[cl].head(n_top)["names"])
        
        m_genes = set(modlyn_results[cl].head(n_top)["gene"])
        
        if s_genes and m_genes:
            jaccard = len(s_genes & m_genes) / len(s_genes | m_genes)
            jaccard_scores.append(jaccard)
            cell_lines.append(cl)
    
    if jaccard_scores:
        axes[1].bar(range(len(jaccard_scores)), jaccard_scores, color='#52B788', alpha=0.8)
        axes[1].set_xticks(range(len(cell_lines)))
        axes[1].set_xticklabels(cell_lines, rotation=45)
        axes[1].set_ylabel('Jaccard Similarity')
        axes[1].set_title('Method Agreement')
        axes[1].axhline(np.mean(jaccard_scores), color='red', linestyle='--', 
                       label=f'Mean: {np.mean(jaccard_scores):.3f}')
        axes[1].legend()
    
    plt.tight_layout()
    plt.show()

# Usage:
# results = compare_methods(your_adata, n_top=20)

In [None]:
# adata.obs["cell_line"] = adata.obs["cell_line"].astype("category")
# adata.obs["y"] = adata.obs["cell_line"].cat.codes.astype(int)

# Run comparison
results = compare_methods(adata, n_cells=1000, n_top=20, min_cells_per_line=10)
