# Spatial-IT RRMap OT Analysis

This notebook analyzes timepoint transitions in an `AnnData` object using the OT workflow in `rrmap_ot.py`.

## Objective
- Compute transitions between consecutive `course` labels.
- Compare source-to-destination `anno_L2` state flow.
- Review QC diagnostics (entropy + mass diagnostics for unbalanced OT).


In [4]:
from __future__ import annotations

from pathlib import Path
import os
import sys
import warnings

# Avoid cache issues in read-only/home-restricted environments
NOTEBOOK_TMP = Path.cwd().resolve() / "tmp"
os.environ.setdefault("MPLCONFIGDIR", str(NOTEBOOK_TMP / "matplotlib"))
os.environ.setdefault("NUMBA_CACHE_DIR", str(NOTEBOOK_TMP / "numba_cache"))
Path(os.environ["MPLCONFIGDIR"]).mkdir(parents=True, exist_ok=True)
Path(os.environ["NUMBA_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import scanpy as sc
from IPython.display import display

# Reproducibility
SEED = 0
np.random.seed(SEED)

# Resolve repo root even when Jupyter starts from a subdirectory
def find_project_root(start: Path, marker: str = "rrmap_ot.py") -> Path:
    for candidate in [start.resolve(), *start.resolve().parents]:
        if (candidate / marker).exists():
            return candidate
    raise FileNotFoundError(
        f"Could not find {marker} by walking up from {start.resolve()}."
    )

PROJECT_ROOT = find_project_root(Path.cwd())
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

import rrmap_ot

print(f"Notebook cwd: {Path.cwd().resolve()}")
print(f"Project root: {PROJECT_ROOT}")
print(f"Using rrmap_ot from: {Path(rrmap_ot.__file__).resolve()}")


Notebook cwd: /Users/christoffer/work/karolinska/development/spatial-OT/output/jupyter-notebook
Project root: /Users/christoffer/work/karolinska/development/spatial-OT
Using rrmap_ot from: /Users/christoffer/work/karolinska/development/spatial-OT/rrmap_ot.py


## Configuration

Set `ADATA_PATH` to your RRMap / Spatial-IT `.h5ad` file.

Defaults here use unbalanced centroid OT (recommended first pass for large datasets).


In [5]:
# ---- User configuration ----
ADATA_PATH = Path("/Volumes/processing2/RRmap/data/RRmap_metadata_fixed_update.h5ad")

COURSE_KEY = "course"
STATE_KEY = "anno_L2"
EMBEDDING_KEY = None  # None => rrmap_ot fallback: X_scVI -> X_umap -> X_pca

# Run one model at a time to avoid cross-model transitions.
MODELS = {
    "Chronic": {
        "baseline": "MOG CFA",
        "courses": ["MOG CFA", "non symptomatic", "early onset", "chronic peak", "chronic long"],
    },
    "RR": {
        "baseline": "PLP CFA",
        "courses": [
            "PLP CFA", "onset I", "onset II", "monophasic",
            "peak I", "remitt I", "peak II", "remitt II", "peak III",
        ],
    },
}
MODEL_NAME = "Chronic"  # "Chronic", "RR", or None for full-dataset run
STRIP_COURSE_LABELS = True
COURSE_ORDER = None  # Used only when MODEL_NAME is None

METHOD = "unbalanced"  # "balanced" or "unbalanced"
MODE = "centroid"      # "centroid" (fast) or "cell" (heavy)
REG = 0.05
REG_M = 10.0
MAX_CELLS_PER_COURSE = 20_000
UNBALANCED_MASS_MODE = "normalized"  # "normalized" or "raw"
INCLUDE_ALL_STATES = False
RANDOM_STATE = 0

TOP_K = 3


In [6]:
# Load adata
if "adata" in globals() and hasattr(adata, "obs"):
    print("Using existing `adata` already present in notebook session.")
else:
    if not ADATA_PATH.exists():
        raise FileNotFoundError(
            f"ADATA_PATH does not exist: {ADATA_PATH}\n"
            "Set ADATA_PATH in the config cell to your .h5ad file."
        )
    adata = sc.read_h5ad(ADATA_PATH)

print(adata)


AnnData object with n_obs × n_vars = 877141 × 5101
    obs: 'cell', 'centroid_x', 'centroid_y', 'centroid_z', 'component', 'volume', 'surface_area', 'scale', 'region', 'sample_id', 'proseg_cluster', 'output_folder', 'Num', 'n_genes', 'n_counts', 'louvain_0.5', 'louvain_1', 'louvain_1.5', 'louvain_2', 'louvain_2.5', 'louvain_3', 'louvain_3.5', 'Cluster', 'Level1', 'Level2', 'Level3', 'Level3.1', 'grid_label', 'rbd_domain', 'rbd_domain_0.1', 'rbd_domain_0.2', 'rbd_domain_0.3', 'rbd_domain_0.5', 'rbd_domain_0.6', 'rbd_domain_0.7', 'rbd_domain_0.8', 'rbd_domain_0.9', 'rbd_domain_1.1', 'rbd_domain_1', 'rbd_domain_1.25', 'rbd_domain_1.4', 'rbd_domain_1.5', 'leiden_0.5', 'leiden_1', 'leiden_1.5', 'leiden_2', 'leiden_2.5', 'leiden_3', 'leiden_3.5', 'sample_name', 'course', 'condition', 'model', 'cytetype_annotation_louvain_3.5', 'cytetype_cellOntologyTerm_louvain_3.5', 'cluster_id', 'author_label', 'annotation', 'Class', 'state', 'CL_term', 'CL_term_id', 'confidence', 'author_label_similarity_

In [7]:
# Basic validation and quick summary
required_obs = [COURSE_KEY, STATE_KEY]
missing_obs = [k for k in required_obs if k not in adata.obs]
if missing_obs:
    raise KeyError(f"Missing required obs keys: {missing_obs}")

print("Required obs keys present.")
print(f"Available embeddings in adata.obsm: {list(adata.obsm.keys())}")

print("\nCells per course (raw labels):")
display(adata.obs[COURSE_KEY].value_counts(dropna=False).rename("n_cells").to_frame())

if STRIP_COURSE_LABELS:
    course_raw = adata.obs[COURSE_KEY].astype("string")
    course_clean = course_raw.str.strip()
    n_changed = int((course_raw != course_clean).fillna(False).sum())
    if n_changed > 0:
        print(f"\nDetected {n_changed} cells with leading/trailing spaces in `{COURSE_KEY}` labels.")
        print("Cells per course (stripped labels):")
        display(course_clean.value_counts(dropna=False).rename("n_cells").to_frame())

print("\nCells per state (top 20):")
display(adata.obs[STATE_KEY].value_counts(dropna=False).head(20).rename("n_cells").to_frame())


Required obs keys present.
Available embeddings in adata.obsm: ['X_mana_gauss', 'X_mana_gauss_2neigh', 'X_pca', 'X_scVI', 'X_umap', 'spatial']

Cells per course:


Unnamed: 0_level_0,n_cells
course,Unnamed: 1_level_1
peak III,108344
peak I,105761
monophasic,93898
chronic long,78239
remitt I,71522
peak II,70152
chronic peak,65703
PLP CFA,59168
early onset,58116
onset II,39470



Cells per state (top 20):


Unnamed: 0_level_0,n_cells
anno_L2,Unnamed: 1_level_1
Oligodendrocyte,174835
Interneuron,130506
Microglia,103816
Astrocyte,102125
Endothelial,60810
Meningeal fibroblast,42899
Macrophage,42745
Neuron,39572
T cell,31211
Schwann cell,24869


## Compute OT Transitions

Runs `compute_ot_transitions(...)` from `rrmap_ot.py` and stores outputs in `adata.uns["ot_transitions"]`.


In [8]:
course_key_for_ot = COURSE_KEY
if STRIP_COURSE_LABELS:
    course_key_for_ot = f"{COURSE_KEY}__stripped"
    adata.obs[course_key_for_ot] = adata.obs[COURSE_KEY].astype("string").str.strip()

adata_for_ot = adata
course_order_for_ot = COURSE_ORDER

if MODEL_NAME is not None:
    if MODEL_NAME not in MODELS:
        raise KeyError(f"MODEL_NAME={MODEL_NAME!r} not found in MODELS: {list(MODELS.keys())}")

    model_cfg = MODELS[MODEL_NAME]
    model_courses = [str(c).strip() for c in model_cfg["courses"]]
    baseline = str(model_cfg.get("baseline", model_courses[0])).strip()
    if model_courses and baseline != model_courses[0]:
        warnings.warn(
            f"Model {MODEL_NAME!r} baseline {baseline!r} does not match first course {model_courses[0]!r}."
        )

    model_mask = adata.obs[course_key_for_ot].isin(model_courses).to_numpy()
    if int(model_mask.sum()) == 0:
        raise ValueError(
            f"No cells matched MODEL_NAME={MODEL_NAME!r} using course labels: {model_courses}"
        )

    adata_for_ot = adata[model_mask].copy()
    course_order_for_ot = model_courses

    present = set(adata_for_ot.obs[course_key_for_ot].dropna().unique().tolist())
    missing_courses = [c for c in model_courses if c not in present]
    if missing_courses:
        warnings.warn(
            f"Some configured courses are missing for MODEL_NAME={MODEL_NAME!r}: {missing_courses}"
        )

    print(f"Running model: {MODEL_NAME}")
    print(f"Course order: {course_order_for_ot}")
    print(f"Cells in model subset: {adata_for_ot.n_obs:,}")
else:
    print("Running full dataset (MODEL_NAME=None).")
    if course_order_for_ot is not None:
        print(f"Using explicit COURSE_ORDER with {len(course_order_for_ot)} labels.")

transitions, plans = rrmap_ot.compute_ot_transitions(
    adata=adata_for_ot,
    course_key=course_key_for_ot,
    state_key=STATE_KEY,
    embedding_key=EMBEDDING_KEY,
    method=METHOD,
    mode=MODE,
    course_order=course_order_for_ot,
    reg=REG,
    reg_m=REG_M,
    max_cells_per_course=MAX_CELLS_PER_COURSE,
    random_state=RANDOM_STATE,
    return_plans=True,
    include_all_states=INCLUDE_ALL_STATES,
    unbalanced_mass_mode=UNBALANCED_MASS_MODE,
)

print(f"Computed {len(transitions)} course-pair transition matrix/matrices.")
print("Pairs:", list(transitions.keys()))


Computed 13 course-pair transition matrix/matrices.
Pairs: [('MOG CFA', 'early onset'), ('early onset', 'chronic peak'), ('chronic peak', 'chronic long'), ('chronic long', 'PLP CFA'), ('PLP CFA', 'non symptomatic'), ('non symptomatic', 'monophasic'), ('monophasic', 'onset I'), ('onset I', 'onset II'), ('onset II', 'peak I'), ('peak I', 'remitt I'), ('remitt I', 'peak II'), ('peak II', 'remitt II'), ('remitt II', 'peak III')]


In [None]:
# Inspect top-k destinations per source state for each course pair
all_top_tables = {}

for pair, t_df in transitions.items():
    top = rrmap_ot.top_k_destinations(t_df, k=TOP_K, normalize_rows=True)

    rows = []
    for src_state, dsts in top.items():
        for rank, (dst_state, score) in enumerate(dsts, start=1):
            rows.append(
                {
                    "pair": f"{pair[0]}->{pair[1]}",
                    "source_state": src_state,
                    "rank": rank,
                    "dest_state": dst_state,
                    "score": score,
                }
            )

    top_df = pd.DataFrame(rows)
    all_top_tables[pair] = top_df

    print(f"\nTop transitions for {pair[0]} -> {pair[1]}")
    if top_df.empty:
        print("No positive mass transitions.")
    else:
        display(top_df.sort_values(["source_state", "rank"]).head(30))


In [None]:
# Plot one heatmap
if not transitions:
    raise ValueError("No transitions were computed. Check course ordering and required keys.")

pair_to_plot = next(iter(transitions.keys()))
t_plot = transitions[pair_to_plot]

ax = rrmap_ot.plot_transition_heatmap(
    t_plot,
    title=f"OT transition: {pair_to_plot[0]} -> {pair_to_plot[1]}",
    figsize=(10, 6),
    cmap="viridis",
)
plt.tight_layout()
plt.show()


In [None]:
# QC: outgoing entropy + unbalanced mass diagnostics for one pair
pair_qc = next(iter(transitions.keys()))
t_qc = transitions[pair_qc]

entropy = rrmap_ot.outgoing_entropy(t_qc).sort_values(ascending=False)
print(f"Outgoing entropy ({pair_qc[0]} -> {pair_qc[1]})")
display(entropy.to_frame().head(30))

if METHOD == "unbalanced":
    payload = plans[pair_qc]
    source_labels = payload.get("source_labels")
    target_labels = payload.get("target_labels")

    diag = rrmap_ot.unbalanced_mass_diagnostics(
        plan=payload["plan"],
        a=payload["a"],
        b=payload["b"],
        source_labels=source_labels,
        target_labels=target_labels,
    )

    print("Unbalanced mass summary:")
    display(pd.Series(diag["summary"], name="value").to_frame())

    print("Source mass diagnostics (head):")
    display(diag["source"].head(20))

    print("Target mass diagnostics (head):")
    display(diag["target"].head(20))


In [None]:
# Optional: save transition matrices and top tables for downstream review
run_label = (MODEL_NAME if MODEL_NAME is not None else "full_dataset").replace(" ", "_")
out_dir = PROJECT_ROOT / "output" / "ot_transition_tables" / run_label
out_dir.mkdir(parents=True, exist_ok=True)

for pair, t_df in transitions.items():
    pair_str = f"{pair[0]}__to__{pair[1]}".replace("/", "-")
    t_df.to_csv(out_dir / f"transition_{pair_str}.csv")

for pair, top_df in all_top_tables.items():
    pair_str = f"{pair[0]}__to__{pair[1]}".replace("/", "-")
    top_df.to_csv(out_dir / f"topk_{pair_str}.csv", index=False)

print(f"Saved tables to: {out_dir}")


In [None]:
# Optional heavy run (cell mode): uncomment when needed
# transitions_cell = rrmap_ot.compute_ot_transitions(
#     adata=adata,
#     course_key=COURSE_KEY,
#     state_key=STATE_KEY,
#     embedding_key=EMBEDDING_KEY,
#     method=METHOD,
#     mode="cell",
#     course_order=COURSE_ORDER,
#     reg=REG,
#     reg_m=REG_M,
#     max_cells_per_course=5_000,
#     random_state=RANDOM_STATE,
#     include_all_states=INCLUDE_ALL_STATES,
#     unbalanced_mass_mode=UNBALANCED_MASS_MODE,
# )
# print("Cell-mode pairs:", list(transitions_cell.keys()))


## Next Steps
- Compare `METHOD="balanced"` vs `METHOD="unbalanced"` and inspect how entropy changes.
- Provide an explicit `COURSE_ORDER` if lexical sorting does not match biological time.
- If centroid mode is too coarse, run cell mode on selected course pairs with lower `max_cells_per_course`.
