# Downstream Analysis

This notebook should be used for downstream analysis of your OPS screen.
Cells marked with <font color='red'>SET PARAMETERS</font> contain crucial variables that need to be set according to your specific experimental setup and data organization.
Please review and modify these variables as needed before proceeding with the analysis.

## <font color='red'>SET PARAMETERS</font>

### Fixed parameters for cluster module

- `CONFIG_FILE_PATH`: Path to a Brieflow config file used during processing. Absolute or relative to where workflows are run from.

In [None]:
CONFIG_FILE_PATH = "config/config.yml"

In [None]:
from pathlib import Path
import warnings
warnings.filterwarnings('ignore', category=UserWarning)
import yaml
import pandas as pd
import matplotlib.pyplot as plt

from lib.cluster.cluster_analysis import (
    differential_analysis, 
    waterfall_plot, 
    two_feature_plot, 
    cluster_heatmap,
    volcano_plot,
)
from lib.cluster.cluster_eval import merge_bootstrap_with_genes
from lib.cluster.phate_leiden_clustering import plot_phate_leiden_clusters
from lib.shared.metrics import get_all_stats

In [None]:
# load config file and determine root path
with open(CONFIG_FILE_PATH, "r") as config_file:
    config = yaml.safe_load(config_file)
    ROOT_FP = Path(config["all"]["root_fp"])

## Pipeline Statistics Report

This function analyzes the entire data processing pipeline from raw microscopy images to clustered perturbation profiles. It reports:

1. **Preprocessing**: Input files and generated image tiles
2. **SBS**: Cell segmentation and barcode mapping success rates  
3. **Phenotype**: Cells and morphological features extracted
4. **Merge**: Phenotype/SBS matching rates and single gene mapping
5. **Aggregation**: Perturbation coverage and cell counts per class
6. **Clustering**: Pathway enrichment metrics (CORUM, KEGG, STRING)

To include batch effect metrics (slower), use: `get_all_stats(config, include_batch_effects=True)`

In [None]:
statistics = get_all_stats(config)

## Find Optimal Resolution

Use benchmark results from Snakemake outputs to identify the best Leiden resolution for each cell class/channel combination.

In [None]:
# Load cell classes and channel combos from config
cluster_combo_fp = config["cluster"]["cluster_combo_fp"]
cluster_combos = pd.read_csv(cluster_combo_fp, sep="\t")

CHANNEL_COMBOS = list(cluster_combos["channel_combo"].unique())
CELL_CLASSES = list(cluster_combos["cell_class"].unique())

print(f"Channel Combos: {CHANNEL_COMBOS}")
print(f"Cell classes: {CELL_CLASSES}")

In [None]:
from lib.cluster.cluster_eval import find_optimal_resolution, plot_resolution_comparison

# Find optimal resolution for each cell class/channel combo
optimal_resolutions = {}

for cell_class in CELL_CLASSES:
    for channel_combo in CHANNEL_COMBOS:
        try:
            result = find_optimal_resolution(
                root_fp=ROOT_FP,
                channel_combo=channel_combo,
                cell_class=cell_class,
                use_filtered=False,
                metric="corum_enrichment",  # or "combined", "kegg_enrichment", "string_f1"
            )
            key = f"{cell_class}_{channel_combo}"
            optimal_resolutions[key] = result
            
            print(f"\n{cell_class} / {channel_combo}:")
            print(f"  Optimal resolution: {result['optimal_resolution']}")
            print(f"  Metric used: {result['metric_used']}")
            
            # Show metrics comparison plot
            fig = plot_resolution_comparison(result['all_results'])
            plt.suptitle(f"{cell_class} / {channel_combo}", fontsize=12)
            plt.tight_layout()
            plt.show()
            
            # Display results table
            display(result['all_results'][['resolution', 'corum_enrichment', 'kegg_enrichment', 'string_f1', 'corum_num_enriched']])
            
        except Exception as e:
            print(f"\n{cell_class} / {channel_combo}: No benchmark results found ({e})")

## <font color='red'>SET PARAMETERS</font>

### Cluster Selection for Analysis

Set these parameters to select the specific cluster data to analyze:
- `CHANNEL_COMBO`: Select from available channel combinations,
- `CELL_CLASS`: Select from available cell classes,
- `LEIDEN_RESOLUTION`: Select from available Leiden resolutions,

These parameters determine which folder of cluster data will be analyzed.

In [None]:
CHANNEL_COMBO = None
CELL_CLASS = None
LEIDEN_RESOLUTION = None

In [None]:
aggregate_file = ROOT_FP / "aggregate" / "tsvs" / f"CeCl-{CELL_CLASS}_ChCo-{CHANNEL_COMBO}__features_genes.tsv"
print(f"Aggregate file: {aggregate_file}")

if not aggregate_file.exists():
    print(f"Aggregate file does not exist: {aggregate_file}")
else:
    print(f"Aggregate file found")

cluster_path = ROOT_FP / "cluster" / CHANNEL_COMBO / CELL_CLASS / str(LEIDEN_RESOLUTION)
print(f"Cluster path: {cluster_path}")

if not cluster_path.exists():
    print(f"Cluster directory does not exist: {cluster_path}")
else:
    print(f"Cluster directory found")

## Feature Plot Analysis

This section generates visualizations to explore the phenotypic effects of gene perturbations in your screen. The plots will help you:

1. **Differential Feature Analysis**: Identify genes with significant phenotypic changes vs. controls
2. **Waterfall Plots**: Rank genes by their effect on specific features of interest
3. **Two-Feature Plots**: Discover relationships between different phenotypic measurements
4. **Heatmaps**: Visualize patterns across multiple features and gene sets simultaneously

The interactive analysis allow you to customize each visualization for your specific biological questions.

In [None]:
cluster_file = cluster_path / "phate_leiden_clustering.tsv"
cluster_df = pd.read_csv(cluster_file, sep="\t")
display(cluster_df)

In [None]:
aggregate_df = pd.read_csv(aggregate_file, sep="\t")
display(aggregate_df)

## <font color='red'>SET PARAMETERS</font>

### Cluster Selection for Visualization

Set these parameters to select the specific cluster to analyze:
- `CLUSTER_ID`: The cluster of interest to generate plots.

In [None]:
CLUSTER_ID = None

In [None]:
print(f"\n{'='*50}")
print(f"Analyzing Cluster {CLUSTER_ID}")
cluster_genes = cluster_df[cluster_df['cluster'] == CLUSTER_ID][config["aggregate"]["perturbation_name_col"]]
print(f"Genes in Cluster {CLUSTER_ID}:" ,", ".join(cluster_genes.unique()))
print(f"{'='*50}")

# Run differential analysis using robust z-score
diff_results = differential_analysis(
    feature_df=aggregate_df,  
    cluster_df=cluster_df,    
    cluster_id=CLUSTER_ID,
    control_type="nontargeting",  
    control_label=config["aggregate"]["control_key"],
    use_nonparametric=True,  
    normalize_method="robust_zscore"
)

# Display results in a more user-friendly format
print("\nTop Upregulated Features:")
up_df = diff_results['top_up'][['feature', 'robust_zscore', 'p_value', 'median_test', 'median_control']]
up_df = up_df.rename(columns={
    'robust_zscore': 'Z-score',
    'p_value': 'p-value',
    'median_test': 'Median (test)',
    'median_control': 'Median (control)'
})
display(up_df.style.format({
    'Z-score': '{:.2f}',
    'p-value': '{:.2e}',
    'Median (test)': '{:.3f}',
    'Median (control)': '{:.3f}'
}))

print("\nTop Downregulated Features:")
down_df = diff_results['top_down'][['feature', 'robust_zscore', 'p_value', 'median_test', 'median_control']]
down_df = down_df.rename(columns={
    'robust_zscore': 'Z-score',
    'p_value': 'p-value',
    'median_test': 'Median (test)',
    'median_control': 'Median (control)'
})
display(down_df.style.format({
    'Z-score': '{:.2f}',
    'p-value': '{:.2e}',
    'Median (test)': '{:.3f}',
    'Median (control)': '{:.3f}'
}))

## <font color='red'>SET PARAMETERS</font>

### Feature Selection for Visualization

Set these parameters to select the specific cluster to analyze:
- `FEATURES_TO_ANALYZE`: The features found in the differential analysis to generate plots for.
- `GENES_TO_LABEL`: The genes within the cluster to label on the plots.

In [None]:
FEATURES_TO_ANALYZE = None
GENES_TO_LABEL = None

In [None]:
# Create waterfall plots for selected features
print("\n--- Waterfall Plots for Selected Features ---")
for feature in FEATURES_TO_ANALYZE:
    print(f"\nPlotting: {feature}")
    waterfall_plot(
        feature_df=aggregate_df, 
        feature=feature,
        cluster_df=cluster_df,
        cluster_id=CLUSTER_ID,
        nontargeting_pattern=config["aggregate"]["control_key"],
        title=f"Cluster {CLUSTER_ID}: {feature}",
        label_genes=GENES_TO_LABEL
    )

# Create two-feature plots for combinations
print("\n--- Two-Feature Plots ---")
# Generate all unique pairs of features
feature_pairs = [(FEATURES_TO_ANALYZE[i], FEATURES_TO_ANALYZE[j]) 
                 for i in range(len(FEATURES_TO_ANALYZE)) 
                 for j in range(i+1, len(FEATURES_TO_ANALYZE))]

for feature1, feature2 in feature_pairs:
    print(f"\nPlotting: {feature1} vs {feature2}")
    two_feature_plot(
        feature_df=aggregate_df,
        x=feature1,
        y=feature2,
        cluster_df=cluster_df,
        cluster_id=CLUSTER_ID,
        nontargeting_pattern=config["aggregate"]["control_key"],
        title=f"Cluster {CLUSTER_ID}: {feature1} vs {feature2}",
        label_genes=GENES_TO_LABEL
    )

# Create heatmap with all differential features
print("\n--- Heatmap of Differential Features ---")
# Get top features from differential analysis
top_up = diff_results['top_up']['feature'].tolist() if not diff_results['top_up'].empty else []
top_down = diff_results['top_down']['feature'].tolist() if not diff_results['top_down'].empty else []
diff_features = top_up + top_down

fig, ax, heatmap_data = cluster_heatmap(
    feature_df=aggregate_df,
    cluster_df=cluster_df,
    cluster_ids=[CLUSTER_ID],
    features=diff_features,
    perturbation_name_col=config["aggregate"]["perturbation_name_col"],
    z_score="global",
    title=f"Cluster {CLUSTER_ID}: Top Differential Features",
)

# PHATE plot for the cluster
print("\n--- PHATE Plot for Cluster ---")
phate_fig = plot_phate_leiden_clusters(
    phate_leiden_clustering=cluster_df,
    perturbation_name_col=config["aggregate"]["perturbation_name_col"],
    control_key=config["aggregate"]["control_key"],
    clusters_of_interest=[CLUSTER_ID],
)

## Volcano Plot Analysis

Volcano plots visualize the relationship between effect size (z-score) and statistical significance (-log10 p-value) from bootstrap analysis. This helps identify genes with both large effects and high confidence.

**Prerequisites:** Bootstrap analysis must have been run during the aggregate step.

In [None]:
# Load bootstrap results
bootstrap_file = ROOT_FP / "aggregate" / "bootstrap" / f"CeCl-{CELL_CLASS}_ChCo-{CHANNEL_COMBO}__all_gene_bootstrap_results.tsv"

if bootstrap_file.exists():
    bootstrap_df = pd.read_csv(bootstrap_file, sep="\t")
    print(f"Loaded bootstrap results: {len(bootstrap_df)} genes")
    
    # Merge bootstrap with feature table
    merged_df = merge_bootstrap_with_genes(
        bootstrap_df=bootstrap_df,
        genes_df=aggregate_df,
        perturbation_name_col=config["aggregate"]["perturbation_name_col"],
        bootstrap_gene_col="gene",
    )
    print(f"Merged data: {len(merged_df)} genes with {len([c for c in merged_df.columns if c.endswith('_fdr')])} features tested")
    display(merged_df.head())
else:
    print(f"Bootstrap file not found: {bootstrap_file}")
    print("Run the aggregate module with bootstrap enabled to generate this data.")
    merged_df = None

## <font color='red'>SET PARAMETERS</font>

### Volcano Plot Configuration

- `VOLCANO_FEATURE`: Feature to plot (must match a bootstrapped feature)
- `VOLCANO_FDR_THRESHOLD`: FDR threshold for significance
- `VOLCANO_ZSCORE_THRESHOLD`: Z-score threshold for effect size

In [None]:
# Volcano plot parameters
VOLCANO_FEATURE = None  # Set to a feature from differential analysis (e.g., from diff_results)
VOLCANO_FDR_THRESHOLD = 0.05
VOLCANO_ZSCORE_THRESHOLD = 2.0

In [None]:
if merged_df is not None and VOLCANO_FEATURE is not None:
    # Check if feature has bootstrap data
    log10_col = f"{VOLCANO_FEATURE}_log10"
    if log10_col in merged_df.columns:
        print(f"Creating volcano plot for: {VOLCANO_FEATURE}")
        
        fig, ax = volcano_plot(
            merged_df=merged_df,
            feature=VOLCANO_FEATURE,
            perturbation_name_col=config["aggregate"]["perturbation_name_col"],
            cluster_df=cluster_df,
            cluster_id=CLUSTER_ID,
            fdr_threshold=VOLCANO_FDR_THRESHOLD,
            zscore_threshold=VOLCANO_ZSCORE_THRESHOLD,
            title=f"Cluster {CLUSTER_ID}: {VOLCANO_FEATURE}",
            label_genes=GENES_TO_LABEL,
        )
        plt.show()
        
        # Show summary statistics
        fdr_col = f"{VOLCANO_FEATURE}_fdr"
        if fdr_col in merged_df.columns:
            sig_up = ((merged_df[fdr_col] < VOLCANO_FDR_THRESHOLD) & (merged_df[VOLCANO_FEATURE] >= VOLCANO_ZSCORE_THRESHOLD)).sum()
            sig_down = ((merged_df[fdr_col] < VOLCANO_FDR_THRESHOLD) & (merged_df[VOLCANO_FEATURE] <= -VOLCANO_ZSCORE_THRESHOLD)).sum()
            print(f"\nSignificant genes (FDR < {VOLCANO_FDR_THRESHOLD}, |z| >= {VOLCANO_ZSCORE_THRESHOLD}):")
            print(f"  Upregulated: {sig_up}")
            print(f"  Downregulated: {sig_down}")
    else:
        print(f"Feature '{VOLCANO_FEATURE}' not found in bootstrap results.")
        print(f"Available features: {[c.replace('_log10', '') for c in merged_df.columns if c.endswith('_log10')][:10]}...")
elif merged_df is None:
    print("No bootstrap data available - load bootstrap results first")
else:
    print("Set VOLCANO_FEATURE to create a volcano plot")

## Cell Montage Generation

Generate montages of cells sorted by a specific feature value. This helps visualize the morphological phenotypes associated with perturbations.

**Prerequisites:** Requires phenotype images to be available in the output directory.

## <font color='red'>SET PARAMETERS</font>

### Montage Configuration

- `MONTAGE_FEATURE`: Feature to sort cells by (e.g., from differential analysis above)
- `MONTAGE_GENE`: Gene/perturbation to generate montage for (or None for top perturbations)
- `MONTAGE_NUM_CELLS`: Number of cells to include in montage
- `MONTAGE_ASCENDING`: Sort order (True = lowest values first)

In [None]:
# Montage parameters
MONTAGE_FEATURE = None  # Set to a feature name from differential analysis
MONTAGE_GENE = None  # Set to a gene symbol, or None for top perturbations by feature
MONTAGE_NUM_CELLS = 30
MONTAGE_CELL_SIZE = 40  # Pixel size of cell bounding box
MONTAGE_SHAPE = (3, 10)  # Grid shape (rows, cols)
MONTAGE_ASCENDING = False  # False = highest values first

In [None]:
from lib.aggregate.montage_utils import create_cell_montage, add_filenames
from skimage.exposure import rescale_intensity
import numpy as np

if MONTAGE_FEATURE is not None:
    # Load merge data with cell-level info (needed for image paths and coordinates)
    merge_data_path = ROOT_FP / "merge" / "parquet" / f"CeCl-{CELL_CLASS}__merge.parquet"
    
    if merge_data_path.exists():
        merge_data = pd.read_parquet(merge_data_path)
        
        # Add image file paths
        merge_data = add_filenames(merge_data, ROOT_FP)
        
        # Filter to specific gene if provided
        if MONTAGE_GENE is not None:
            cell_data = merge_data[merge_data[config["aggregate"]["perturbation_name_col"]] == MONTAGE_GENE].copy()
            title_suffix = f"Gene: {MONTAGE_GENE}"
        else:
            # Use top perturbations by feature value
            cell_data = merge_data.copy()
            title_suffix = "All perturbations"
        
        if len(cell_data) > 0:
            # Get channel names from config
            channels = config["phenotype"]["channel_names"]
            
            # Create montage sorted by feature
            montages = create_cell_montage(
                cell_data=cell_data,
                channels=channels,
                num_cells=MONTAGE_NUM_CELLS,
                cell_size=MONTAGE_CELL_SIZE,
                shape=MONTAGE_SHAPE,
                selection_params={
                    "method": "sorted",
                    "sort_by": MONTAGE_FEATURE,
                    "ascending": MONTAGE_ASCENDING,
                },
            )
            
            # Display montages for each channel
            fig, axes = plt.subplots(1, len(channels), figsize=(4 * len(channels), 4))
            if len(channels) == 1:
                axes = [axes]
            
            for ax, (channel, montage_array) in zip(axes, montages.items()):
                # Normalize for display
                montage_display = rescale_intensity(montage_array, in_range="image", out_range=(0, 1))
                ax.imshow(montage_display, cmap="gray")
                ax.set_title(channel)
                ax.axis("off")
            
            sort_order = "ascending" if MONTAGE_ASCENDING else "descending"
            plt.suptitle(f"Montage sorted by {MONTAGE_FEATURE} ({sort_order})\n{title_suffix}", fontsize=12)
            plt.tight_layout()
            plt.show()
        else:
            print(f"No cells found for gene: {MONTAGE_GENE}")
    else:
        print(f"Merge data not found: {merge_data_path}")
else:
    print("Set MONTAGE_FEATURE to generate a montage")

## Mozzarellm: LLM-based Gene Cluster Analysis

### Overview
[Mozzarellm](https://github.com/cheeseman-lab/mozzarellm) is a Python package that leverages Large Language Models (LLMs) to analyze gene clusters for pathway identification and novel gene discovery.

### Prerequisites

Install mozzarellm in your Brieflow environment:

```bash
pip install git+https://github.com/cheeseman-lab/mozzarellm.git
```

Set up API keys in a `.env` file in the analysis directory:

```bash
# .env file
ANTHROPIC_API_KEY=your_key_here
# or OPENAI_API_KEY=your_key_here
# or GOOGLE_API_KEY=your_key_here
```

### Workflow

1. **Configure** mozzarellm parameters below and write to `config.yml`
2. **Run** the analysis script: `bash 13.run_mozzarellm.sh`

The script supports resume (automatically skips completed clusters) and saves results incrementally.

## <font color='red'>SET PARAMETERS</font>

### Mozzarellm Configuration

Configure which clustering result to analyze with mozzarellm:

- `MOZZARELLM_CELL_CLASS`: Cell class to analyze
- `MOZZARELLM_CHANNEL_COMBO`: Channel combination
- `MOZZARELLM_RESOLUTION`: Leiden resolution (choose your optimal resolution)
- `MOZZARELLM_MODEL`: LLM model to use
- `MOZZARELLM_SCREEN_CONTEXT`: Context about your screen for the LLM

In [None]:
# Mozzarellm configuration - set these based on your analysis above
MOZZARELLM_CELL_CLASS = CELL_CLASS  # Use the same as selected above, or override
MOZZARELLM_CHANNEL_COMBO = CHANNEL_COMBO  # Use the same as selected above, or override
MOZZARELLM_RESOLUTION = LEIDEN_RESOLUTION  # Use the same as selected above, or override

# Model selection
MOZZARELLM_MODEL = "claude-sonnet-4-5-20250929"  # or "gpt-4o", "gemini-2.0-flash", etc.
MOZZARELLM_TEMPERATURE = 0.0

# Screen context - customize this for your specific screen
MOZZARELLM_SCREEN_CONTEXT = """
These clusters are from an optical pooled screen (OPS) that measured morphological
phenotypes in human cells. The screen involved perturbing genes using CRISPR knockout
and imaging the resulting cellular morphology via fluorescence microscopy using a
cell-painting-derived panel. Genes grouped within a cluster tend to exhibit similar
phenotypes, suggesting they may participate in the same biological process or pathway.
"""

In [None]:
# Verify the clustering file exists
mozzarellm_cluster_path = ROOT_FP / "cluster" / MOZZARELLM_CHANNEL_COMBO / MOZZARELLM_CELL_CLASS / str(MOZZARELLM_RESOLUTION)
mozzarellm_cluster_file = mozzarellm_cluster_path / "phate_leiden_clustering.tsv"

print(f"Mozzarellm will analyze: {mozzarellm_cluster_file}")
if mozzarellm_cluster_file.exists():
    cluster_preview = pd.read_csv(mozzarellm_cluster_file, sep="\t")
    print(f"File exists: {len(cluster_preview)} genes, {cluster_preview['cluster'].nunique()} clusters")
else:
    print(f"File not found - make sure clustering has been run")

## Write Mozzarellm Configuration to config.yml

Run this cell to save the mozzarellm configuration. The shell script `13.run_mozzarellm.sh` will read these parameters.

In [None]:
# Add mozzarellm section to config
config["mozzarellm"] = {
    "cell_class": MOZZARELLM_CELL_CLASS,
    "channel_combo": MOZZARELLM_CHANNEL_COMBO,
    "leiden_resolution": MOZZARELLM_RESOLUTION,
    "model": MOZZARELLM_MODEL,
    "temperature": MOZZARELLM_TEMPERATURE,
    "screen_context": MOZZARELLM_SCREEN_CONTEXT.strip(),
}

# Write updated config
from lib.shared.configuration_utils import CONFIG_FILE_HEADER

with open(CONFIG_FILE_PATH, "w") as config_file:
    config_file.write(CONFIG_FILE_HEADER)
    yaml.dump(config, config_file, default_flow_style=False, sort_keys=False)

print("Mozzarellm configuration written to config.yml")
print(f"\nConfiguration:")
print(f"  Cell class: {MOZZARELLM_CELL_CLASS}")
print(f"  Channel combo: {MOZZARELLM_CHANNEL_COMBO}")
print(f"  Resolution: {MOZZARELLM_RESOLUTION}")
print(f"  Model: {MOZZARELLM_MODEL}")

## Run Mozzarellm Analysis

Now run the analysis from the terminal:

```bash
cd analysis/
bash 13.run_mozzarellm.sh
```

The script will:
1. Read configuration from `config.yml`
2. Load the clustering results
3. Analyze each cluster with the LLM (with progress bar)
4. Save results incrementally (can resume if interrupted)

**Output location:** `{cluster_path}/mozzarellm/`
- `clusters/` - Individual cluster results (JSON, for resume)
- `{model}_results.json` - Combined analysis with all clusters
- `{model}_summaries.tsv` - One row per cluster (pathway, confidence, counts)
- `{model}_flagged_genes.tsv` - One row per flagged gene (priority, rationale)