# RRmap RewireSpace Analysis

This notebook runs RewireSpace on your dataset:
`/Volumes/processing2/RRmap/data/RRmap_metadata_fixed_update.h5ad`.

Workflow:
- Inspect available metadata and graph keys
- Load a subset of samples into memory
- Compute per-sample contact/Z matrices
- Fit stage-wise rewiring effects and plot top edges


In [None]:
import sys
from pathlib import Path

for candidate in [Path.cwd() / "src", Path.cwd().parent / "src"]:
    if candidate.exists() and str(candidate) not in sys.path:
        sys.path.insert(0, str(candidate))

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import anndata as ad

from rewirespace.spatial_rewiring import per_sample_contacts, fit_stage_effect
from rewirespace.plot_rewiring import plot_stage_heatmaps, plot_rewiring_curves


In [None]:
DATA_PATH = Path("/Volumes/processing2/RRmap/data/RRmap_metadata_fixed_update.h5ad")

# Defaults inferred from this file on 2026-02-19.
CELL_TYPE_KEY = "annotation"  # alternatives: "Level2", "Class", "anno_L2"
STAGE_KEY = "stage"
SAMPLE_KEY = "sample_id"
SUBJECT_KEY = None  # e.g. "sample_name" if repeated measures are available
ADJ_KEY = "spatial_connectivities"

# Runtime knobs
SAMPLES_TO_USE = None  # e.g. ["G3_L1_0", "G3_L1_1"]
N_SAMPLES_HEAD = 4
MAX_CELLS_PER_SAMPLE = 15000  # set None to disable downsampling
N_PERM = 50  # increase for final runs (e.g. 100-500)
RANDOM_STATE = 42

if not DATA_PATH.exists():
    raise FileNotFoundError(f"Missing file: {DATA_PATH}")


In [None]:
adata_b = ad.read_h5ad(DATA_PATH, backed="r")

print(adata_b)
print("n_obs:", adata_b.n_obs, "n_vars:", adata_b.n_vars)
print("obsp keys:", list(adata_b.obsp.keys()))
print("obs columns (first 40):", list(adata_b.obs.columns)[:40])

required_obs = [CELL_TYPE_KEY, STAGE_KEY, SAMPLE_KEY]
missing_obs = [k for k in required_obs if k not in adata_b.obs.columns]
if missing_obs:
    raise KeyError(f"Missing required adata.obs keys: {missing_obs}")
if ADJ_KEY not in adata_b.obsp and ADJ_KEY not in adata_b.uns:
    raise KeyError(f"Adjacency key '{ADJ_KEY}' not in adata.obsp/adata.uns.")

sample_counts = adata_b.obs[SAMPLE_KEY].value_counts(dropna=False)
print("n unique samples:", sample_counts.shape[0])
sample_counts.head(12)


In [None]:
if SAMPLES_TO_USE is None:
    selected_samples = sample_counts.index[:N_SAMPLES_HEAD].tolist()
else:
    selected_samples = list(SAMPLES_TO_USE)

print("selected_samples:", selected_samples)

mask = adata_b.obs[SAMPLE_KEY].isin(selected_samples).to_numpy()
idx = np.flatnonzero(mask)
print("selected cells:", idx.size)
if idx.size == 0:
    raise ValueError("No cells matched selected samples; update SAMPLES_TO_USE or SAMPLE_KEY.")

adata = adata_b[idx].to_memory()
adata_b.file.close()
print(adata)


In [None]:
if MAX_CELLS_PER_SAMPLE is not None:
    rng = np.random.default_rng(RANDOM_STATE)
    sample_vals = adata.obs[SAMPLE_KEY].to_numpy()
    keep_pos = []
    for sid in pd.unique(sample_vals):
        pos = np.flatnonzero(sample_vals == sid)
        if pos.size > MAX_CELLS_PER_SAMPLE:
            pos = rng.choice(pos, size=MAX_CELLS_PER_SAMPLE, replace=False)
        keep_pos.append(np.sort(pos))
    keep_pos = np.concatenate(keep_pos)
    keep_pos.sort()
    adata = adata[keep_pos].copy()
    print(f"Downsampled to {adata.n_obs} cells (max {MAX_CELLS_PER_SAMPLE} per sample).")

if isinstance(adata.obs[STAGE_KEY].dtype, pd.CategoricalDtype):
    print("Using existing categorical stage order.")
else:
    stage_vals = [v for v in pd.unique(adata.obs[STAGE_KEY]) if pd.notna(v)]
    try:
        stage_order = sorted(stage_vals)
    except TypeError:
        stage_order = stage_vals
    adata.obs[STAGE_KEY] = pd.Categorical(adata.obs[STAGE_KEY], categories=stage_order, ordered=True)
    print("Inferred stage order:", stage_order)


In [None]:
subject_key_for_contacts = SUBJECT_KEY if (SUBJECT_KEY is not None and SUBJECT_KEY in adata.obs.columns) else None
if SUBJECT_KEY is not None and subject_key_for_contacts is None:
    print(f"Subject key '{SUBJECT_KEY}' not found in adata.obs; running without subject-level adjustment.")

df_long = per_sample_contacts(
    adata=adata,
    adj_key=ADJ_KEY,
    cell_type_key=CELL_TYPE_KEY,
    sample_key=SAMPLE_KEY,
    stage_key=STAGE_KEY,
    subject_key=subject_key_for_contacts,
    n_perm=N_PERM,
)

print(df_long.shape)
df_long.head()


In [None]:
summary = (
    df_long.groupby("stage", observed=False)
    .agg(
        n_samples=("sample_id", "nunique"),
        n_rows=("type_i", "size"),
        mean_abs_Z=("Z", lambda x: float(np.mean(np.abs(x)))),
    )
    .reset_index()
)
summary


In [None]:
subject_col_for_model = "subject" if (subject_key_for_contacts is not None and "subject" in df_long.columns) else None

results = fit_stage_effect(
    df_long=df_long,
    value_col="Z",
    stage_col="stage",
    subject_col=subject_col_for_model,
)

print(results.shape)
results.head(20)


In [None]:
tmp = results.copy()
tmp["_q"] = tmp["q_value"].fillna(1.0)
top_edges = tmp.sort_values(["_q", "effect_range"], ascending=[True, False]).drop(columns="_q").head(10)
top_edges[["rank", "type_i", "type_j", "effect_range", "p_value", "q_value", "model"]]


In [None]:
fig, axes = plot_stage_heatmaps(
    df_long=df_long,
    value_col="Z",
    stage_col="stage",
    ncols=2,
)
plt.show()


In [None]:
pairs = list(top_edges[["type_i", "type_j"]].itertuples(index=False, name=None))[:6]
fig, ax = plot_rewiring_curves(
    df_long=df_long,
    pairs=pairs,
    value_col="Z",
    stage_col="stage",
    show_sem=True,
)
plt.show()


In [None]:
out_prefix = "rrmap_rewiring"
df_long.to_csv(f"{out_prefix}_per_sample_contacts.csv", index=False)
results.to_csv(f"{out_prefix}_stage_effects.csv", index=False)
print(f"Saved {out_prefix}_per_sample_contacts.csv and {out_prefix}_stage_effects.csv")
