# Spatial Rewiring Showcase

This notebook runs the full neighbor-matrix-centric rewiring pipeline on AnnData.
- If `adata` already exists in memory, it will use your dataset.
- Otherwise, it builds a synthetic spatial omics dataset so the workflow is runnable end-to-end.

In [None]:
import sys
from pathlib import Path

for candidate in [Path.cwd() / "src", Path.cwd().parent / "src"]:
    if candidate.exists() and str(candidate) not in sys.path:
        sys.path.insert(0, str(candidate))

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import sparse

from rewirespace.spatial_rewiring import per_sample_contacts, fit_stage_effect
from rewirespace.plot_rewiring import plot_stage_heatmaps, plot_rewiring_curves

In [None]:
# Configure keys for your AnnData
CELL_TYPE_KEY = "cell_type"
STAGE_KEY = "stage"
SAMPLE_KEY = "sample_id"
SUBJECT_KEY = "mouse_id"  # set to None if repeated measures are not available
ADJ_KEY = "spatial_connectivities"
N_PERM = 200

In [None]:
# Build synthetic showcase data only if `adata` is not already defined
if "adata" not in globals():
    import anndata as ad

    rng = np.random.default_rng(42)

    stage_order = ["baseline", "early", "mid", "late"]
    samples = [
        ("s1", "baseline", "subj1"),
        ("s2", "baseline", "subj2"),
        ("s3", "early", "subj1"),
        ("s4", "early", "subj2"),
        ("s5", "mid", "subj1"),
        ("s6", "mid", "subj2"),
        ("s7", "late", "subj1"),
        ("s8", "late", "subj2"),
    ]

    cell_types = np.array(["T", "B", "Myeloid", "Stromal"])
    # Stage-specific compositional drift to create non-trivial rewiring patterns
    stage_probs = {
        "baseline": np.array([0.45, 0.25, 0.20, 0.10]),
        "early":    np.array([0.35, 0.20, 0.30, 0.15]),
        "mid":      np.array([0.28, 0.18, 0.37, 0.17]),
        "late":     np.array([0.22, 0.16, 0.42, 0.20]),
    }

    obs_rows = []
    adjacency_blocks = []

    for sample_id, stage, subject in samples:
        n_cells = 120
        labels = rng.choice(cell_types, size=n_cells, p=stage_probs[stage])

        # Random spatial coordinates and kNN graph per sample
        coords = rng.uniform(0, 1, size=(n_cells, 2))
        d2 = ((coords[:, None, :] - coords[None, :, :]) ** 2).sum(axis=2)
        np.fill_diagonal(d2, np.inf)

        k = 6
        nn_idx = np.argpartition(d2, kth=k, axis=1)[:, :k]
        r = np.repeat(np.arange(n_cells), k)
        c = nn_idx.reshape(-1)

        A = sparse.csr_matrix((np.ones_like(r, dtype=np.float32), (r, c)), shape=(n_cells, n_cells))
        A = ((A + A.T) > 0).astype(np.float32)
        adjacency_blocks.append(A)

        for ct in labels:
            row = {
                CELL_TYPE_KEY: ct,
                STAGE_KEY: stage,
                SAMPLE_KEY: sample_id,
            }
            if SUBJECT_KEY is not None:
                row[SUBJECT_KEY] = subject
            obs_rows.append(row)

    obs = pd.DataFrame(obs_rows)
    obs[STAGE_KEY] = pd.Categorical(obs[STAGE_KEY], categories=stage_order, ordered=True)

    adata = ad.AnnData(X=np.zeros((obs.shape[0], 1), dtype=np.float32), obs=obs)
    adata.obsp[ADJ_KEY] = sparse.block_diag(adjacency_blocks, format="csr")

print(adata)
print(adata.obs[[CELL_TYPE_KEY, STAGE_KEY, SAMPLE_KEY] + ([SUBJECT_KEY] if SUBJECT_KEY is not None else [])].head())

In [None]:
subject_key_for_contacts = SUBJECT_KEY if (SUBJECT_KEY is not None and SUBJECT_KEY in adata.obs.columns) else None
if SUBJECT_KEY is not None and subject_key_for_contacts is None:
    print(f"Subject key '{SUBJECT_KEY}' not found in adata.obs; running without subject-level adjustment.")

df_long = per_sample_contacts(
    adata=adata,
    adj_key=ADJ_KEY,
    cell_type_key=CELL_TYPE_KEY,
    sample_key=SAMPLE_KEY,
    stage_key=STAGE_KEY,
    subject_key=subject_key_for_contacts,
    n_perm=N_PERM,
)

print(df_long.shape)
df_long.head()

In [None]:
summary = (
    df_long.groupby("stage", observed=False)
    .agg(n_samples=("sample_id", "nunique"), mean_abs_Z=("Z", lambda x: np.mean(np.abs(x))))
    .reset_index()
)
summary

In [None]:
subject_col_for_model = "subject" if (subject_key_for_contacts is not None and "subject" in df_long.columns) else None

results = fit_stage_effect(
    df_long=df_long,
    value_col="Z",
    stage_col="stage",
    subject_col=subject_col_for_model,
)

print(results.shape)
results.head(15)

In [None]:
# Rank top changing edges
tmp = results.copy()
tmp["_q"] = tmp["q_value"].fillna(1.0)
top_edges = tmp.sort_values(["_q", "effect_range"], ascending=[True, False]).drop(columns="_q").head(10)
top_edges[["rank", "type_i", "type_j", "effect_range", "p_value", "q_value", "model"]]

In [None]:
fig, axes = plot_stage_heatmaps(
    df_long=df_long,
    value_col="Z",
    stage_col="stage",
    ncols=2,
)
plt.show()

In [None]:
pairs = list(top_edges[["type_i", "type_j"]].itertuples(index=False, name=None))[:6]
fig, ax = plot_rewiring_curves(
    df_long=df_long,
    pairs=pairs,
    value_col="Z",
    stage_col="stage",
    show_sem=True,
)
plt.show()

In [None]:
# Optional: save outputs from this showcase run
df_long.to_csv("rewiring_showcase_per_sample_contacts.csv", index=False)
results.to_csv("rewiring_showcase_stage_effects.csv", index=False)
print("Saved rewiring_showcase_per_sample_contacts.csv and rewiring_showcase_stage_effects.csv")