# nichecompass_analyse_sample_integration

Analysis of a trained NicheCompass sample-integration model.

In [None]:
# Parameters (papermill will override these)
from pathlib import Path

outdir = Path.cwd()              # Base output directory (default: cwd)
prefix = "nichecompass"          # Must match training --prefix
timestamp = ""                   # REQUIRED: YYYYMMDD_HHMMSS from training stdout
species = "human"

# Dataset / graph keys (should match training unless you know what you're doing)
spatial_key = "spatial"
n_neighbors = 4
counts_key = "counts"            # Falls back to 'X' if missing
adj_key = "spatial_connectivities"
gp_names_key = "nichecompass_gp_names"
active_gp_names_key = "nichecompass_active_gp_names"
gp_targets_mask_key = "nichecompass_gp_targets"
gp_targets_categories_mask_key = "nichecompass_gp_targets_categories"
gp_sources_mask_key = "nichecompass_gp_sources"
gp_sources_categories_mask_key = "nichecompass_gp_sources_categories"
latent_key = "nichecompass_latent"

# Labels / plotting
sample_key = "batch"
cell_type_key = "Main_molecular_cell_type"
cat_covariates_keys = ["batch"]
spot_size = 0.2

# Clustering and differential analysis
latent_leiden_resolution = 0.2
differential_gp_test_results_key = "nichecompass_differential_gp_test_results"
log_bayes_factor_thresh = 2.3

# Optional: communication analysis. If empty, this section is skipped.
gp_name = ""  # e.g., "Lefty1_ligand_receptor_target_gene_GP"

In [None]:
### Package imports
from pathlib import Path
import json, warnings

import anndata as ad
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
from sklearn.preprocessing import MinMaxScaler

from nichecompass.models import NicheCompass
from nichecompass.utils import (
    create_new_color_dict,
    compute_communication_gp_network,
    generate_enriched_gp_info_plots,
    visualize_communication_gp_network,
)

warnings.filterwarnings("ignore")

In [None]:
### Resolve run paths and load params
# Validate timestamp
assert timestamp, "Parameter 'timestamp' must be provided (YYYYMMDD_HHMMSS)."

run_root = Path(outdir) / f"{prefix}_{timestamp}"
if not run_root.exists():
    raise FileNotFoundError(f"Run directory '{run_root}' does not exist.")

artifacts_folder_path = run_root / "artifacts"
figure_folder_path = artifacts_folder_path / "sample_integration" / prefix / "figures"
model_folder_path = artifacts_folder_path / "sample_integration" / prefix / "model"
figure_folder_path.mkdir(parents=True, exist_ok=True)

# Parse parameters used in training (overrides defaults & user-provided by papermill)
cfg_path = model_folder_path / "run_config.json"
if cfg_path.exists():
    with open(cfg_path, "r", encoding="utf-8") as f:
        rc: dict = json.load(f)
    counts_key = rc.get("counts_key_effective", rc.get("counts_key", counts_key)) or counts_key
    adj_key = rc.get("adj_key", adj_key) or adj_key
    gp_names_key = rc.get("gp_names_key", gp_names_key) or gp_names_key
    active_gp_names_key = rc.get("active_gp_names_key", active_gp_names_key) or active_gp_names_key
    gp_targets_mask_key = rc.get("gp_targets_mask_key", gp_targets_mask_key) or gp_targets_mask_key
    gp_targets_categories_mask_key = rc.get("gp_targets_categories_mask_key", gp_targets_categories_mask_key) or gp_targets_categories_mask_key
    gp_sources_mask_key = rc.get("gp_sources_mask_key", gp_sources_mask_key) or gp_sources_mask_key
    gp_sources_categories_mask_key = rc.get("gp_sources_categories_mask_key", gp_sources_categories_mask_key) or gp_sources_categories_mask_key
    latent_key = rc.get("latent_key", latent_key) or latent_key

    sample_key = rc.get("sample_key", sample_key) or sample_key
    cell_type_key = rc.get("cell_type_key", cell_type_key) or cell_type_key
    cat_covariates_keys = rc.get("cat_covariates_keys", cat_covariates_keys) or cat_covariates_keys
    if not cat_covariates_keys:
        cat_covariates_keys = [sample_key]

In [None]:
# Load trained model + adata
model = NicheCompass.load(dir_path=str(model_folder_path),
                          adata=None,
                          adata_file_name="model.h5ad",
                          gp_names_key=gp_names_key)
samples = model.adata.obs[sample_key].unique().tolist()

### 4.1 Visualize NicheCompass Latent GP Space

Inspect how well the integration worked by visualizing the batch annotations in the latent GP space.

In [None]:
batch_colors = create_new_color_dict(adata=model.adata, cat_key=cat_covariates_keys[0])
cell_type_colors = create_new_color_dict(adata=model.adata, cat_key=cell_type_key)

In [None]:
# Batches: latent and physical per-sample
fig = plt.figure(figsize=(12, 14))
title = fig.suptitle("NicheCompass Batches in Latent and Physical Space", y=0.96, x=0.55, fontsize=20)
ax0 = plt.subplot2grid((2, len(samples)), (0, 0), colspan=len(samples))
sc.pl.umap(adata=model.adata, color=[cat_covariates_keys[0]], palette=batch_colors,
           title="Batches in Latent Space", ax=ax0, show=False)
axs = [ax0]
for idx, sample in enumerate(samples):
    ax = plt.subplot2grid((2, len(samples)), (1, idx))
    sc.pl.spatial(adata=model.adata[model.adata.obs[sample_key]==sample],
                  color=[cat_covariates_keys[0]], palette=batch_colors, spot_size=spot_size,
                  title=f"Batches in Physical Space\n(Sample: {sample})",
                  legend_loc=None, ax=ax, show=False)
    axs.append(ax)

# Create and position shared legend
handles, labels = axs[0].get_legend_handles_labels()
lgd = fig.legend(handles, labels, loc="center left", bbox_to_anchor=(0.98, 0.5))
axs[0].get_legend().remove()

# Adjust, save and display
plt.subplots_adjust(wspace=0.2, hspace=0.25)
fig.savefig(figure_folder_path / "batches_latent_physical_space.svg",
            bbox_extra_artists=(lgd, title), bbox_inches="tight")
plt.show()

In [None]:
# Cell types: latent and physical per-sample
fig = plt.figure(figsize=(12, 14))
title = fig.suptitle("Cell Types in Latent and Physical Space", y=0.96, x=0.55, fontsize=20)
ax0 = plt.subplot2grid((2, len(samples)), (0, 0), colspan=len(samples))
sc.pl.umap(adata=model.adata, color=[cell_type_key], palette=cell_type_colors,
           title="Cell Types in Latent Space", ax=ax0, show=False)
axs = [ax0]
for idx, sample in enumerate(samples):
    ax = plt.subplot2grid((2, len(samples)), (1, idx))
    sc.pl.spatial(adata=model.adata[model.adata.obs[sample_key]==sample],
                  color=[cell_type_key], palette=cell_type_colors, spot_size=spot_size,
                  title=f"Cell Types in Physical Space\n(Sample: {sample})",
                  legend_loc=None, ax=ax, show=False)
    axs.append(ax)

# Create and position shared legend
handles, labels = axs[0].get_legend_handles_labels()
lgd = fig.legend(handles, labels, loc="center left", bbox_to_anchor=(0.98, 0.5))
axs[0].get_legend().remove()

# Adjust, save and display plots
plt.subplots_adjust(wspace=0.2, hspace=0.25)
fig.savefig(figure_folder_path / "cell_types_latent_physical_space.svg",
            bbox_extra_artists=(lgd, title), bbox_inches="tight")
plt.show()

### 4.2 Identify Niches (Leiden on latent graph)

Computes Leiden clustering of the NicheCompass latent GP space to identify spatially consistent cell niches.

In [None]:
if latent_key not in model.adata.uns:
    raise KeyError(
        f"Neighbors key '{latent_key}' not found in adata.uns. "
        "Check that your training run computed neighbors with key_added=latent_key "
        "and that run_config.json matches this notebook."
    )

In [None]:
# Compute latent Leiden clustering
sc.tl.leiden(adata=model.adata, resolution=latent_leiden_resolution,
             key_added=f"latent_leiden_{latent_leiden_resolution}", neighbors_key=latent_key)

latent_cluster_key = f"latent_leiden_{latent_leiden_resolution}"
latent_cluster_colors = create_new_color_dict(adata=model.adata, cat_key=latent_cluster_key)

# Niches: latent and physical per-sample
fig = plt.figure(figsize=(12, 14))
title = fig.suptitle("NicheCompass Niches in Latent and Physical Space", y=0.96, x=0.55, fontsize=20)
ax0 = plt.subplot2grid((2, len(samples)), (0, 0), colspan=len(samples))
sc.pl.umap(adata=model.adata, color=[latent_cluster_key], palette=latent_cluster_colors,
           title="Niches in Latent Space", ax=ax0, show=False)
axs = [ax0]
for idx, sample in enumerate(samples):
    ax = plt.subplot2grid((2, len(samples)), (1, idx))
    sc.pl.spatial(adata=model.adata[model.adata.obs[sample_key]==sample],
                  color=[latent_cluster_key], palette=latent_cluster_colors, spot_size=spot_size,
                  title=f"Niches in Physical Space\n(Sample: {sample})",
                  legend_loc=None, ax=ax, show=False)
    axs.append(ax)

# Create and position shared legend
handles, labels = axs[0].get_legend_handles_labels()
lgd = fig.legend(handles, labels, loc="center left", bbox_to_anchor=(0.98, 0.5))
axs[0].get_legend().remove()

# Adjust, save and display plots
plt.subplots_adjust(wspace=0.2, hspace=0.25)
fig.savefig(figure_folder_path / f"res_{latent_leiden_resolution}_niches_latent_physical_space.svg",
            bbox_extra_artists=(lgd, title), bbox_inches="tight")
plt.show()

### 4.3 Characterise Niches

#### 4.3.1 Niche Composition

In [None]:
# Niche composition (batches)
save_path_batches = figure_folder_path / f"res_{latent_leiden_resolution}_niche_composition_batches.svg"
df_counts = (model.adata.obs.groupby([latent_cluster_key, cat_covariates_keys[0]]).size().unstack())

ax = df_counts.plot(kind="bar", stacked=True, figsize=(10,10))
legend = plt.legend(bbox_to_anchor=(1,1), loc="upper left", prop={"size": 10})
legend.set_title("Batch Annotations", prop={"size": 10})
plt.title("Batch Composition of Niches")
plt.xlabel("Niche")
plt.ylabel("Cell Counts")
plt.savefig(save_path_batches, bbox_extra_artists=(legend,), bbox_inches="tight")
plt.show()

In [None]:
# Niche composition (cell-type)
save_path_celltypes = figure_folder_path / f"res_{latent_leiden_resolution}_niche_composition_cell_types.svg"
df_counts = (model.adata.obs.groupby([latent_cluster_key, cell_type_key]).size().unstack())

ax = df_counts.plot(kind="bar", stacked=True, figsize=(10,10))
legend = plt.legend(bbox_to_anchor=(1,1), loc="upper left", prop={"size": 10})
legend.set_title("Cell Type Annotations", prop={"size": 10})
plt.title("Cell Type Composition of Niches")
plt.xlabel("Niche")
plt.ylabel("Cell Counts")
plt.savefig(save_path_celltypes, bbox_extra_artists=(legend,), bbox_inches="tight")
plt.show()

#### 4.3.2 Differential GPs
We test which GPs are differentially expressed in a niche.

To this end, we will perform "one-vs-rest" differential GP testing, i.e all niches (```selected_cats = None```) are tested against all other niches (```comparison_cats = "rest"```).

However, differential GP testing can also be performed in the following ways:
- Set ```selected_cats = ["0"]``` to perform differential GP testing for a specific niche only, in this case niche "0".
- Set ```comparison_cats = ["2"]``` to perform differential GP testing against niche "2" as opposed to against all other niches.

We choose an absolute log bayes factor threshold of 2.3 to determine strongly enriched GPs (see https://en.wikipedia.org/wiki/Bayes_factor).

In [None]:
# Check number of active GPs
active_gps = model.get_active_gps()
print(f"Total GPs: {len(model.adata.uns[gp_names_key])}; Active GPs: {len(active_gps)}.")

In [None]:
# Display example active GPs
gp_summary_df = model.get_gp_summary()
display(gp_summary_df[gp_summary_df["gp_active"] == True].head())

In [None]:
# Run differential gp testing and store the result
selected_cats = None
comparison_cats = "rest"

enriched_gps = model.run_differential_gp_tests(
    cat_key=latent_cluster_key,
    selected_cats=selected_cats,
    comparison_cats=comparison_cats,
    log_bayes_factor_thresh=log_bayes_factor_thresh,
)

# Store results
if differential_gp_test_results_key in model.adata.uns:
    res_path = figure_folder_path / f"log_bayes_factor_{log_bayes_factor_thresh}_niche_enriched_gps_summary.csv"
    gp_summary_cols = [
        "gp_name","n_source_genes","n_non_zero_source_genes","n_target_genes","n_non_zero_target_genes",
        "gp_source_genes","gp_target_genes","gp_source_genes_importances","gp_target_genes_importances"
    ]
    enriched_gp_summary_df = gp_summary_df[gp_summary_df["gp_name"].isin(enriched_gps)]
    cat_dtype = pd.CategoricalDtype(categories=enriched_gps, ordered=True)
    enriched_gp_summary_df = enriched_gp_summary_df.assign(gp_name=enriched_gp_summary_df["gp_name"].astype(cat_dtype))
    enriched_gp_summary_df = enriched_gp_summary_df.sort_values(by="gp_name")[gp_summary_cols]
    enriched_gp_summary_df.to_csv(res_path, index=False)
    display(enriched_gp_summary_df)

In [None]:
# Heatmap of normalized GP activities across niches
if not enriched_gps:
    print(f"No enriched GPs at log_bayes_factor_thresh={log_bayes_factor_thresh}. Skipping heatmap.")

else:
    df = model.adata.obs[[latent_cluster_key] + enriched_gps].groupby(latent_cluster_key).mean()

    scaler = MinMaxScaler()
    normalized_df = pd.DataFrame(scaler.fit_transform(df), columns=df.columns, index=df.index)

    plt.figure(figsize=(16,8))
    ax = sns.heatmap(normalized_df, cmap="viridis", annot=False, linewidths=0)
    plt.xticks(rotation=45, fontsize=8, ha="right")
    plt.xlabel("Gene Programs", fontsize=16)
    plt.tight_layout()
    plt.savefig(figure_folder_path / "enriched_gps_heatmap.svg", bbox_inches="tight")
    plt.show()

In [None]:
if enriched_gps:
    plot_label = f"log_bayes_factor_{log_bayes_factor_thresh}_cluster_{selected_cats[0] if selected_cats else 'None'}_vs_rest"

    generate_enriched_gp_info_plots(
        plot_label=plot_label,
        model=model,
        sample_key=sample_key,
        differential_gp_test_results_key=differential_gp_test_results_key,
        cat_key=latent_cluster_key,
        cat_palette=latent_cluster_colors,
        n_top_enriched_gp_start_idx=20,
        n_top_enriched_gp_end_idx=30,
        feature_spaces=samples, # ["latent"]
        n_top_genes_per_gp=3,
        save_figs=True,
        figure_folder_path=f"{figure_folder_path}/",
        spot_size=spot_size)

else:
    print("No enriched GPs; skipping per-GP info plots.")

### 4.3.3 Cell-cell Communication (optional)

We use the inferred activity of an enriched combined interaction GP to analyze the involved intercellular interactions.

In [None]:
if gp_name:
    network_df = compute_communication_gp_network(
        gp_list=[gp_name],
        model=model,
        group_key=latent_cluster_key,
        n_neighbors=n_neighbors,
    )

    visualize_communication_gp_network(
        adata=model.adata,
        network_df=network_df,
        figsize=(9,7),
        cat_colors=create_new_color_dict(adata=model.adata, cat_key=latent_cluster_key),
        edge_type_colors=["#1f77b4"],
        cat_key=latent_cluster_key,
        save=True,
        save_path=f"{figure_folder_path}/gp_network_{gp_name}.svg",
    )

else:
    print("Skipping communication analysis (no gp_name provided).")