# UniVI TEA-seq tri-modal data integration demonstration/tutorial

Andrew Ashford, Pathways + Omics Group, Oregon Health & Science University - 11/18/2025

This Jupyter Notebook will be used to outline the training steps for a UniVI model using human PBMC TEA-seq tri-modal data.


#### Import modules

In [1]:
import os
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import scipy.sparse as sp
import scanpy as sc
import anndata as ad

from sklearn.decomposition import TruncatedSVD
from sklearn.preprocessing import normalize

import snapatac2 as snap


In [2]:
# -------------------------
# 0. Wire up package import
# -------------------------
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)

from univi import (
    UniVIMultiModalVAE,
    ModalityConfig,
    UniVIConfig,
    TrainingConfig,
    matching,
)
from univi.data import MultiModalDataset
from univi.trainer import UniVITrainer


In [3]:
import torch
print("Torch:", torch.__version__)
print("torch.version.cuda:", torch.version.cuda)
print("CUDA available:", torch.cuda.is_available())

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)


Torch: 2.9.1+cu128
torch.version.cuda: 12.8
CUDA available: True
Using device: cuda


#### Read in and preprocess data as needed

In [4]:
data_dir = Path("../data/TEA-seq_data")
prefix   = "GSM5123953_X066-MP0C1W5_leukopak_perm-cells_tea"

rna_h5         = data_dir / f"{prefix}_200M_cellranger-arc_filtered_feature_bc_matrix.h5"
adt_counts_csv = data_dir / f"{prefix}_48M_adt_counts.csv.gz"
frag_tsv       = data_dir / f"{prefix}_200M_atac_filtered_fragments.tsv.gz"
atac_meta_csv  = data_dir / f"{prefix}_200M_atac_filtered_metadata.csv.gz"


In [5]:
# ----------------------
# 1) RNA
# ----------------------
print("Reading RNA (ARC filtered_feature_bc_matrix.h5)...")
m = sc.read_10x_h5(rna_h5)
print(m)

print("Feature types:", m.var["feature_types"].unique())

rna_adata = m.copy()
rna_adata.var_names_make_unique()
print("Initial RNA shape:", rna_adata.shape)


Reading RNA (ARC filtered_feature_bc_matrix.h5)...


  utils.warn_names_duplicates("var")


AnnData object with n_obs × n_vars = 8496 × 36601
    var: 'gene_ids', 'feature_types', 'genome', 'interval'
Feature types: ['Gene Expression']
Initial RNA shape: (8496, 36601)


  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")


In [6]:
# Sanity check
print(rna_adata.X.min())
print(rna_adata.X.max())


0.0
3801.0


In [7]:
# ----------------------
# 2) ADT
# ----------------------
print("Reading ADT counts...")
adt_df = pd.read_csv(adt_counts_csv, index_col=0)  # rows = barcodes, cols = ADT names
print("ADT counts shape:", adt_df.shape)

adt_adata = ad.AnnData(
    X=sp.csr_matrix(adt_df.values),
    obs=pd.DataFrame(index=adt_df.index.astype(str)),
    var=pd.DataFrame(index=adt_df.columns.astype(str)),
)
adt_adata.var_names_make_unique()
print("ADT AnnData:", adt_adata.shape)


Reading ADT counts...
ADT counts shape: (719170, 47)
ADT AnnData: (719170, 47)


In [8]:
# Sanity check
print(adt_adata.X.min())
print(adt_adata.X.max())


0
37856


In [None]:
# ----------------------
# 3) ATAC: fragments + metadata -> LSI
# ----------------------
print("Importing ATAC fragments with snapatac2...")
atac_raw = snap.pp.import_data(
    fragment_file=str(frag_tsv),
    chrom_sizes=snap.genome.hg38,   # use chrom_sizes
    sorted_by_barcode=False,        # TEA-seq fragments usually unsorted
)

print("ATAC raw object:", atac_raw)


In [None]:
# ---- 3a. Attach external metadata and map barcodes ----
meta = pd.read_csv(atac_meta_csv)
print("ATAC meta columns:", meta.columns.tolist())

# Index metadata by hex barcodes; they match atac_raw.obs_names
meta = meta.set_index("barcodes")

# Align cells common between fragments + metadata
common_ids = atac_raw.obs_names.intersection(meta.index)
print("ATAC cells with metadata:", len(common_ids), "of", atac_raw.n_obs)

atac_raw = atac_raw[common_ids].copy()
atac_raw.obs = atac_raw.obs.join(meta, how="left")


In [None]:
# ----------------------
# 4) Simple ATAC QC (using n_fragments from metadata)
# ----------------------
if "n_fragments" in atac_raw.obs.columns:
    mask = atac_raw.obs["n_fragments"] >= 1500
    print("Keeping", mask.sum(), "of", atac_raw.n_obs, "ATAC cells with n_fragments >= 1000")
    atac_raw = atac_raw[mask].copy()


In [None]:
# Sanity check
print(atac_raw.X.min())
print(atac_raw.X.max())


In [None]:
# ----------------------
# 5) Tile matrix + TF–IDF + LSI
# ----------------------
print("Adding ATAC tile matrix...")
snap.pp.add_tile_matrix(atac_raw)
print("Tile matrix shape:", atac_raw.shape, "(cells x genomic tiles)")


In [None]:
print("Computing TF–IDF and LSI on ATAC tiles...")
X = atac_raw.X
if not sp.issparse(X):
    X = sp.csr_matrix(X)

n_cells, n_feats = X.shape
print(f"ATAC tile matrix: {n_cells} cells × {n_feats} features")


In [None]:
# TF: normalize by per-cell counts
tf = normalize(X, norm="l1", axis=1)

# DF: number of cells with a non-zero in each feature
df = np.array((X > 0).sum(axis=0)).ravel()
idf = np.log1p(n_cells / (1.0 + df))

# TF-IDF
X_tfidf = tf.multiply(idf)


In [None]:
'''
import numpy as np

# X: sparse matrix of shape (n_cells, n_peaks)
min_cells_per_peak = 50  # or 20, depending on sparsity

# number of nonzeros per peak (i.e., per column)
peak_nnz = X.getnnz(axis=0)          # returns a 1D np.array-like
peak_nnz = np.asarray(peak_nnz).ravel()

keep_peaks = peak_nnz >= min_cells_per_peak

X_filtered = X[:, keep_peaks]
'''

In [None]:
if X_tfidf.dtype != np.float32:
    X_tfidf = X_tfidf.astype(np.float32)


In [None]:
# SVD → LSI
n_lsi = 100
svd = TruncatedSVD(n_components=n_lsi, random_state=42)
lsi = svd.fit_transform(X_tfidf)


In [None]:
# row L2-normalize
lsi = normalize(lsi, norm="l2", axis=1)


In [None]:
# Wrap as ATAC AnnData for UniVI
atac_adata = ad.AnnData(
    X=lsi.astype(np.float32),
    obs=atac_raw.obs.copy(),   # includes original_barcodes, etc.
)
print("ATAC LSI AnnData:", atac_adata.shape)


In [None]:
# ----------------------
# 6) Put everything into shared barcode space
# ----------------------
def strip_suffix(idx):
    # drop trailing "-<number>" if present
    return idx.astype(str).str.replace(r"-\d+$", "", regex=True)

# RNA & ADT already indexed by 10x barcodes with "-1"
rna_adata.obs_names = strip_suffix(rna_adata.obs_names.to_series())
adt_adata.obs_names = strip_suffix(adt_adata.obs_names.to_series())

# For ATAC, map from hex barcodes -> original 10x barcodes, then strip suffix
atac_adata.obs["barcode_10x"] = atac_adata.obs["original_barcodes"].astype(str)
atac_adata.obs_names = strip_suffix(atac_adata.obs["barcode_10x"])

# Also strip suffix from RNA/ADT now for consistency
rna_adata.obs_names = strip_suffix(rna_adata.obs_names.to_series())
adt_adata.obs_names = strip_suffix(adt_adata.obs_names.to_series())


In [None]:
# ----------------------
# 7) Tri-modal intersection
# ----------------------
common_barcodes = (
    set(rna_adata.obs_names)
    & set(adt_adata.obs_names)
    & set(atac_adata.obs_names)
)

print("Common tri-modal cells:", len(common_barcodes))
if len(common_barcodes) == 0:
    raise ValueError("No overlapping barcodes across RNA/ADT/ATAC after mapping.")

common_barcodes = sorted(common_barcodes)

rna_adata  = rna_adata[common_barcodes].copy()
adt_adata  = adt_adata[common_barcodes].copy()
atac_adata = atac_adata[common_barcodes].copy()

print("Aligned shapes:")
print("  RNA :", rna_adata.shape)
print("  ADT :", adt_adata.shape)
print("  ATAC:", atac_adata.shape)


In [None]:
# ----------------------
# 8) Optional subsampling
# ----------------------
target_n = 10000
n_cells  = rna_adata.n_obs

if n_cells > target_n:
    rng = np.random.default_rng(42)
    keep_idx = rng.choice(n_cells, size=target_n, replace=False)
    keep_barcodes = rna_adata.obs_names[keep_idx]

    rna_adata  = rna_adata[keep_barcodes].copy()
    adt_adata  = adt_adata[keep_barcodes].copy()
    atac_adata = atac_adata[keep_barcodes].copy()

print("After optional subsampling:", rna_adata.n_obs, "cells.")


#### Preprocess data since we will be using Gaussian decoders in this case to prioritize data alignment

In [None]:
# ----------------------
# 9) Modality-specific preprocessing
# ----------------------
'''
# --- RNA: log-normalize + HVGs ---
rna = rna_adata.copy()
rna.layers["counts"] = rna.X.copy()

sc.pp.normalize_total(rna, target_sum=1e4)
sc.pp.log1p(rna)
sc.pp.highly_variable_genes(rna, n_top_genes=2000, flavor="seurat_v3")
rna = rna[:, rna.var["highly_variable"]].copy()
print("RNA (HVG log1p) shape:", rna.shape)
'''

In [None]:
'''
# --- ADT: CLR per cell ---
adt = adt_adata.copy()
adt.layers["counts"] = adt.X.copy()

X = adt.layers["counts"].astype(float)
if sp.issparse(X):
    X = X.toarray()

eps = 1e-6
X_log = np.log1p(X + eps)
X_clr = X_log - X_log.mean(axis=1, keepdims=True)
adt.X = X_clr.astype(np.float32)
print("ADT CLR shape:", adt.shape)
'''

In [None]:
'''
# --- ATAC: z-score each LSI dimension ---
atac = atac_adata.copy()
X_atac = atac.X.astype(np.float32)

mean = X_atac.mean(axis=0, keepdims=True)
std  = X_atac.std(axis=0, keepdims=True) + 1e-6
X_z  = (X_atac - mean) / std
atac.X = X_z.astype(np.float32)
print("ATAC LSI-z shape:", atac.shape)
'''

In [None]:
import scanpy as sc
import numpy as np
import scipy.sparse as sp

# ---------- RNA ----------
rna = rna_adata.copy()
rna.layers["counts"] = rna.X.copy()

# Normalize + log1p
sc.pp.normalize_total(rna, target_sum=1e4)
sc.pp.log1p(rna)

# HVGs
sc.pp.highly_variable_genes(rna, n_top_genes=2000, flavor="seurat_v3")
rna = rna[:, rna.var["highly_variable"]].copy()
print("RNA (HVG log1p) shape:", rna.shape)

# Z-score per gene across cells (for Gaussian decoder)
sc.pp.scale(rna, max_value=10)
print("RNA scaled shape:", rna.shape)


# ---------- ADT ----------
adt = adt_adata.copy()
adt.layers["counts"] = adt.X.copy()

X = adt.layers["counts"].astype(float)
if sp.issparse(X):
    X = X.toarray()

eps = 1e-6
# CLR per cell
X_log = np.log1p(X + eps)
X_clr = X_log - X_log.mean(axis=1, keepdims=True)

# Then per-feature z-score across cells
mean_adt = X_clr.mean(axis=0, keepdims=True)
std_adt  = X_clr.std(axis=0, keepdims=True) + 1e-6
X_clr_z  = (X_clr - mean_adt) / std_adt

adt.X = X_clr_z.astype(np.float32)

print("ADT CLR+z shape:", adt.shape)


# ---------- ATAC ----------
atac = atac_adata.copy()

X_atac = atac.X.astype(np.float32)  # assume this is LSI already
mean_atac = X_atac.mean(axis=0, keepdims=True)
std_atac  = X_atac.std(axis=0, keepdims=True) + 1e-6
X_z = (X_atac - mean_atac) / std_atac

atac.X = X_z.astype(np.float32)
print("ATAC LSI-z shape:", atac.shape)


#### Initialize model and data via dataloaders

In [None]:
adata_dict = {
    "rna": rna,
    "adt": adt,
    "atac": atac,
}


In [None]:
from univi.data import MultiModalDataset
from univi.config import UniVIConfig, ModalityConfig, TrainingConfig
from univi.models.univi import UniVIMultiModalVAE
from univi.trainer import UniVITrainer
'''
# ---------- UniVI config (Gaussian for all 3) ----------
univi_cfg = UniVIConfig(
    latent_dim=60,
    beta=100.0,
    gamma=120.0,
    encoder_dropout=0.0,
    decoder_dropout=0.0,
    encoder_batchnorm=True,
    decoder_batchnorm=False,
    kl_anneal_start=0,
    kl_anneal_end=0,
    align_anneal_start=0,
    align_anneal_end=0,
    modalities=[
        ModalityConfig(
            name="rna",
            input_dim=rna.n_vars,
            encoder_hidden=[512, 256],
            decoder_hidden=[256, 512],
            likelihood="gaussian",
        ),
        ModalityConfig(
            name="adt",
            input_dim=adt.n_vars,
            encoder_hidden=[128, 64],
            decoder_hidden=[64, 128],
            likelihood="gaussian",
        ),
        ModalityConfig(
            name="atac",
            input_dim=atac.n_vars,  # n_lsi (e.g. 50)
            encoder_hidden=[128, 64],
            decoder_hidden=[64, 128],
            likelihood="gaussian",
        ),
    ],
)

train_cfg = TrainingConfig(
    n_epochs=200,
    batch_size=256,
    lr=1e-3,
    weight_decay=1e-4,
    #device="cuda",   # or "cpu"
    device=device,
    log_every=10,
    grad_clip=5.0,
    num_workers=0,
    seed=42,
    early_stopping=True,
    patience=20,
    min_delta=0.0,
)
'''

In [None]:
from univi.config import UniVIConfig, ModalityConfig, TrainingConfig

univi_cfg = UniVIConfig(
    latent_dim=80,
    #beta=150.0,
    beta=10.0,
    gamma=40.0,
    encoder_dropout=0.0,
    decoder_dropout=0.0,
    encoder_batchnorm=True,
    decoder_batchnorm=False,
    kl_anneal_start=0,
    kl_anneal_end=0,   # ramp KL up over first 50 epochs
    align_anneal_start=0,  # let reconstructions stabilize a bit first
    align_anneal_end=0,
    modalities=[
        ModalityConfig(
            name="rna",
            input_dim=rna.n_vars,
            encoder_hidden=[512, 256],
            decoder_hidden=[256, 512],
            likelihood="gaussian",
        ),
        ModalityConfig(
            name="adt",
            input_dim=adt.n_vars,
            encoder_hidden=[128, 64],
            decoder_hidden=[64, 128],
            likelihood="gaussian",
        ),
        ModalityConfig(
            name="atac",
            input_dim=atac.n_vars,  # n_lsi
            encoder_hidden=[256, 128],
            decoder_hidden=[128, 256],
            likelihood="gaussian",
        ),
    ],
)

train_cfg = TrainingConfig(
    n_epochs=300,
    batch_size=256,
    lr=1e-3,
    weight_decay=1e-4,
    device=device,      # "cuda"
    log_every=10,
    grad_clip=5.0,
    num_workers=0,
    seed=42,
    early_stopping=True,
    patience=35,
    min_delta=0.0,
)


In [None]:
from torch.utils.data import DataLoader, Subset

dataset = MultiModalDataset(
    adata_dict=adata_dict,
    X_key="X",
    device=train_cfg.device,
)

n_cells = dataset.n_cells
indices = np.arange(n_cells)
rng = np.random.default_rng(42)
rng.shuffle(indices)

frac_train = 0.8
frac_val   = 0.1
n_train = int(frac_train * n_cells)
n_val   = int(frac_val * n_cells)

train_idx = indices[:n_train]
val_idx   = indices[n_train:n_train + n_val]
test_idx  = indices[n_train + n_val:]

train_ds = Subset(dataset, train_idx)
val_ds   = Subset(dataset, val_idx)

train_loader = DataLoader(
    train_ds,
    batch_size=train_cfg.batch_size,
    shuffle=True,
    num_workers=train_cfg.num_workers,
)

val_loader = DataLoader(
    val_ds,
    batch_size=train_cfg.batch_size,
    shuffle=False,
    num_workers=train_cfg.num_workers,
)

model = UniVIMultiModalVAE(univi_cfg).to(train_cfg.device)
trainer = UniVITrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    train_cfg=train_cfg,
    device=train_cfg.device,
)


#### Train model

In [None]:
# ---------- train ----------
history = trainer.fit()


In [None]:
import matplotlib.pyplot as plt

# Quick training curves
fig, ax = plt.subplots()
ax.plot(history["train_loss"], label="train")
ax.plot(history["val_loss"], label="val")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.set_title("UniVI Multiome training curves")
ax.legend()
plt.tight_layout()
plt.show()

fig, ax = plt.subplots()
ax.plot(history["beta"], label="beta")
ax.plot(history["gamma"], label="gamma")
ax.set_xlabel("Epoch")
ax.set_ylabel("Weight")
ax.set_title("KL / alignment annealing")
ax.legend()
plt.tight_layout()
plt.show()


In [None]:
from dataclasses import asdict
import torch

os.makedirs("../saved_models", exist_ok=True)

# trainer.model already has the best weights (because we restored best_state_dict)
ckpt_path = "../saved_models/univi_tea_seq_beta_1_gamma_40_latent_dims_120.pt"
torch.save(
    {
        "state_dict": trainer.model.state_dict(),
        "univi_cfg": asdict(univi_cfg),
        "best_epoch": trainer.best_epoch,
        "best_val_loss": trainer.best_val_loss,
    },
    ckpt_path,
)
print("Saved best model to:", ckpt_path)


#### Evaluate model

In [None]:
import torch
from univi.config import UniVIConfig, ModalityConfig
from univi.models.univi import UniVIMultiModalVAE

#device = "cpu"  # or "cuda" if available
device = 'cuda'

ckpt = torch.load(
    "../saved_models/univi_tea_seq_beta_1_gamma_40_latent_dims_120.pt",
    map_location=device,
)

# ---- Rebuild UniVIConfig, making sure modalities are ModalityConfig objects ----
cfg_dict = ckpt["univi_cfg"]

# If this is an OmegaConf object or similar, make sure it's a plain dict
try:
    from omegaconf import DictConfig, OmegaConf
    if isinstance(cfg_dict, DictConfig):
        cfg_dict = OmegaConf.to_container(cfg_dict, resolve=True)
except ImportError:
    pass

# Now rehydrate each modality
modalities = [ModalityConfig(**m) for m in cfg_dict["modalities"]]
cfg_dict = {**cfg_dict, "modalities": modalities}

univi_cfg_loaded = UniVIConfig(**cfg_dict)

# ---- Rebuild model + load weights ----
model_loaded = UniVIMultiModalVAE(univi_cfg_loaded).to(device)
model_loaded.load_state_dict(ckpt["state_dict"])

print("Best epoch was:", ckpt.get("best_epoch"), "val loss =", ckpt.get("best_val_loss"))


In [None]:
z_rna  = trainer.encode_modality(rna,  modality="rna")
z_adt  = trainer.encode_modality(adt,  modality="adt")
z_atac = trainer.encode_modality(atac, modality="atac")

rna.obsm["X_univi"]  = z_rna
adt.obsm["X_univi"]  = z_adt
atac.obsm["X_univi"] = z_atac


In [None]:
print(history)


In [None]:
# These should be the *same* cells across modalities
assert np.array_equal(rna.obs_names, adt.obs_names)
assert np.array_equal(rna.obs_names, atac.obs_names)


In [None]:
# ------------------------------
# Build train / val / test adatas
# ------------------------------

rna_train_adata  = rna[train_idx].copy()
rna_val_adata    = rna[val_idx].copy()
rna_test_adata   = rna[test_idx].copy()

adt_train_adata  = adt[train_idx].copy()
adt_val_adata    = adt[val_idx].copy()
adt_test_adata   = adt[test_idx].copy()

atac_train_adata = atac[train_idx].copy()
atac_val_adata   = atac[val_idx].copy()
atac_test_adata  = atac[test_idx].copy()


In [None]:
print(rna_test_adata)
print(adt_test_adata)
print(atac_test_adata)


In [None]:
import scanpy as sc

# Make plots a bit bigger / nicer
sc.settings.set_figure_params(dpi=200, figsize=(16, 14))


In [None]:
# ============================
# UniVI TEA-seq evaluation (RNA / ADT / ATAC) – label-free, tri-modal
# ============================
import os
import numpy as np
import torch
import scipy.sparse as sp
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc

from sklearn.metrics import silhouette_score
from sklearn.neighbors import NearestNeighbors

from univi import evaluation as univi_eval
from univi import plotting as univi_plot  # not strictly needed, but left for convenience

from sklearn.metrics import silhouette_score
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import pairwise_distances  # NEW: for per-cell FOSCTTM

# -----------------------------------------
# CONFIG
# -----------------------------------------
FIGDIR = "../figures/teaseq_univi_tri_modal_eval"
os.makedirs(FIGDIR, exist_ok=True)

device = train_cfg.device  # e.g. "cuda" or "cpu"

# -----------------------------------------
# 0. Sanity checks
# -----------------------------------------
assert (
    rna_test_adata.n_obs == adt_test_adata.n_obs == atac_test_adata.n_obs
), "RNA / ADT / ATAC TEST sets must have same #cells"
assert np.array_equal(rna_test_adata.obs_names, adt_test_adata.obs_names), (
    "RNA and ADT obs_names must match 1:1 for pairwise metrics."
)
assert np.array_equal(rna_test_adata.obs_names, atac_test_adata.obs_names), (
    "RNA and ATAC obs_names must match 1:1 for pairwise metrics."
)

print(f"Test cells: {rna_test_adata.n_obs}")

# -----------------------------------------
# 1. Encode latent embeddings for test sets
# -----------------------------------------
print("\nEncoding test sets into UniVI latent space...")

z_rna  = univi_eval.encode_adata(model, rna_test_adata,  modality="rna",  device=device)
z_adt  = univi_eval.encode_adata(model, adt_test_adata,  modality="adt",  device=device)
z_atac = univi_eval.encode_adata(model, atac_test_adata, modality="atac", device=device)

rna_test_adata.obsm["X_univi"]  = z_rna
adt_test_adata.obsm["X_univi"]  = z_adt
atac_test_adata.obsm["X_univi"] = z_atac

print("Latent shapes (test):")
print("  RNA :", z_rna.shape)
print("  ADT :", z_adt.shape)
print("  ATAC:", z_atac.shape)

# -----------------------------------------
# 2. FOSCTTM (pairwise alignment)
# -----------------------------------------
'''
print("\nComputing FOSCTTM for each modality pair (lower = better)...")
fos_rna_adt  = univi_eval.compute_foscttm(z_rna,  z_adt)
fos_rna_atac = univi_eval.compute_foscttm(z_rna,  z_atac)
fos_adt_atac = univi_eval.compute_foscttm(z_adt,  z_atac)

print(f"  RNA  vs ADT : {fos_rna_adt:.4f}")
print(f"  RNA  vs ATAC: {fos_rna_atac:.4f}")
print(f"  ADT  vs ATAC: {fos_adt_atac:.4f}")

plt.figure(figsize=(4, 4))
pairs = ["RNA–ADT", "RNA–ATAC", "ADT–ATAC"]
vals = [fos_rna_adt, fos_rna_atac, fos_adt_atac]
sns.barplot(x=pairs, y=vals)
plt.ylabel("FOSCTTM (mean)")
plt.title("Tri-modal FOSCTTM (lower = better)")
#plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, "foscttm_barplot.png"))
plt.show()
plt.close()
'''

# Added error bars
def foscttm_per_cell(z_src, z_tgt, metric="euclidean"):
    """
    Compute per-cell FOSCTTM:
    fraction of target cells closer than the true match (1:1 pairing).
    Assumes z_src[i] ↔ z_tgt[i] is the true pair.
    """
    assert z_src.shape[0] == z_tgt.shape[0], "Need 1:1 pairing for FOSCTTM"
    dists = pairwise_distances(z_src, z_tgt, metric=metric)        # (n_src, n_tgt)
    n = dists.shape[0]
    true_d = dists[np.arange(n), np.arange(n)]                     # (n,)
    # For each i, fraction of j with d(i, j) < d(i, true_match)
    per_cell = (dists < true_d[:, None]).mean(axis=1)
    return per_cell


print("\nComputing FOSCTTM for each modality pair (lower = better)...")

fos_rna_adt_cells  = foscttm_per_cell(z_rna,  z_adt)
fos_rna_atac_cells = foscttm_per_cell(z_rna,  z_atac)
fos_adt_atac_cells = foscttm_per_cell(z_adt,  z_atac)

fos_rna_adt_mean  = fos_rna_adt_cells.mean()
fos_rna_atac_mean = fos_rna_atac_cells.mean()
fos_adt_atac_mean = fos_adt_atac_cells.mean()

# Standard error of the mean (SEM)
n = z_rna.shape[0]
fos_rna_adt_sem  = fos_rna_adt_cells.std(ddof=1)  / np.sqrt(n)
fos_rna_atac_sem = fos_rna_atac_cells.std(ddof=1) / np.sqrt(n)
fos_adt_atac_sem = fos_adt_atac_cells.std(ddof=1) / np.sqrt(n)

print(f"  RNA  vs ADT : {fos_rna_adt_mean:.4f} ± {fos_rna_adt_sem:.4f} (SEM)")
print(f"  RNA  vs ATAC: {fos_rna_atac_mean:.4f} ± {fos_rna_atac_sem:.4f} (SEM)")
print(f"  ADT  vs ATAC: {fos_adt_atac_mean:.4f} ± {fos_adt_atac_sem:.4f} (SEM)")

pairs = ["RNA–ADT", "RNA–ATAC", "ADT–ATAC"]
vals  = [fos_rna_adt_mean, fos_rna_atac_mean, fos_adt_atac_mean]
errs  = [fos_rna_adt_sem,  fos_rna_atac_sem,  fos_adt_atac_sem]

plt.figure(figsize=(4, 4))
plt.bar(pairs, vals, yerr=errs, capsize=5)
plt.ylabel("FOSCTTM (mean ± SEM)")
plt.title("Tri-modal FOSCTTM (lower = better)")
plt.savefig(os.path.join(FIGDIR, "foscttm_barplot.png"))
plt.show()
plt.close()

# -----------------------------------------
# 3. Modality mixing (all three modalities)
# -----------------------------------------
Z_joint = np.concatenate([z_rna, z_adt, z_atac], axis=0)
modality_labels = np.array(
    ["rna"]  * z_rna.shape[0]
    + ["adt"]  * z_adt.shape[0]
    + ["atac"] * z_atac.shape[0]
)

mixing_score = univi_eval.compute_modality_mixing(
    Z_joint,
    modality_labels,
    k=30,
)
print(f"\nGlobal modality mixing score (RNA/ADT/ATAC, k=30): {mixing_score:.3f}")

# kNN modality composition heatmap
print("Computing kNN modality composition...")
k = 30
nn = NearestNeighbors(n_neighbors=k + 1)
nn.fit(Z_joint)
_, idx = nn.kneighbors(Z_joint)

idx_neighbors = idx[:, 1:]  # drop self
neighbor_mods = modality_labels[idx_neighbors]

modalities = np.array(["rna", "adt", "atac"])
comp_matrix = np.zeros((len(modalities), len(modalities)))  # row = center, col = neighbor

for i, m_center in enumerate(modalities):
    mask_center = modality_labels == m_center
    neigh_for_center = neighbor_mods[mask_center].reshape(-1)
    for j, m_neigh in enumerate(modalities):
        comp_matrix[i, j] = (neigh_for_center == m_neigh).mean()

same_mod_frac = (neighbor_mods == modality_labels[:, None]).mean()
print(f"  Fraction of neighbors with same modality (k={k}): {same_mod_frac:.3f}")

plt.figure(figsize=(5, 4))
sns.heatmap(
    comp_matrix,
    annot=True,
    fmt=".2f",
    xticklabels=modalities,
    yticklabels=modalities,
    cmap="viridis",
)
plt.xlabel("Neighbor modality")
plt.ylabel("Center modality")
plt.title(f"kNN modality composition (k={k})")
#plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, "knn_modality_composition.png"))
plt.show()
plt.close()

# -----------------------------------------
# 4. UMAP on UniVI latent (tri-modal)
# -----------------------------------------
print("\nBuilding tri-modal UMAP on UniVI latent...")

# Tag each test set with modality
rna_tmp  = rna_test_adata.copy()
adt_tmp  = adt_test_adata.copy()
atac_tmp = atac_test_adata.copy()

rna_tmp.obs["univi_source"]  = "rna"
adt_tmp.obs["univi_source"]  = "adt"
atac_tmp.obs["univi_source"] = "atac"

combined = rna_tmp.concatenate(
    adt_tmp,
    atac_tmp,
    join="outer",
    batch_key="univi_batch",
    batch_categories=["rna", "adt", "atac"],
    index_unique=None,
)

# ensure latent is correctly stacked
combined.obsm["X_univi"] = np.vstack([
    rna_test_adata.obsm["X_univi"],
    adt_test_adata.obsm["X_univi"],
    atac_test_adata.obsm["X_univi"],
])

# neighbors / UMAP / Leiden
sc.pp.neighbors(combined, use_rep="X_univi", n_neighbors=30)
sc.tl.umap(combined)
sc.tl.leiden(combined, key_added="univi_leiden", resolution=0.8)

# UMAP colored by modality
sc.pl.umap(
    combined,
    color="univi_source",
    size=65,
    alpha=0.8,
    show=False,
)
plt.savefig(os.path.join(FIGDIR, "umap_tri_modal_modality.png"), bbox_inches="tight")
plt.show()
plt.close()

# UMAP colored by Leiden clusters (pseudo-clusters)
sc.pl.umap(
    combined,
    color="univi_leiden",
    size=65,
    alpha=0.8,
    show=False,
)
plt.savefig(os.path.join(FIGDIR, "umap_tri_modal_leiden.png"), bbox_inches="tight")
plt.show()
plt.close()

# Per-modality UMAPs, colored by Leiden
for mod in ["rna", "adt", "atac"]:
    sub = combined[combined.obs["univi_source"] == mod].copy()
    sc.pl.umap(
        sub,
        color="univi_leiden",
        size=65,
        alpha=0.8,
        show=False,
    )
    plt.savefig(
        os.path.join(FIGDIR, f"umap_{mod}_only_leiden.png"),
        bbox_inches="tight",
    )
    plt.show()
    plt.close()

# -----------------------------------------
# 5. Latent geometry diagnostics
# -----------------------------------------
print("\nLatent geometry diagnostics...")

def latent_norms(z, label):
    norms = np.linalg.norm(z, axis=1)
    return norms, np.repeat(label, len(norms))

norm_rna,  lab_rna  = latent_norms(z_rna,  "RNA")
norm_adt,  lab_adt  = latent_norms(z_adt,  "ADT")
norm_atac, lab_atac = latent_norms(z_atac, "ATAC")

norms_all = np.concatenate([norm_rna, norm_adt, norm_atac])
labs_all  = np.concatenate([lab_rna, lab_adt, lab_atac])

plt.figure(figsize=(5, 4))
sns.violinplot(x=labs_all, y=norms_all, inner="box")
plt.ylabel("‖z‖ (L2 norm)")
plt.xlabel("Modality")
plt.title("Latent L2-norm distribution by modality")
#plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, "latent_norms_by_modality.png"))
plt.show()
plt.close()

# Latent correlation heatmap of RNA latent dims
corr_latent = np.corrcoef(z_rna, rowvar=False)
plt.figure(figsize=(6, 5))
sns.heatmap(corr_latent, cmap="vlag", center=0)
plt.title("RNA latent dimension correlation (test set)")
#plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, "latent_corr_heatmap_rna.png"))
plt.show()
plt.close()

# Silhouette score on modality (lower = better mixing)
if len(np.unique(modality_labels)) > 1:
    sil_mod = silhouette_score(Z_joint, modality_labels)
else:
    sil_mod = np.nan
print(f"Silhouette (modality) on UniVI latent: {sil_mod:.3f}")

# -----------------------------------------
# 6. Local modality entropy (kNN) + UMAP
# -----------------------------------------
print("\nComputing local modality entropy...")

n_neighbors_local = 30
nn_local = NearestNeighbors(n_neighbors=n_neighbors_local, metric="euclidean")
nn_local.fit(Z_joint)
_, idx_local = nn_local.kneighbors(Z_joint)

mods = modality_labels
local_modality_entropy = []

for i in range(Z_joint.shape[0]):
    neigh = idx_local[i, 1:]  # drop self
    neigh_mod = mods[neigh]

    # empirical distribution over modalities
    ent = 0.0
    for m in modalities:
        p = (neigh_mod == m).mean()
        if p > 0:
            ent -= p * np.log2(p)
    local_modality_entropy.append(ent)

local_modality_entropy = np.asarray(local_modality_entropy)
combined.obs["local_modality_entropy"] = local_modality_entropy

print(f"Mean local modality entropy (k={n_neighbors_local}): {local_modality_entropy.mean():.3f}")

plt.figure(figsize=(5, 4))
plt.hist(local_modality_entropy, bins=30)
plt.xlabel("Local modality entropy (bits)")
plt.ylabel("Cells")
plt.title("kNN modality entropy")
#plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, "local_modality_entropy_hist.png"))
plt.show()
plt.close()

# UMAP colored by local modality entropy
sc.pl.umap(
    combined,
    color="local_modality_entropy",
    size=65,
    alpha=0.8,
    show=False,
)
plt.savefig(os.path.join(FIGDIR, "umap_local_modality_entropy.png"), bbox_inches="tight")
plt.show()
plt.close()

# -----------------------------------------
# 7. Pairwise matching metrics (top-k) for all modality pairs
# -----------------------------------------
print("\nPairwise matching metrics (top-1 / top-5 / top-10 / top-25 / top-50 / top-75 / top-100)...")
'''
def topk_matching(z_src, z_tgt, pair_name: str, k_match: int = 100):
    nn = NearestNeighbors(n_neighbors=k_match, metric="euclidean")
    nn.fit(z_tgt)
    _, idx_knn = nn.kneighbors(z_src)

    true_idx = np.arange(z_src.shape[0])
    top1_hits  = (idx_knn[:, 0] == true_idx)
    top5_hits  = (idx_knn[:, :5] == true_idx[:, None]).any(axis=1)
    top10_hits = (idx_knn[:, :10] == true_idx[:, None]).any(axis=1)
    top25_hits = (idx_knn[:, :25] == true_idx[:, None]).any(axis=1)
    top50_hits = (idx_knn[:, :50] == true_idx[:, None]).any(axis=1)
    top75_hits = (idx_knn[:, :75] == true_idx[:, None]).any(axis=1)
    top100_hits = (idx_knn[:, :100] == true_idx[:, None]).any(axis=1)

    print(f"  {pair_name}:")
    print(f"    Top-1 accuracy:  {top1_hits.mean():.3f}")
    print(f"    Top-5 accuracy:  {top5_hits.mean():.3f}")
    print(f"    Top-10 accuracy: {top10_hits.mean():.3f}")
    print(f"    Top-25 accuracy: {top25_hits.mean():.3f}")    
    print(f"    Top-50 accuracy: {top50_hits.mean():.3f}")
    print(f"    Top-75 accuracy: {top75_hits.mean():.3f}")    
    print(f"    Top-100 accuracy: {top100_hits.mean():.3f}")

    plt.figure(figsize=(16, 14))
    plt.bar(
        ["Top-1", "Top-5", "Top-10", "Top-25", "Top-50", "Top-75", "Top-100"],
        [top1_hits.mean(), top5_hits.mean(), top10_hits.mean(), top25_hits.mean(), top50_hits.mean(), 
         top75_hits.mean(), top100_hits.mean()],
    )
    plt.ylabel("Fraction of correctly matched pairs")
    plt.title(f"Cross-modal matching accuracy ({pair_name})")
    #plt.tight_layout()
    fname = f"matching_{pair_name.replace(' ', '_').replace('→','to')}.png"
    plt.savefig(os.path.join(FIGDIR, fname))
    plt.show()
    plt.close()
'''

# Added error bars
def topk_matching(z_src, z_tgt, pair_name: str, k_match: int = 100):
    nn = NearestNeighbors(n_neighbors=k_match, metric="euclidean")
    nn.fit(z_tgt)
    _, idx_knn = nn.kneighbors(z_src)

    true_idx = np.arange(z_src.shape[0])
    n_cells  = z_src.shape[0]

    top1_hits   = (idx_knn[:, 0] == true_idx)
    top5_hits   = (idx_knn[:, :5]   == true_idx[:, None]).any(axis=1)
    top10_hits  = (idx_knn[:, :10]  == true_idx[:, None]).any(axis=1)
    top25_hits  = (idx_knn[:, :25]  == true_idx[:, None]).any(axis=1)
    top50_hits  = (idx_knn[:, :50]  == true_idx[:, None]).any(axis=1)
    top75_hits  = (idx_knn[:, :75]  == true_idx[:, None]).any(axis=1)
    top100_hits = (idx_knn[:, :100] == true_idx[:, None]).any(axis=1)

    # Means
    means = np.array([
        top1_hits.mean(),
        top5_hits.mean(),
        top10_hits.mean(),
        top25_hits.mean(),
        top50_hits.mean(),
        top75_hits.mean(),
        top100_hits.mean(),
    ])

    # Binomial standard error: sqrt(p * (1 - p) / n)
    ses = np.sqrt(means * (1.0 - means) / n_cells)

    print(f"  {pair_name}:")
    labels = ["Top-1", "Top-5", "Top-10", "Top-25", "Top-50", "Top-75", "Top-100"]
    for lab, m, se in zip(labels, means, ses):
        print(f"    {lab} accuracy: {m:.3f} ± {se:.3f} (SEM)")

    plt.figure(figsize=(8, 6))
    plt.bar(labels, means, yerr=ses, capsize=5)
    plt.ylabel("Fraction of correctly matched pairs")
    plt.title(f"Cross-modal matching accuracy ({pair_name})")
    fname = f"matching_{pair_name.replace(' ', '_').replace('→','to')}.png"
    plt.savefig(os.path.join(FIGDIR, fname))
    plt.show()
    plt.close()

topk_matching(z_rna,  z_adt,  "RNA→ADT")
topk_matching(z_adt,  z_rna,  "ADT→RNA")
topk_matching(z_rna,  z_atac, "RNA→ATAC")
topk_matching(z_atac, z_rna,  "ATAC→RNA")
topk_matching(z_adt,  z_atac, "ADT→ATAC")
topk_matching(z_atac, z_adt,  "ATAC→ADT")

# -----------------------------------------
# 8. Cross-modal reconstruction metrics
# -----------------------------------------
def _to_dense(X):
    return X.toarray() if sp.issparse(X) else np.asarray(X)

def cross_modal_metrics(
    model,
    src_adata,
    tgt_adata,
    src_mod: str,
    tgt_mod: str,
    name_prefix: str,
    device: str,
):
    Xhat_tgt = univi_eval.cross_modal_predict(
        model,
        adata_src=src_adata,
        src_mod=src_mod,
        tgt_mod=tgt_mod,
        device=device,
        batch_size=512,
    )

    X_tgt = _to_dense(tgt_adata.X)

    mse_feat  = univi_eval.mse_per_feature(X_tgt, Xhat_tgt)
    corr_feat = univi_eval.pearson_corr_per_feature(X_tgt, Xhat_tgt)

    print(f"\nCross-modal reconstruction: {src_mod} → {tgt_mod}")
    print(f"  Mean feature MSE: {mse_feat.mean():.4f}")
    print(f"  Mean feature Pearson r: {corr_feat.mean():.3f}")

    # Histogram of per-feature correlation
    plt.figure(figsize=(5, 4))
    sns.histplot(corr_feat, bins=40, kde=False)
    plt.xlabel("Per-feature Pearson r")
    plt.ylabel("Count")
    plt.title(f"{src_mod} → {tgt_mod}: feature-wise correlation")
    plt.tight_layout()
    plt.savefig(os.path.join(FIGDIR, f"{name_prefix}_corr_hist.png"))
    plt.show()
    plt.close()

    # Histogram of per-feature MSE
    plt.figure(figsize=(5, 4))
    sns.histplot(mse_feat, bins=40, kde=False)
    plt.xlabel("Per-feature MSE")
    plt.ylabel("Count")
    plt.title(f"{src_mod} → {tgt_mod}: feature-wise MSE")
    plt.tight_layout()
    plt.savefig(os.path.join(FIGDIR, f"{name_prefix}_mse_hist.png"))
    plt.show()
    plt.close()

    return mse_feat, corr_feat

# Evaluate key directions on TEST set
_ = cross_modal_metrics(model, rna_test_adata, adt_test_adata,
                        src_mod="rna", tgt_mod="adt",
                        name_prefix="RNA_to_ADT_test", device=device)

_ = cross_modal_metrics(model, rna_test_adata, atac_test_adata,
                        src_mod="rna", tgt_mod="atac",
                        name_prefix="RNA_to_ATAC_test", device=device)

_ = cross_modal_metrics(model, adt_test_adata, rna_test_adata,
                        src_mod="adt", tgt_mod="rna",
                        name_prefix="ADT_to_RNA_test", device=device)

_ = cross_modal_metrics(model, atac_test_adata, rna_test_adata,
                        src_mod="atac", tgt_mod="rna",
                        name_prefix="ATAC_to_RNA_test", device=device)


In [None]:
# ============================
# 10. Summarize TEA-seq metrics & save as JSON
# ============================
import json

teaseq_metrics = {
    # Pairwise FOSCTTM
    "foscttm_rna_adt": float(fos_rna_adt),
    "foscttm_rna_atac": float(fos_rna_atac),
    "foscttm_adt_atac": float(fos_adt_atac),

    # Global mixing
    "modality_mixing_k20": float(mixing_score),
    "same_modality_neighbor_frac_k20": float(same_mod_frac),

    # Latent geometry
    "silhouette_modality": float(sil_mod) if not np.isnan(sil_mod) else None,
    "mean_local_modality_entropy_k20": float(local_modality_entropy.mean()),
    "median_local_modality_entropy_k20": float(np.median(local_modality_entropy)),

    # Dataset sizes
    "n_cells_test": int(rna_test_adata.n_obs),
    "n_genes_rna": int(rna_test_adata.n_vars),
    "n_features_adt": int(adt_test_adata.n_vars),
    "n_features_atac": int(atac_test_adata.n_vars),
}

'''
metrics_path = os.path.join(FIGDIR, "teaseq_univi_metrics.json")
with open(metrics_path, "w") as f:
    json.dump(teaseq_metrics, f, indent=2)

print(f"\n[TEA-seq] Saved benchmark metrics to: {metrics_path}")
'''

# ============================
# 11. kNN distance diagnostics: same vs cross-modality neighbors
# ============================
print("\nComputing kNN distance distributions (same vs cross-modality)...")

from sklearn.neighbors import NearestNeighbors

k_dist = 30
nn_dist = NearestNeighbors(n_neighbors=k_dist + 1, metric="euclidean")
nn_dist.fit(Z_joint)
dists_all, idx_all = nn_dist.kneighbors(Z_joint)

# drop self
dists_neighbors = dists_all[:, 1:]              # (n_cells, k_dist)
idx_neighbors_dist = idx_all[:, 1:]             # already had neighbor_mods from earlier
flat_dists = dists_neighbors.reshape(-1)

# Center & neighbor modalities (flattened to per-edge view)
center_mods_flat = np.repeat(modality_labels, k_dist)
neighbor_mods_flat = neighbor_mods.reshape(-1)

same_mod_edge = neighbor_mods_flat == center_mods_flat

plt.figure(figsize=(6, 4))
sns.kdeplot(flat_dists[same_mod_edge], label="same modality", fill=True, alpha=0.6)
sns.kdeplot(flat_dists[~same_mod_edge], label="different modality", fill=True, alpha=0.6)
plt.xlabel("Euclidean distance in UniVI latent")
plt.ylabel("Density")
plt.title(f"kNN distance distribution (k={k_dist})")
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, "knn_distance_same_vs_cross_modality.png"), dpi=200)
plt.show()
plt.close()


# ============================
# 12. Cross-modal neighbor fraction per cell (and UMAP)
# ============================
print("\nComputing per-cell cross-modal neighbor fraction...")

# fraction of neighbors NOT sharing the cell's modality
cross_mod_neighbor_frac = 1.0 - (neighbor_mods == modality_labels[:, None]).mean(axis=1)
combined.obs["cross_mod_neighbor_frac"] = cross_mod_neighbor_frac

plt.figure(figsize=(5, 4))
sns.histplot(cross_mod_neighbor_frac, bins=30)
plt.xlabel("Fraction of neighbors with different modality")
plt.ylabel("Cells")
plt.title(f"Cross-modal neighbor fraction (k={k})")
plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, "cross_mod_neighbor_fraction_hist.png"), dpi=200)
plt.show()
plt.close()

# UMAP colored by cross-modal neighbor fraction
sc.pl.umap(
    combined,
    color="cross_mod_neighbor_frac",
    size=65,
    alpha=0.8,
    cmap="viridis",
    show=False,
)
plt.title("UMAP – cross-modal neighbor fraction")
plt.savefig(os.path.join(FIGDIR, "umap_cross_mod_neighbor_fraction.png"),
            bbox_inches="tight", dpi=200)
plt.show()
plt.close()


# ============================
# 13. Per-cluster modality composition (Leiden clusters)
# ============================
print("\nComputing modality composition per Leiden cluster...")

if "univi_leiden" in combined.obs.columns:
    comp_ct = pd.crosstab(combined.obs["univi_leiden"], combined.obs["univi_source"])
    comp_prop = comp_ct.div(comp_ct.sum(axis=1), axis=0)  # row-normalize

    plt.figure(figsize=(7, 5))
    comp_prop.sort_index().plot(
        kind="bar",
        stacked=True,
        width=0.9,
        colormap="tab10",
    )
    plt.xlabel("UniVI Leiden cluster")
    plt.ylabel("Fraction of cells")
    plt.title("Modality composition per UniVI Leiden cluster")
    plt.legend(title="Modality", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()
    plt.savefig(os.path.join(FIGDIR, "cluster_modality_composition_stacked_bar.png"),
                dpi=200)
    plt.show()
    plt.close()
else:
    print("  WARNING: 'univi_leiden' not found in combined.obs; skipping cluster composition plot.")


# ============================
# 14. Feature-level cross-modal scatter plots (RNA → ADT / ATAC)
# ============================
print("\nFeature-level scatter plots for cross-modal prediction (RNA→ADT / RNA→ATAC)...")

# Example: RNA → ADT for a small subset of cells & selected markers
# (adjust marker names to what you actually have in adt_test_adata.var_names)

n_scatter_cells = min(5000, adt_test_adata.n_obs)  # subsample if huge
cell_idx = np.random.default_rng(42).choice(adt_test_adata.n_obs, n_scatter_cells, replace=False)

# recompute predictions just for this subset to avoid reusing big arrays
Xhat_adt_sub = univi_eval.cross_modal_predict(
    model,
    adata_src=rna_test_adata[cell_idx],
    src_mod="rna",
    tgt_mod="adt",
    device=device,
    batch_size=512,
)
X_adt_sub = _to_dense(adt_test_adata[cell_idx].X)

adt_markers_to_plot = [
    # put your favorite ADT markers here
    # e.g. "CD3", "CD4", "CD8A", "CD56", ...
]

for marker in adt_markers_to_plot:
    if marker not in adt_test_adata.var_names:
        print(f"  [RNA→ADT] Marker '{marker}' not found in adt_test_adata.var_names; skipping.")
        continue

    j = np.where(adt_test_adata.var_names == marker)[0][0]
    y_true = X_adt_sub[:, j]
    y_pred = Xhat_adt_sub[:, j]

    plt.figure(figsize=(4.5, 4))
    plt.hexbin(y_true, y_pred, gridsize=50, mincnt=1)
    plt.xlabel(f"True ADT ({marker})")
    plt.ylabel(f"Predicted ADT ({marker})")
    plt.title(f"RNA→ADT prediction for {marker}")
    plt.tight_layout()
    fname = f"scatter_RNA_to_ADT_{marker}.png".replace(" ", "_")
    plt.savefig(os.path.join(FIGDIR, fname), dpi=200)
    plt.show()
    plt.close()


# Example: RNA → ATAC for a few peaks / gene-body features (if named)
# (This is more exploratory since ATAC features are often peaks; choose a few named ones if available.)

n_scatter_cells_atac = min(5000, atac_test_adata.n_obs)
cell_idx_atac = np.random.default_rng(123).choice(atac_test_adata.n_obs, n_scatter_cells_atac, replace=False)

Xhat_atac_sub = univi_eval.cross_modal_predict(
    model,
    adata_src=rna_test_adata[cell_idx_atac],
    src_mod="rna",
    tgt_mod="atac",
    device=device,
    batch_size=512,
)
X_atac_sub = _to_dense(atac_test_adata[cell_idx_atac].X)

# If you have named ATAC features (e.g. gene body aggregates), you can list them here.
atac_features_to_plot = [
    # e.g. "TNFRSF4_body", "IFNG_enh", ...
]

for feat in atac_features_to_plot:
    if feat not in atac_test_adata.var_names:
        print(f"  [RNA→ATAC] Feature '{feat}' not found in atac_test_adata.var_names; skipping.")
        continue

    j = np.where(atac_test_adata.var_names == feat)[0][0]
    y_true = X_atac_sub[:, j]
    y_pred = Xhat_atac_sub[:, j]

    plt.figure(figsize=(4.5, 4))
    plt.hexbin(y_true, y_pred, gridsize=50, mincnt=1)
    plt.xlabel(f"True ATAC ({feat})")
    plt.ylabel(f"Predicted ATAC ({feat})")
    plt.title(f"RNA→ATAC prediction for {feat}")
    plt.tight_layout()
    fname = f"scatter_RNA_to_ATAC_{feat}.png".replace(" ", "_")
    plt.savefig(os.path.join(FIGDIR, fname), dpi=200)
    plt.show()
    plt.close()


In [None]:
import scanpy as sc
import numpy as np

import scanpy as sc
import numpy as np
import scipy.sparse as sp

# -----------------------------------------
# Helper: unimodal Leiden clustering (safer)
# -----------------------------------------
def compute_unimodal_leiden(
    adata,
    mod: str,
    key_added: str = None,
    resolution: float = 1.0,
    n_neighbors: int = 15,
    n_pcs: int = 30,
):
    """
    Compute unimodal Leiden clusters for a single AnnData.

    Parameters
    ----------
    adata : AnnData
        Modality-specific AnnData (e.g. RNA-only, ADT-only, ATAC-only).
    mod : {"rna", "adt", "atac"}
        Name of the modality (used to pick rep + default key).
    key_added : str, optional
        Name of the .obs column for clusters (default: f"{mod}_leiden").
    resolution : float
        Leiden resolution parameter.
    n_neighbors : int
        Number of neighbors for kNN graph.
    n_pcs : int
        Number of PCs to use when computing PCA (if needed).

    Returns
    -------
    adata : AnnData
        Same object with a new .obs[key_added] column.
    """
    if adata is None or adata.n_obs == 0:
        print(f"[{mod}] No cells; skipping Leiden.")
        return adata

    if key_added is None:
        key_added = f"{mod}_leiden"

    # Decide which representation to use
    use_rep = None

    # ---------- choose rep ----------
    if mod in ("rna", "adt"):
        if "X_pca" in adata.obsm_keys():
            use_rep = "X_pca"
            print(f"[{mod}] Using existing X_pca for neighbors/Leiden.")
        else:
            print(f"[{mod}] Computing PCA on existing .X for neighbors/Leiden (no extra scaling)...")
            tmp = adata.copy()

            # Clean NaNs / infs if present
            X = tmp.X.A if sp.issparse(tmp.X) else np.asarray(tmp.X)
            if not np.isfinite(X).all():
                n_bad = np.sum(~np.isfinite(X))
                print(f"[{mod}] Warning: found {n_bad} non-finite entries in .X; replacing with 0.")
                X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
                if sp.issparse(tmp.X):
                    tmp.X = sp.csr_matrix(X)
                else:
                    tmp.X = X

            # PCA directly on (already processed) X
            sc.tl.pca(tmp, n_comps=n_pcs)
            adata.obsm["X_pca"] = tmp.obsm["X_pca"]
            use_rep = "X_pca"

    elif mod == "atac":
        if "X_lsi" in adata.obsm_keys():
            use_rep = "X_lsi"
            print(f"[{mod}] Using existing X_lsi for neighbors/Leiden.")
        else:
            print(f"[{mod}] X_lsi not found; computing PCA on .X for neighbors/Leiden (no extra scaling)...")
            tmp = adata.copy()

            X = tmp.X.A if sp.issparse(tmp.X) else np.asarray(tmp.X)
            if not np.isfinite(X).all():
                n_bad = np.sum(~np.isfinite(X))
                print(f"[{mod}] Warning: found {n_bad} non-finite entries in .X; replacing with 0.")
                X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
                if sp.issparse(tmp.X):
                    tmp.X = sp.csr_matrix(X)
                else:
                    tmp.X = X

            sc.tl.pca(tmp, n_comps=n_pcs)
            adata.obsm["X_pca"] = tmp.obsm["X_pca"]
            use_rep = "X_pca"

    else:
        # Fallback for any other modality
        if "X_pca" in adata.obsm_keys():
            use_rep = "X_pca"
            print(f"[{mod}] Using existing X_pca for neighbors/Leiden.")
        else:
            print(f"[{mod}] Unknown modality; computing PCA on .X for neighbors/Leiden (no extra scaling)...")
            tmp = adata.copy()

            X = tmp.X.A if sp.issparse(tmp.X) else np.asarray(tmp.X)
            if not np.isfinite(X).all():
                n_bad = np.sum(~np.isfinite(X))
                print(f"[{mod}] Warning: found {n_bad} non-finite entries in .X; replacing with 0.")
                X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
                if sp.issparse(tmp.X):
                    tmp.X = sp.csr_matrix(X)
                else:
                    tmp.X = X

            sc.tl.pca(tmp, n_comps=n_pcs)
            adata.obsm["X_pca"] = tmp.obsm["X_pca"]
            use_rep = "X_pca"

    # ---------- neighbors + Leiden ----------
    print(f"[{mod}] Computing neighbors (n_neighbors={n_neighbors}, use_rep={use_rep})...")
    sc.pp.neighbors(adata, n_neighbors=n_neighbors, use_rep=use_rep)

    print(f"[{mod}] Running Leiden (resolution={resolution}) → obs['{key_added}']...")
    sc.tl.leiden(adata, key_added=key_added, resolution=resolution)

    n_clusters = adata.obs[key_added].nunique()
    print(f"[{mod}] Leiden done: {n_clusters} clusters in obs['{key_added}'].")

    return adata

# -----------------------------------------
# Run unimodal Leiden on test sets
# -----------------------------------------

# You can tune resolution per modality if you want:
resolutions = {
    "rna":  1.0,
    "adt":  1.0,
    "atac": 1.0,
}

specs = [
    ("rna",  locals().get("rna_test_adata",  None)),
    ("adt",  locals().get("adt_test_adata",  None)),
    ("atac", locals().get("atac_test_adata", None)),
]

for mod, adata in specs:
    if adata is None:
        print(f"[{mod}] No AnnData object found (e.g. rna_test is None); skipping.")
        continue

    compute_unimodal_leiden(
        adata,
        mod=mod,
        key_added=f"{mod}_leiden",
        resolution=resolutions.get(mod, 1.0),
        n_neighbors=15,
        n_pcs=30,
    )


In [None]:
# -----------------------------------------
# 9. Denoising with decoders + pseudo-celltypes
# -----------------------------------------
print("\nDenoising on test sets with unimodal Leiden clusters as pseudo-celltypes...")

# Map modality -> expected Leiden key in .obs
denoise_specs = [
    (rna_test_adata,  "rna",  "rna_test_adata",  "rna_leiden"),
    (adt_test_adata,  "adt",  "adt_test_adata",  "adt_leiden"),
    (atac_test_adata, "atac", "atac_test_adata", "atac_leiden"),
]

n_top_heatmap_features = 40  # number of features (genes/ADTs/peaks) to show per heatmap

for adata, mod, tag, cluster_key in denoise_specs:
    if adata is None or adata.n_obs == 0:
        print(f"  Skipping {tag} ({mod}) – no cells.")
        continue

    print(f"\n  Denoising {tag} ({mod})...")

    # Run UniVI decoder to get denoised layer
    univi_eval.denoise_adata(
        model,
        adata,
        modality=mod,
        device=device,
        batch_size=512,
        out_layer="univi_denoised",
    )

    # Per-cell MSE (raw vs denoised) in the feature space used for that modality
    X_raw = _to_dense(adata.X)
    X_den = _to_dense(adata.layers["univi_denoised"])
    per_cell_mse = ((X_raw - X_den) ** 2).mean(axis=1)

    # Store per-cell MSE in obs so we can group by clusters later
    adata.obs[f"{mod}_denoise_mse"] = per_cell_mse

    # Global histogram (same as before)
    plt.figure(figsize=(5, 4))
    sns.histplot(per_cell_mse, bins=40)
    plt.xlabel("Per-cell MSE (raw vs denoised)")
    plt.ylabel("Count")
    plt.title(f"Denoising quality: {tag} ({mod})")
    plt.tight_layout()
    plt.savefig(os.path.join(FIGDIR, f"denoise_{tag}_{mod}.png"))
    plt.show()
    plt.close()

    # -----------------------------------------
    # Cluster-level summaries using unimodal Leiden as pseudo-celltypes
    # -----------------------------------------
    if cluster_key not in adata.obs.columns:
        print(f"  [!] No '{cluster_key}' column in adata.obs – skipping pseudo-celltype summaries.")
        continue

    print(f"  Computing denoising summaries by cluster: {cluster_key} (pseudo-celltypes)")

    # Build a small df for aggregation
    df_clusters = (
        adata.obs[[cluster_key]]
        .assign(denoise_mse=per_cell_mse)
        .groupby(cluster_key)["denoise_mse"]
        .agg(["mean", "median", "std", "count"])
        .sort_values("mean")
    )

    print(f"\n  Per-cluster denoising MSE for {tag} ({mod}) using {cluster_key}:")
    print(df_clusters)

    # Boxplot of per-cell MSE by cluster (ordered by cluster mean MSE)
    order = df_clusters.index.tolist()
    plt.figure(figsize=(max(6, 0.4 * len(order)), 4))
    sns.boxplot(
        data=adata.obs,
        x=cluster_key,
        y=f"{mod}_denoise_mse",
        order=order,
    )
    plt.xticks(rotation=90)
    plt.xlabel(f"{cluster_key} (pseudo-celltypes)")
    plt.ylabel("Per-cell MSE (raw vs denoised)")
    plt.title(f"Denoising MSE by cluster: {tag} ({mod})")
    plt.tight_layout()
    plt.savefig(os.path.join(FIGDIR, f"denoise_{tag}_{mod}_by_{cluster_key}.png"))
    plt.show()
    plt.close()

    # -----------------------------------------
    # Per-cluster expression heatmaps: raw vs denoised vs delta
    # -----------------------------------------
    print(f"  Building cluster-level expression heatmaps for {tag} ({mod})...")

    # Choose a subset of features for visualization
    # Priority: highly variable, otherwise top by variance in X_raw
    if "highly_variable" in adata.var.columns and adata.var["highly_variable"].any():
        hv_mask = adata.var["highly_variable"].values
        hv_idx = np.where(hv_mask)[0]
        feat_idx = hv_idx[:min(n_top_heatmap_features, len(hv_idx))]
        if len(feat_idx) == 0:  # safety
            var_feat = X_raw.var(axis=0)
            feat_idx = np.argsort(var_feat)[::-1][:n_top_heatmap_features]
    else:
        var_feat = X_raw.var(axis=0)
        feat_idx = np.argsort(var_feat)[::-1][:n_top_heatmap_features]

    feat_idx = np.asarray(feat_idx)
    feat_names = adata.var_names[feat_idx]

    # Compute per-cluster mean expression (raw and denoised)
    cluster_ids = order  # same ordering as boxplot
    raw_means = []
    den_means = []

    for cl in cluster_ids:
        mask = adata.obs[cluster_key] == cl
        if mask.sum() == 0:
            raw_means.append(np.full(len(feat_idx), np.nan))
            den_means.append(np.full(len(feat_idx), np.nan))
        else:
            raw_means.append(X_raw[mask][:, feat_idx].mean(axis=0))
            den_means.append(X_den[mask][:, feat_idx].mean(axis=0))

    raw_means = np.vstack(raw_means)
    den_means = np.vstack(den_means)
    delta_means = den_means - raw_means

    # Shared color scale for raw/denoised
    vmin = np.nanpercentile(np.concatenate([raw_means, den_means]), 2)
    vmax = np.nanpercentile(np.concatenate([raw_means, den_means]), 98)

    df_raw   = pd.DataFrame(raw_means,   index=cluster_ids, columns=feat_names)
    df_den   = pd.DataFrame(den_means,   index=cluster_ids, columns=feat_names)
    df_delta = pd.DataFrame(delta_means, index=cluster_ids, columns=feat_names)

    # Raw expression heatmap
    plt.figure(figsize=(0.35 * len(feat_names) + 4, 0.35 * len(cluster_ids) + 3))
    sns.heatmap(
        df_raw,
        cmap="viridis",
        vmin=vmin,
        vmax=vmax,
        cbar_kws={"label": "Mean expression (raw)"},
    )
    plt.xlabel("Features")
    plt.ylabel(cluster_key)
    plt.title(f"{tag} ({mod}) – cluster-level raw expression")
    plt.tight_layout()
    plt.savefig(os.path.join(FIGDIR, f"heatmap_{tag}_{mod}_raw_by_{cluster_key}.png"))
    plt.show()
    plt.close()

    # Denoised expression heatmap
    plt.figure(figsize=(0.35 * len(feat_names) + 4, 0.35 * len(cluster_ids) + 3))
    sns.heatmap(
        df_den,
        cmap="viridis",
        vmin=vmin,
        vmax=vmax,
        cbar_kws={"label": "Mean expression (denoised)"},
    )
    plt.xlabel("Features")
    plt.ylabel(cluster_key)
    plt.title(f"{tag} ({mod}) – cluster-level denoised expression")
    plt.tight_layout()
    plt.savefig(os.path.join(FIGDIR, f"heatmap_{tag}_{mod}_denoised_by_{cluster_key}.png"))
    plt.show()
    plt.close()

    # Delta heatmap (denoised - raw)
    v_abs = np.nanmax(np.abs(delta_means))
    plt.figure(figsize=(0.35 * len(feat_names) + 4, 0.35 * len(cluster_ids) + 3))
    sns.heatmap(
        df_delta,
        cmap="vlag",
        center=0,
        vmin=-v_abs,
        vmax=v_abs,
        cbar_kws={"label": "Δ mean expression (denoised - raw)"},
    )
    plt.xlabel("Features")
    plt.ylabel(cluster_key)
    plt.title(f"{tag} ({mod}) – change in cluster-level expression")
    plt.tight_layout()
    plt.savefig(os.path.join(FIGDIR, f"heatmap_{tag}_{mod}_delta_by_{cluster_key}.png"))
    plt.show()
    plt.close()

print("\nUsed unimodal Leiden clusters as pseudo-celltypes for denoising summaries (where available).")


In [None]:
# -----------------------------------------
# After building `combined` and running UMAP / univi_leiden
# -----------------------------------------
# combined was made from rna_tmp, adt_tmp, atac_tmp and has:
#   combined.obs["univi_source"] in {"rna", "adt", "atac"}

# Initialize columns as NaN
for key in ["rna_leiden", "adt_leiden", "atac_leiden"]:
    combined.obs[key] = np.nan

# RNA clusters → rows in `combined` with univi_source == "rna"
if "rna_leiden" in rna_test_adata.obs:
    mask_rna = combined.obs["univi_source"] == "rna"
    n_rna_combined = mask_rna.sum()
    assert n_rna_combined == rna_test_adata.n_obs, (
        f"RNA counts mismatch: combined has {n_rna_combined}, "
        f"rna_test_adata has {rna_test_adata.n_obs}"
    )
    combined.obs.loc[mask_rna, "rna_leiden"] = (
        rna_test_adata.obs["rna_leiden"].astype(str).values
    )

# ADT clusters
if "adt_leiden" in adt_test_adata.obs:
    mask_adt = combined.obs["univi_source"] == "adt"
    n_adt_combined = mask_adt.sum()
    assert n_adt_combined == adt_test_adata.n_obs, (
        f"ADT counts mismatch: combined has {n_adt_combined}, "
        f"adt_test_adata has {adt_test_adata.n_obs}"
    )
    combined.obs.loc[mask_adt, "adt_leiden"] = (
        adt_test_adata.obs["adt_leiden"].astype(str).values
    )

# ATAC clusters
if "atac_leiden" in atac_test_adata.obs:
    mask_atac = combined.obs["univi_source"] == "atac"
    n_atac_combined = mask_atac.sum()
    assert n_atac_combined == atac_test_adata.n_obs, (
        f"ATAC counts mismatch: combined has {n_atac_combined}, "
        f"atac_test_adata has {atac_test_adata.n_obs}"
    )
    combined.obs.loc[mask_atac, "atac_leiden"] = (
        atac_test_adata.obs["atac_leiden"].astype(str).values
    )


In [None]:
# -----------------------------------------
# 10. Overlap of unimodal clusters in UniVI latent space
# -----------------------------------------
print("\nVisualizing UniVI latent UMAP colored by unimodal clusters...")

# These should have been carried into `combined` via concatenate
for key in ["rna_leiden", "adt_leiden", "atac_leiden"]:
    if key in combined.obs.columns:
        print(f"  Plotting UMAP colored by {key}...")
        sc.pl.umap(
            combined,
            color=key,
            size=65,
            alpha=0.8,
            show=False,
        )
        plt.savefig(
            os.path.join(FIGDIR, f"umap_tri_modal_{key}.png"),
            bbox_inches="tight",
        )
        plt.show()
        plt.close()
    else:
        print(f"  [!] {key} not found in combined.obs – skipping this UMAP.")

# -----------------------------------------
# Cluster overlap tables: how do unimodal clusters align cell-by-cell?
# -----------------------------------------
import pandas as pd

def cluster_overlap_heatmap(
    adata_a,
    key_a: str,
    adata_b,
    key_b: str,
    pair_name: str,
    normalize: str = "index",
):
    """
    Make a normalized confusion matrix between two clusterings
    (rows = clusters in A, columns = clusters in B).
    """
    assert np.array_equal(adata_a.obs_names, adata_b.obs_names), (
        f"{pair_name}: obs_names do not match 1:1"
    )

    s_a = adata_a.obs[key_a].astype("category")
    s_b = adata_b.obs[key_b].astype("category")

    tab = pd.crosstab(s_a, s_b, normalize=normalize)

    plt.figure(figsize=(0.5 * tab.shape[1] + 4, 0.5 * tab.shape[0] + 4))
    sns.heatmap(
        tab,
        annot=False,
        cmap="viridis",
        cbar_kws={"label": f"Fraction (normalized by {normalize})"},
    )
    plt.xlabel(key_b)
    plt.ylabel(key_a)
    plt.title(f"Cluster overlap: {pair_name}")
    plt.tight_layout()
    fname = f"cluster_overlap_{pair_name.replace(' ', '_').replace('→','to')}.png"
    plt.savefig(os.path.join(FIGDIR, fname))
    plt.show()
    plt.close()

    return tab

# Only run if the keys exist in obs
if "rna_leiden" in rna_test_adata.obs and "adt_leiden" in adt_test_adata.obs:
    print("\nCluster overlap RNA vs ADT (unimodal Leiden)...")
    tab_rna_adt = cluster_overlap_heatmap(
        rna_test_adata, "rna_leiden",
        adt_test_adata, "adt_leiden",
        pair_name="RNA→ADT",
        normalize="index",
    )

if "rna_leiden" in rna_test_adata.obs and "atac_leiden" in atac_test_adata.obs:
    print("\nCluster overlap RNA vs ATAC (unimodal Leiden)...")
    tab_rna_atac = cluster_overlap_heatmap(
        rna_test_adata, "rna_leiden",
        atac_test_adata, "atac_leiden",
        pair_name="RNA→ATAC",
        normalize="index",
    )

if "adt_leiden" in adt_test_adata.obs and "atac_leiden" in atac_test_adata.obs:
    print("\nCluster overlap ADT vs ATAC (unimodal Leiden)...")
    tab_adt_atac = cluster_overlap_heatmap(
        adt_test_adata, "adt_leiden",
        atac_test_adata, "atac_leiden",
        pair_name="ADT→ATAC",
        normalize="index",
    )


In [None]:
# -----------------------------------------
# 11. Marker exploration (input vs latent vs recon)
# -----------------------------------------
import os
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns  # already imported above, but just in case


def ensure_umap_on_univi(
    adata,
    n_neighbors: int = 30,
    min_dist: float = 0.5,
):
    """
    Ensure adata has a UMAP in .obsm['X_umap'] using .obsm['X_univi'] as input.
    """
    if "X_univi" not in adata.obsm_keys():
        raise ValueError("adata.obsm['X_univi'] is missing; cannot build UMAP on UniVI latent.")
    if "X_umap" in adata.obsm_keys():
        # nothing to do
        return

    print(f"[umap] Computing neighbors/UMAP on UniVI latent (n_neighbors={n_neighbors}, min_dist={min_dist})...")
    sc.pp.neighbors(adata, use_rep="X_univi", n_neighbors=n_neighbors)
    sc.tl.umap(adata, min_dist=min_dist)


def plot_markers_on_umap(
    adata,
    marker_dict: dict,
    title_prefix: str,
    figdir: str = FIGDIR,
    n_neighbors: int = 30,
    min_dist: float = 0.5,
):
    """
    Plot UMAPs colored by marker sets for a given AnnData.

    - Uses adata.obsm['X_univi'] to build UMAP if needed.
    - marker_dict = { "group_name": [marker1, marker2, ...], ... }
    """
    os.makedirs(figdir, exist_ok=True)

    # Make sure we have a UMAP on UniVI latent
    ensure_umap_on_univi(adata, n_neighbors=n_neighbors, min_dist=min_dist)

    var_names = np.array(adata.var_names)

    for group, genes in marker_dict.items():
        # intersect with available features
        present = [g for g in genes if g in var_names]
        if len(present) == 0:
            print(f"[marker] In {title_prefix}, no markers from {group} present in var_names.")
            continue

        print(f"[marker] {title_prefix}: plotting {group} markers: {present}")
        sc.pl.umap(
            adata,
            color=present,
            size=50,
            alpha=0.8,
            show=False,
        )
        fname = f"umap_{title_prefix}_{group}_markers.png".replace(" ", "_")
        plt.savefig(os.path.join(figdir, fname), bbox_inches="tight")
        plt.show()
        plt.close()

def cluster_level_marker_deltas(
    adata,
    markers,
    cluster_key: str,
    mod: str,
    tag: str,
    figdir: str = FIGDIR,
):
    """
    For a set of markers, compare cluster-level raw vs denoised means.
    Assumes `adata.layers["univi_denoised"]` exists.
    """
    markers = [g for g in markers if g in adata.var_names]
    if not markers:
        print(f"[marker] No markers present for {tag} ({mod}) – skipping.")
        return

    X_raw = _to_dense(adata.X)
    X_den = _to_dense(adata.layers["univi_denoised"])

    idx = adata.var_names.get_indexer(markers)
    cluster_ids = sorted(adata.obs[cluster_key].unique())

    raw_means = []
    den_means = []

    for cl in cluster_ids:
        mask = adata.obs[cluster_key] == cl
        if mask.sum() == 0:
            raw_means.append(np.full(len(markers), np.nan))
            den_means.append(np.full(len(markers), np.nan))
        else:
            raw_means.append(X_raw[mask][:, idx].mean(axis=0))
            den_means.append(X_den[mask][:, idx].mean(axis=0))

    raw_means = np.vstack(raw_means)
    den_means = np.vstack(den_means)
    delta_means = den_means - raw_means

    # Cluster-level bar plots or heatmaps
    df_raw   = pd.DataFrame(raw_means,   index=cluster_ids, columns=markers)
    df_den   = pd.DataFrame(den_means,   index=cluster_ids, columns=markers)
    df_delta = pd.DataFrame(delta_means, index=cluster_ids, columns=markers)

    # Heatmap of raw marker expression
    plt.figure(figsize=(0.7 * len(markers) + 4, 0.4 * len(cluster_ids) + 3))
    sns.heatmap(df_raw, cmap="viridis", cbar_kws={"label": "Mean raw expression"})
    plt.xlabel("Markers")
    plt.ylabel(cluster_key)
    plt.title(f"{tag} ({mod}) – raw marker expression by cluster")
    plt.tight_layout()
    plt.savefig(os.path.join(figdir, f"markers_{tag}_{mod}_raw_by_{cluster_key}.png"))
    plt.show()
    plt.close()

    # Heatmap of delta (denoised - raw)
    v_abs = np.nanmax(np.abs(delta_means))
    plt.figure(figsize=(0.7 * len(markers) + 4, 0.4 * len(cluster_ids) + 3))
    sns.heatmap(
        df_delta,
        cmap="vlag",
        center=0,
        vmin=-v_abs,
        vmax=v_abs,
        cbar_kws={"label": "Δ denoised - raw"},
    )
    plt.xlabel("Markers")
    plt.ylabel(cluster_key)
    plt.title(f"{tag} ({mod}) – change in marker expression by cluster")
    plt.tight_layout()
    plt.savefig(os.path.join(figdir, f"markers_{tag}_{mod}_delta_by_{cluster_key}.png"))
    plt.show()
    plt.close()


In [None]:
'''
# Example (fill in your real markers)
rna_markers = {
    "CD4_T":  ["CD4", "IL7R"],
    "CD8_T":  ["CD8A", "CD8B"],
    "B_cell": ["MS4A1", "CD79A"],
    "NK":     ["NKG7", "GNLY"],
}

adt_markers = {
    "T_cell": ["CD3", "CD4", "CD8"],
    "B_cell": ["CD19", "CD20"],
}

# UMAPs colored by RNA markers
plot_markers_on_umap(
    rna_test_adata,
    marker_dict=rna_markers,
    title_prefix="rna_univi_latent",
)

# Cluster-level marker deltas (using RNA unimodal clusters)
cluster_level_marker_deltas(
    rna_test_adata,
    markers=["CD4", "CD8A", "MS4A1"],  # or whatever
    cluster_key="rna_leiden",
    mod="rna",
    tag="rna_test_adata",
)
'''

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr

# ------------------------------------------------------
# 1) Define richer marker panels (RNA + ADT)
# ------------------------------------------------------

rna_markers = {
    # CD4-ish / helper / memory T
    "CD4_like_T": [
        "CD4", "IL7R", "CCR7", "LEF1", "TCF7", "BCL11B", "IKZF2", "MEF2C", "RUNX3",
        "TNFRSF4", "TNFRSF18", "CXCR4", "CCR6"
    ],
    # Cytotoxic CD8 / NK-like
    "Cytotoxic_T_NK": [
        "CD8A", "CD8B", "PRF1", "GZMB", "GZMH", "GNLY", "NKG7", "IFNG", "KLRD1",
        "KLRK1", "CX3CR1", "CCL5"
    ],
    # B cells / plasmablasts
    "B_cell": [
        "MS4A1", "CD19", "CD22", "CD79A", "CD79B", "BANK1", "BLK", "PAX5",
        "CD74", "HLA-DRA", "HLA-DRB1", "HLA-DQB1",
        "IGKC", "IGLC1", "IGHD", "IGHM", "IGHA1", "IGHA2", "IGHE", "MZB1", "XBP1"
    ],
    # Mono / myeloid / DC
    "Mono_like": [
        "LYZ", "S100A8", "S100A9", "S100A12", "FCN1", "IL1B", "TNFAIP3", "NFKBIA",
        "LST1", "CSF2RA", "CSF3R", "ITGAX", "ITGAD", "CTSS", "LAPTM5", "SRGN",
        "CCR2", "CXCL8", "CD14"
    ],
    # General activation / exhaustion-ish markers
    "Activation": [
        "CD69", "IFNG", "TNFRSF9", "TNFRSF18", "TNFRSF4", "LAG3", "TIGIT", "PDCD1",
        "TOX", "BATF", "EGR1", "FOSB"
    ],
}

adt_markers = {
    "T_all": ["CD3", "TCR-a/b", "TCR-g/d"],
    "CD4_T": ["CD3", "CD4", "CD45RA", "CD45RO", "CD27", "CD127", "CD279"],
    "CD8_T": ["CD3", "CD8a", "CD45RA", "CD45RO", "KLRG1", "CD27", "CD279"],
    "NK": ["CD56", "CD16", "KLRG1"],
    "B_cell": ["CD19", "CD21", "CD24", "IgD", "IgM", "CD38"],
    "Mono_DC": ["CD14", "CD16", "CD11b", "CD11c", "HLA-DR", "CD141", "CD172a", "CD192", "CD304"],
    "Activation": ["CD25", "CD40", "CD80", "CD86", "CD71", "CD95", "CD278", "CD279"],
}

# ------------------------------------------------------
# 2) Helper: filter markers to those present in the AnnData
# ------------------------------------------------------

def filter_markers_to_var(marker_dict, adata, verbose=True, label=""):
    varset = set(adata.var_names)
    out = {}
    for grp, genes in marker_dict.items():
        present = [g for g in genes if g in varset]
        if verbose:
            missing = [g for g in genes if g not in varset]
            if present:
                print(f"[{label}] {grp}: using {len(present)} markers; missing: {missing}")
            else:
                print(f"[{label}] {grp}: no markers present; all missing: {missing}")
        if present:
            out[grp] = present
    return out

rna_markers_f = filter_markers_to_var(rna_markers, rna_test_adata, label="RNA")
adt_markers_f = filter_markers_to_var(adt_markers, adt_test_adata, label="ADT")

# ------------------------------------------------------
# 3) UMAP plotting of marker groups on UniVI latent UMAPs
# ------------------------------------------------------

def ensure_umap_on_univi(adata, n_neighbors=30):
    """Make sure adata has X_umap; if not, compute neighbors/UMAP on X_univi."""
    if "X_umap" not in adata.obsm_keys():
        if "X_univi" not in adata.obsm_keys():
            raise ValueError("No 'X_univi' in .obsm; cannot build UMAP on UniVI latent.")
        sc.pp.neighbors(adata, use_rep="X_univi", n_neighbors=n_neighbors)
        sc.tl.umap(adata)

def plot_markers_on_umap(adata, marker_dict, title_prefix, figdir=FIGDIR, size=200):
    ensure_umap_on_univi(adata)
    for group, genes in marker_dict.items():
        if not genes:
            continue
        print(f"[marker UMAP] {title_prefix}: {group} -> {genes}")
        sc.pl.umap(
            adata,
            color=genes,
            size=size,
            alpha=0.8,
            show=False,
        )
        fname = f"umap_{title_prefix}_{group}_markers.png".replace(" ", "_")
        plt.savefig(os.path.join(figdir, fname), bbox_inches="tight")
        plt.show()
        plt.close()

# RNA marker UMAPs (on UniVI latent)
plot_markers_on_umap(
    rna_test_adata,
    marker_dict=rna_markers_f,
    title_prefix="rna_univi_latent",
)

# ADT marker UMAPs (on UniVI latent)
plot_markers_on_umap(
    adt_test_adata,
    marker_dict=adt_markers_f,
    title_prefix="adt_univi_latent",
)

# ------------------------------------------------------
# 4) Cluster-level marker heatmaps (raw vs denoised) per modality
# ------------------------------------------------------

def cluster_marker_heatmaps(
    adata,
    marker_dict,
    cluster_key,
    mod: str,
    tag: str,
    layer_raw=None,             # None -> use .X
    layer_denoised="univi_denoised",
    figdir=FIGDIR,
):
    """
    For each marker group, compute cluster × marker mean (raw vs denoised)
    and plot heatmaps + delta heatmap.
    """
    if cluster_key not in adata.obs.columns:
        print(f"[{mod}] No cluster_key='{cluster_key}' in obs; skipping.")
        return

    # Get matrices
    if layer_raw is None:
        X_raw = _to_dense(adata.X)
        raw_label = "X"
    else:
        X_raw = _to_dense(adata.layers[layer_raw])
        raw_label = layer_raw

    if layer_denoised in adata.layers:
        X_den = _to_dense(adata.layers[layer_denoised])
        has_den = True
    else:
        X_den = None
        has_den = False
        print(f"[{mod}] layer '{layer_denoised}' not found; only raw heatmaps will be plotted.")

    var_index = pd.Index(adata.var_names)
    clusters = adata.obs[cluster_key].astype("category")
    cluster_categories = clusters.cat.categories

    for group, genes in marker_dict.items():
        if not genes:
            continue
        genes_present = [g for g in genes if g in var_index]
        if not genes_present:
            print(f"[{mod}] [{group}] no markers present; skipping.")
            continue

        cols = var_index.get_indexer(genes_present)

        # cluster × gene mean (raw)
        df_raw = []
        for cl in cluster_categories:
            mask = (clusters == cl).values
            if not mask.any():
                continue
            m = X_raw[mask][:, cols].mean(axis=0)
            df_raw.append(pd.Series(m, index=genes_present, name=str(cl)))
        df_raw = pd.DataFrame(df_raw)

        # Plot raw heatmap
        plt.figure(figsize=(0.5 * len(genes_present) + 4, 0.5 * len(df_raw) + 3))
        sns.heatmap(
            df_raw,
            cmap="viridis",
            cbar_kws={"label": f"Mean {raw_label}"},
        )
        plt.xlabel("Markers")
        plt.ylabel(cluster_key)
        plt.title(f"{tag} ({mod}) – {group} markers (raw)")
        plt.tight_layout()
        fname = f"heatmap_{tag}_{mod}_{group}_raw.png".replace(" ", "_")
        plt.savefig(os.path.join(figdir, fname))
        plt.show()
        plt.close()

        if has_den:
            # cluster × gene mean (denoised)
            df_den = []
            for cl in cluster_categories:
                mask = (clusters == cl).values
                if not mask.any():
                    continue
                m = X_den[mask][:, cols].mean(axis=0)
                df_den.append(pd.Series(m, index=genes_present, name=str(cl)))
            df_den = pd.DataFrame(df_den)

            # Denoised heatmap
            plt.figure(figsize=(0.5 * len(genes_present) + 4, 0.5 * len(df_den) + 3))
            sns.heatmap(
                df_den,
                cmap="viridis",
                cbar_kws={"label": f"Mean {layer_denoised}"},
            )
            plt.xlabel("Markers")
            plt.ylabel(cluster_key)
            plt.title(f"{tag} ({mod}) – {group} markers (denoised)")
            plt.tight_layout()
            fname = f"heatmap_{tag}_{mod}_{group}_denoised.png".replace(" ", "_")
            plt.savefig(os.path.join(figdir, fname))
            plt.show()
            plt.close()

            # Delta (denoised - raw)
            df_delta = df_den - df_raw
            plt.figure(figsize=(0.5 * len(genes_present) + 4, 0.5 * len(df_delta) + 3))
            sns.heatmap(
                df_delta,
                cmap="vlag",
                center=0,
                cbar_kws={"label": "Mean (denoised - raw)"},
            )
            plt.xlabel("Markers")
            plt.ylabel(cluster_key)
            plt.title(f"{tag} ({mod}) – {group} markers Δ (den - raw)")
            plt.tight_layout()
            fname = f"heatmap_{tag}_{mod}_{group}_delta.png".replace(" ", "_")
            plt.savefig(os.path.join(figdir, fname))
            plt.show()
            plt.close()

# RNA cluster-level marker heatmaps
cluster_marker_heatmaps(
    rna_test_adata,
    marker_dict=rna_markers_f,
    cluster_key="rna_leiden",
    mod="rna",
    tag="rna_test_adata",
    layer_raw=None,              # uses .X (your working RNA space)
    layer_denoised="univi_denoised",
)

# ADT cluster-level marker heatmaps
cluster_marker_heatmaps(
    adt_test_adata,
    marker_dict=adt_markers_f,
    cluster_key="adt_leiden",
    mod="adt",
    tag="adt_test_adata",
    layer_raw="counts",          # your raw ADT (e.g. CLR or arcsinh before UniVI)
    layer_denoised="univi_denoised",
)


In [None]:
# ------------------------------------------------------
# 5) ATAC LSI ↔ RNA marker correlations
# ------------------------------------------------------

def correlate_lsi_with_rna_markers(atac_adata, rna_adata, genes, layer_rna=None, top_k=5):
    """
    For each gene, compute Pearson corr between gene expression (RNA)
    and each ATAC LSI dimension. Return full table + top_k per gene.
    """
    assert np.array_equal(atac_adata.obs_names, rna_adata.obs_names), \
        "ATAC and RNA obs_names must match 1:1."

    X_lsi = _to_dense(atac_adata.X)  # cells × n_lsi

    if layer_rna is None:
        X_rna = _to_dense(rna_adata.X)
    else:
        X_rna = _to_dense(rna_adata.layers[layer_rna])
    var_index = pd.Index(rna_adata.var_names)

    results = []
    for g in genes:
        if g not in var_index:
            print(f"[ATAC-RNA corr] {g} not in rna var_names; skipping.")
            continue
        g_idx = var_index.get_loc(g)
        g_vec = X_rna[:, g_idx].ravel()

        # skip all-constant genes
        if np.allclose(g_vec, g_vec[0]):
            print(f"[ATAC-RNA corr] {g} is constant across cells; skipping.")
            continue

        for k in range(X_lsi.shape[1]):
            z = X_lsi[:, k]
            r, _ = pearsonr(z, g_vec)
            results.append({"gene": g, "lsi_dim": k, "corr": r})

    df = pd.DataFrame(results)
    if df.empty:
        print("No correlations computed (no genes found).")
        return df, df

    df = df.assign(abs_corr=lambda d: d["corr"].abs())  # type: ignore
    df_top = (df
              .sort_values(["gene", "abs_corr"], ascending=[True, False])
              .groupby("gene")
              .head(top_k))
    return df, df_top

# Pick a marker subset to probe ATAC LSI
genes_for_atac = [
    # B cell / Ig
    "MS4A1", "CD74", "HLA-DRA", "IGKC", "IGHM",
    # Cytotoxic / NK
    "PRF1", "GZMH", "GNLY", "NKG7", "IFNG",
    # Mono / myeloid
    "S100A8", "S100A9", "FCN1", "IL1B", "LYZ",
]

df_corr, df_corr_top = correlate_lsi_with_rna_markers(
    atac_test_adata,
    rna_test_adata,
    genes_for_atac,
    layer_rna=None,  # or "log1p" etc if you prefer
    top_k=3,
)

print("\nTop LSI dims per marker gene:")
print(df_corr_top)

# Optionally: visualize some high-correlation LSI dims on an ATAC UMAP (UniVI latent)
ensure_umap_on_univi(atac_test_adata)

lsi_dims_to_plot = sorted(df_corr_top["lsi_dim"].unique())[:6]  # first few interesting ones
for d in lsi_dims_to_plot:
    atac_test_adata.obs[f"LSI_{d}"] = _to_dense(atac_test_adata.X)[:, d]
    sc.pl.umap(
        atac_test_adata,
        color=f"LSI_{d}",
        size=85,
        alpha=0.8,
        show=False,
    )
    fname = f"umap_atac_univi_LSI_dim_{d}.png"
    plt.savefig(os.path.join(FIGDIR, fname), bbox_inches="tight")
    plt.show()
    plt.close()


In [None]:
# ------------------------------------------------------
# 6) Sample cells per unimodal cluster across modalities
# ------------------------------------------------------

def sample_per_cluster(adata, cluster_key, n_per_cluster=2000, random_state=0):
    """
    Sample up to n_per_cluster cells from each cluster in cluster_key.
    Returns a new AnnData subset.
    """
    if cluster_key not in adata.obs.columns:
        raise ValueError(f"cluster_key='{cluster_key}' not in adata.obs")

    rng = np.random.default_rng(random_state)
    clusters = adata.obs[cluster_key].astype("category")
    idx_keep = []

    for cl in clusters.cat.categories:
        mask = (clusters == cl).values
        cell_indices = np.where(mask)[0]
        if len(cell_indices) == 0:
            continue
        n_take = min(n_per_cluster, len(cell_indices))
        chosen = rng.choice(cell_indices, size=n_take, replace=False)
        idx_keep.extend(chosen)

    idx_keep = sorted(idx_keep)
    print(f"[sample_per_cluster] {cluster_key}: keeping {len(idx_keep)} cells "
          f"(<= {n_per_cluster} per cluster)")
    return adata[idx_keep].copy()

n_per_cluster = 2000  # upper bound; will just cap at cluster size if smaller

rna_sample  = sample_per_cluster(rna_test_adata,  "rna_leiden",  n_per_cluster, random_state=42)
adt_sample  = sample_per_cluster(adt_test_adata,  "adt_leiden",  n_per_cluster, random_state=42)
atac_sample = sample_per_cluster(atac_test_adata, "atac_leiden", n_per_cluster, random_state=42)

# Tag modality for sampled sets
rna_sample.obs["univi_source"]  = "rna"
adt_sample.obs["univi_source"]  = "adt"
atac_sample.obs["univi_source"] = "atac"

# ------------------------------------------------------
# 7) Build combined sampled object in UniVI latent
# ------------------------------------------------------

combined_sample = rna_sample.concatenate(
    adt_sample,
    atac_sample,
    join="outer",
    batch_key="univi_batch",
    batch_categories=["rna", "adt", "atac"],
    index_unique=None,
)

# Stack UniVI latents in the same order as concatenate
combined_sample.obsm["X_univi"] = np.vstack([
    rna_sample.obsm["X_univi"],
    adt_sample.obsm["X_univi"],
    atac_sample.obsm["X_univi"],
])

# Neighbors/UMAP on UniVI latent for the sampled cells
sc.pp.neighbors(combined_sample, use_rep="X_univi", n_neighbors=30)
sc.tl.umap(combined_sample)

# ------------------------------------------------------
# 8) UMAP visualizations: modality + unimodal clusters
# ------------------------------------------------------

# UMAP colored by modality
sc.pl.umap(
    combined_sample,
    color="univi_source",
    size=85,
    alpha=0.8,
    show=False,
)
plt.savefig(os.path.join(FIGDIR, "umap_sampled_tri_modal_by_modality.png"), bbox_inches="tight")
plt.show()
plt.close()

# UMAP colored by each unimodal Leiden labelling (pseudo-celltypes)
for key in ["rna_leiden", "adt_leiden", "atac_leiden"]:
    if key in combined_sample.obs.columns:
        print(f"[sampled UMAP] Coloring by {key}")
        sc.pl.umap(
            combined_sample,
            color=key,
            size=85,
            alpha=0.8,
            show=False,
        )
        fname = f"umap_sampled_tri_modal_{key}.png"
        plt.savefig(os.path.join(FIGDIR, fname), bbox_inches="tight")
        plt.show()
        plt.close()
    else:
        print(f"[sampled UMAP] {key} not in combined_sample.obs; skipping.")

# ------------------------------------------------------
# 9) (Optional) Use denoised reconstructions for sampled cells
# ------------------------------------------------------
# If you already ran univi_eval.denoise_adata on full test sets, the sampled
# objects will inherit the 'univi_denoised' layer. If not, you can run it here:

for adata_sample, mod, tag in [
    (rna_sample,  "rna",  "rna_sample"),
    (adt_sample,  "adt",  "adt_sample"),
    (atac_sample, "atac", "atac_sample"),
]:
    if "univi_denoised" not in adata_sample.layers:
        print(f"[denoise sampled] Running decoder for {tag} ({mod})...")
        univi_eval.denoise_adata(
            model,
            adata_sample,
            modality=mod,
            device=device,
            batch_size=512,
            out_layer="univi_denoised",
        )

# Now you can re-use cluster_marker_heatmaps on rna_sample/adt_sample
cluster_marker_heatmaps(
    rna_sample,
    marker_dict=rna_markers_f,
    cluster_key="rna_leiden",
    mod="rna",
    tag="rna_sample",
    layer_raw=None,
    layer_denoised="univi_denoised",
)

cluster_marker_heatmaps(
    adt_sample,
    marker_dict=adt_markers_f,
    cluster_key="adt_leiden",
    mod="adt",
    tag="adt_sample",
    layer_raw="counts",
    layer_denoised="univi_denoised",
)


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from univi import evaluation as univi_eval

# ------------------------------------------------------
# 1) Cross-modal predictions from RNA to ADT and ATAC
# ------------------------------------------------------

print("\n[Cross-modal] Predicting ADT and ATAC from RNA test data...")

# RNA → ADT
Xhat_adt_from_rna = univi_eval.cross_modal_predict(
    model,
    adata_src=rna_test_adata,
    src_mod="rna",
    tgt_mod="adt",
    device=device,
    batch_size=512,
)
# shape check
print("  RNA→ADT prediction shape:", Xhat_adt_from_rna.shape)

adt_test_adata.layers["univi_pred_from_rna"] = Xhat_adt_from_rna

# RNA → ATAC
Xhat_atac_from_rna = univi_eval.cross_modal_predict(
    model,
    adata_src=rna_test_adata,
    src_mod="rna",
    tgt_mod="atac",
    device=device,
    batch_size=512,
)
print("  RNA→ATAC prediction shape:", Xhat_atac_from_rna.shape)

atac_test_adata.layers["univi_pred_from_rna"] = Xhat_atac_from_rna

print("Stored predictions in:")
print("  adt_test_adata.layers['univi_pred_from_rna']")
print("  atac_test_adata.layers['univi_pred_from_rna']")


In [None]:
# ------------------------------------------------------
# 2) ADT cluster-level marker heatmaps: true vs RNA-predicted
# ------------------------------------------------------
# Assumes:
#   - adt_markers_f (ADT marker dict filtered to var_names)
#   - cluster_marker_heatmaps(...) defined previously
#   - ADT "true" data live in .layers["counts"] (or whatever you used)

cluster_marker_heatmaps(
    adata=adt_test_adata,
    marker_dict=adt_markers_f,
    cluster_key="adt_leiden",
    mod="adt",
    tag="adt_test_from_rna",
    layer_raw="counts",              # true ADT signal
    layer_denoised="univi_pred_from_rna",  # RNA→ADT prediction
)


In [None]:
# ------------------------------------------------------
# 3) ADT: nearest-centroid cluster prediction using RNA→ADT
# ------------------------------------------------------

def assign_clusters_by_centroid(
    adata,
    cluster_key: str,
    layer: str,
    out_key: str,
):
    """
    For each cell, assign the cluster whose centroid (in given layer) is
    closest in Euclidean distance.
    """
    if cluster_key not in adata.obs.columns:
        raise ValueError(f"{cluster_key} not in adata.obs")

    if layer not in adata.layers:
        raise ValueError(f"{layer} not in adata.layers")

    X = _to_dense(adata.layers[layer])  # cells × features
    clusters = adata.obs[cluster_key].astype("category")
    cats = clusters.cat.categories

    centroids = []
    for cl in cats:
        mask = (clusters == cl).values
        if not mask.any():
            centroids.append(np.nan)
            continue
        centroids.append(X[mask].mean(axis=0))
    centroids = np.vstack(centroids)  # n_clusters × features

    # distances: cells × clusters
    dists = np.linalg.norm(X[:, None, :] - centroids[None, :, :], axis=2)
    idx_min = np.argmin(dists, axis=1)
    assigned = cats[idx_min]

    adata.obs[out_key] = pd.Categorical(assigned, categories=cats)
    print(f"[assign_clusters_by_centroid] Wrote predicted clusters to obs['{out_key}'].")


def confusion_true_vs_pred(
    adata,
    true_key: str,
    pred_key: str,
    tag: str,
    normalize: str = "index",
    figdir=FIGDIR,
):
    """
    Confusion matrix (normalized crosstab) between true and predicted clusters.
    """
    s_true = adata.obs[true_key].astype("category")
    s_pred = adata.obs[pred_key].astype("category")

    tab = pd.crosstab(s_true, s_pred, normalize=normalize)

    plt.figure(figsize=(0.5 * tab.shape[1] + 4, 0.5 * tab.shape[0] + 4))
    sns.heatmap(
        tab,
        cmap="viridis",
        annot=False,
        cbar_kws={"label": f"Fraction (normalized by {normalize})"},
    )
    plt.xlabel(f"Predicted ({pred_key})")
    plt.ylabel(f"True ({true_key})")
    plt.title(f"Cluster confusion: {tag}")
    plt.tight_layout()
    fname = f"cluster_confusion_{tag}.png".replace(" ", "_")
    plt.savefig(os.path.join(figdir, fname))
    plt.show()
    plt.close()

    return tab

# Predict ADT clusters using RNA->ADT predicted profiles
assign_clusters_by_centroid(
    adata=adt_test_adata,
    cluster_key="adt_leiden",
    layer="univi_pred_from_rna",
    out_key="adt_leiden_pred_from_rna",
)

# Confusion matrix: how well RNA->ADT recovers ADT clusters?
tab_adt_conf = confusion_true_vs_pred(
    adata=adt_test_adata,
    true_key="adt_leiden",
    pred_key="adt_leiden_pred_from_rna",
    tag="ADT_true_vs_RNA_predicted",
    normalize="index",
)
print("\nADT cluster confusion (true vs RNA-predicted features):")
print(tab_adt_conf)


In [None]:
print(adt_test_adata)
print(adt_test_adata.X.min())
print(adt_test_adata.X.max())

adt_test_adata.layers['scaled'] = adt_test_adata.X


In [None]:
# ------------------------------------------------------
# 4) ADT: UMAPs of true vs RNA-predicted marker expression
# ------------------------------------------------------

def plot_true_vs_pred_markers_on_umap(
    adata,
    markers,
    layer_true="scaled",
    layer_pred="univi_pred_from_rna",
    title_prefix="adt_true_vs_rna_pred",
    figdir=FIGDIR,
    size=200,
):
    if layer_true not in adata.layers:
        raise ValueError(f"{layer_true} not in adata.layers")
    if layer_pred not in adata.layers:
        raise ValueError(f"{layer_pred} not in adata.layers")

    ensure_umap_on_univi(adata)  # UMAP on X_univi

    X_true = _to_dense(adata.layers[layer_true])
    X_pred = _to_dense(adata.layers[layer_pred])

    var_index = pd.Index(adata.var_names)

    for m in markers:
        if m not in var_index:
            print(f"[UMAP true vs pred] Marker {m} not in ADT var_names; skipping.")
            continue
        idx = var_index.get_loc(m)
        adata.obs[f"{m}_true"] = X_true[:, idx]
        adata.obs[f"{m}_pred"] = X_pred[:, idx]

        print(f"[UMAP true vs pred] Plotting {m} (true vs RNA-pred).")
        sc.pl.umap(
            adata,
            color=[f"{m}_true", f"{m}_pred"],
            size=size,
            alpha=0.8,
            show=False,
        )
        fname = f"umap_{title_prefix}_{m}.png".replace(" ", "_")
        plt.savefig(os.path.join(figdir, fname), bbox_inches="tight")
        plt.show()
        plt.close()

# example ADT markers to inspect
adt_markers_to_plot = ["CD3", "CD4", "CD8a", "CD19", "CD14", "CD56", "HLA-DR"]

plot_true_vs_pred_markers_on_umap(
    adt_test_adata,
    markers=adt_markers_to_plot,
    layer_true="scaled",
    layer_pred="univi_pred_from_rna",
    title_prefix="adt_true_vs_rna_pred",
)


In [None]:
# ------------------------------------------------------
# 5) ATAC: cluster-level comparison of true vs RNA-predicted ATAC features
# ------------------------------------------------------

def atac_cluster_feature_heatmaps(
    atac_adata,
    cluster_key="atac_leiden",
    layer_true="X",                    # treat .X as "true"
    layer_pred="univi_pred_from_rna",
    tag="atac_test_from_rna",
    figdir=FIGDIR,
    max_features_for_plot=40,
):
    """
    Compare per-cluster feature means between true ATAC (.X) and RNA-predicted ATAC.
    Features = columns of atac_adata.X (e.g., LSI dims).
    """
    if cluster_key not in atac_adata.obs.columns:
        print(f"[ATAC] No cluster_key='{cluster_key}' in obs; skipping.")
        return
    if layer_pred not in atac_adata.layers:
        print(f"[ATAC] No predicted layer '{layer_pred}'; skipping.")
        return

    # True ATAC feature matrix
    if layer_true == "X":
        X_true = _to_dense(atac_adata.X)
        true_label = "X"
    else:
        X_true = _to_dense(atac_adata.layers[layer_true])
        true_label = layer_true

    X_pred = _to_dense(atac_adata.layers[layer_pred])

    clusters = atac_adata.obs[cluster_key].astype("category")
    cats = clusters.cat.categories

    # For plotting we may want to restrict to features with highest variance
    var_true = X_true.var(axis=0)
    idx_sorted = np.argsort(var_true)[::-1]
    idx_plot = idx_sorted[:max_features_for_plot]

    feature_names = [str(atac_adata.var_names[i]) for i in idx_plot]

    # cluster × feature means (true)
    df_true = []
    for cl in cats:
        mask = (clusters == cl).values
        if not mask.any():
            continue
        m = X_true[mask][:, idx_plot].mean(axis=0)
        df_true.append(pd.Series(m, index=feature_names, name=str(cl)))
    df_true = pd.DataFrame(df_true)

    # cluster × feature means (pred)
    df_pred = []
    for cl in cats:
        mask = (clusters == cl).values
        if not mask.any():
            continue
        m = X_pred[mask][:, idx_plot].mean(axis=0)
        df_pred.append(pd.Series(m, index=feature_names, name=str(cl)))
    df_pred = pd.DataFrame(df_pred)

    # True heatmap
    plt.figure(figsize=(0.4 * len(feature_names) + 4, 0.5 * len(df_true) + 3))
    sns.heatmap(
        df_true,
        cmap="viridis",
        cbar_kws={"label": f"Mean {true_label}"},
    )
    plt.xlabel("ATAC features (e.g., LSI dims)")
    plt.ylabel(cluster_key)
    plt.title(f"{tag} – ATAC (true)")
    plt.tight_layout()
    plt.savefig(os.path.join(figdir, f"heatmap_{tag}_true.png"))
    plt.show()
    plt.close()

    # Predicted heatmap
    plt.figure(figsize=(0.4 * len(feature_names) + 4, 0.5 * len(df_pred) + 3))
    sns.heatmap(
        df_pred,
        cmap="viridis",
        cbar_kws={"label": "Mean pred_from_rna"},
    )
    plt.xlabel("ATAC features (e.g., LSI dims)")
    plt.ylabel(cluster_key)
    plt.title(f"{tag} – ATAC (RNA-predicted)")
    plt.tight_layout()
    plt.savefig(os.path.join(figdir, f"heatmap_{tag}_pred.png"))
    plt.show()
    plt.close()

    # Delta heatmap
    df_delta = df_pred - df_true
    plt.figure(figsize=(0.4 * len(feature_names) + 4, 0.5 * len(df_delta) + 3))
    sns.heatmap(
        df_delta,
        cmap="vlag",
        center=0,
        cbar_kws={"label": "Mean (pred - true)"},
    )
    plt.xlabel("ATAC features (e.g., LSI dims)")
    plt.ylabel(cluster_key)
    plt.title(f"{tag} – ATAC Δ (pred - true)")
    plt.tight_layout()
    plt.savefig(os.path.join(figdir, f"heatmap_{tag}_delta.png"))
    plt.show()
    plt.close()

    return df_true, df_pred, df_delta

df_atac_true, df_atac_pred, df_atac_delta = atac_cluster_feature_heatmaps(
    atac_test_adata,
    cluster_key="atac_leiden",
    layer_true="X",                     # .X as true LSI/features
    layer_pred="univi_pred_from_rna",
    tag="atac_from_rna",
    max_features_for_plot=40,
)


In [None]:
# ------------------------------------------------------
# 6) ATAC: UMAP of selected feature dims (true vs RNA-pred)
# ------------------------------------------------------

def plot_atac_dims_true_vs_pred_umap(
    atac_adata,
    dims,
    layer_pred="univi_pred_from_rna",
    title_prefix="atac_true_vs_rna_pred_dims",
    figdir=FIGDIR,
    size=200,
):
    """
    dims = list of integer feature indices (columns in atac_adata.X).
    """
    ensure_umap_on_univi(atac_adata)

    X_true = _to_dense(atac_adata.X)
    X_pred = _to_dense(atac_adata.layers[layer_pred])

    for d in dims:
        if d < 0 or d >= atac_adata.n_vars:
            print(f"[ATAC UMAP] dim {d} out of range; skipping.")
            continue
        name = str(atac_adata.var_names[d])
        atac_adata.obs[f"ATAC_dim_{d}_true"] = X_true[:, d]
        atac_adata.obs[f"ATAC_dim_{d}_pred"] = X_pred[:, d]

        sc.pl.umap(
            atac_adata,
            color=[f"ATAC_dim_{d}_true", f"ATAC_dim_{d}_pred"],
            size=size,
            alpha=0.8,
            show=False,
        )
        fname = f"umap_{title_prefix}_dim_{d}.png"
        plt.savefig(os.path.join(figdir, fname), bbox_inches="tight")
        plt.show()
        plt.close()

# Example: just plot the first few dims
dims_to_plot = [0, 1, 2, 3]
plot_atac_dims_true_vs_pred_umap(
    atac_test_adata,
    dims=dims_to_plot,
)


In [None]:
# ------------------------------------------------------------------
# RNA-gene ↔ ADT-marker mapping, tailored to your RNA var_names
# (genes that I *can see* in your provided var set are listed first)
# ------------------------------------------------------------------

rna_gene_map_for_adt_markers = {
    # -------------------------
    # T cell core markers
    # -------------------------
    # CD3 complex – only CD247 (CD3ζ) is clearly in your HVGs
    "CD3": ["CD247", "CD3D", "CD3E", "CD3G"],

    # CD4 helper / naïve-like: no CD4 gene itself, so we use good proxies
    "CD4": ["GATA3", "TCF7", "LEF1", "BCL11B", "IL7R"],

    # CD8 / cytotoxic T: no CD8A in your set, but strong cytotoxic signature is present
    "CD8": ["NKG7", "GNLY", "PRF1", "GZMH", "STAT4", "CD8A", "CD8B"],

    # Early activation
    "CD69": ["CD69", "NR4A2", "TNFAIP3"],

    # Co-stimulatory-ish / surface T cell marker
    "CD6": ["CD6"],

    # -------------------------
    # NK / cytotoxic markers
    # -------------------------
    # CD56 = NCAM1; your RNA has NCAM1 + classic NK/cytotoxic genes
    "CD56": ["NCAM1", "NKG7", "GNLY", "PRF1", "GZMH"],

    # If you have an NKG2D ADT (sometimes in panels)
    "NKG2D": ["NKG7", "GNLY", "PRF1"],

    # Generic “NK / cytotoxic” channel if you have one
    "NK_signature": ["NKG7", "GNLY", "PRF1", "GZMH", "STAT4"],

    # -------------------------
    # B cell markers
    # -------------------------
    # Your RNA HVGs clearly include CD22, MS4A1 (CD20), BANK1, PAX5, FCER2, FCRLA, etc.
    "CD19": ["CD22", "MS4A1", "BANK1", "PAX5", "CD19"],
    "CD20": ["MS4A1", "BANK1", "PAX5"],
    "CD22": ["CD22", "BANK1", "PAX5"],
    "CD23": ["FCER2"],
    "CD74": ["CD74", "HLA-DRA", "HLA-DRB1"],

    # More B-cell-ish markers if you have them in ADT
    "FCRL1": ["FCRL1", "BANK1", "PAX5"],
    "IgM":   ["IGHM", "IGKC"],
    "IgA":   ["IGHA1", "IGHA2"],
    "IgD":   ["IGHD"],

    # -------------------------
    # Myeloid / mono / DC
    # -------------------------
    # Monocytes / neutrophils – CD14 gene itself is not in your list, so use myeloid proxies
    "CD14": ["LYZ", "S100A8", "S100A9", "FCN1"],

    # CD16 – FCGR3B is clearly in your var_names (neutrophils)
    "CD16": ["FCGR3B"],

    # CD11c – ITGAX is present as a gene
    "CD11c": ["ITGAX"],
    "ITGAX": ["ITGAX"],  # in case the ADT channel is actually named ITGAX

    # Chemokine receptors
    "CCR2": ["CCR2"],

    # Cross-presenting DCs (XCR1 ADT often maps well to XCL1/XCL2 expression)
    "XCR1": ["XCL1", "XCL2"],

    # -------------------------
    # MHC-II / antigen presentation
    # -------------------------
    # You have HLA-DRA, HLA-DRB1, HLA-DQB1, CD74
    "HLA-DR": ["HLA-DRA", "HLA-DRB1", "HLA-DQB1", "CD74"],
    "HLA-DQ": ["HLA-DQB1", "HLA-DQA1"],  # HLA-DQA1 may or may not be present
    "HLA-DP": ["HLA-DPA1", "HLA-DPB1"],  # optional; will be auto-dropped if absent

    # -------------------------
    # Activation / costim / checkpoints
    # (many of these genes may *not* be in your HVGs, but we include them
    # so your code can still use this dict with other datasets)
    # -------------------------
    "CD83": ["CD83"],
    "CD96": ["CD96"],

    # Classic T cell activation / exhaustion markers
    "OX40":  ["TNFRSF4"],
    "4-1BB": ["TNFRSF9"],
    "PD-1":  ["PDCD1"],
    "CTLA4": ["CTLA4"],
    "TIGIT": ["TIGIT"],

    # -------------------------
    # T-reg / Th skewing
    # -------------------------
    "FOXP3": ["FOXP3", "IL2RA"],
    "GATA3": ["GATA3"],
    "TCF7_hi": ["TCF7", "LEF1", "BACH2"],

    # -------------------------
    # Cytokines / effector molecules
    # -------------------------
    "IFNg": ["IFNG", "IFNG-AS1"],
    "IL1b": ["IL1B"],
    "TNFa": ["TNF"],  # may not be present in this HVG set, but common in other datasets
}


In [None]:
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt

# ------------------------------------------------------
# 4) ADT: UMAPs of true vs RNA-predicted marker expression
#     + RNA expression (mapped from marker -> gene symbol[s])
# ------------------------------------------------------

def _guess_rna_genes_for_marker(marker: str) -> list[str]:
    """
    Very simple heuristic to propose RNA gene symbols from an ADT marker name.
    Used only if rna_gene_map does not specify anything.
    """
    cand = []

    # Raw marker
    cand.append(marker)

    # Uppercase
    up = marker.upper()
    if up not in cand:
        cand.append(up)

    # Remove hyphens/slashes
    stripped = up.replace("-", "").replace("/", "")
    if stripped not in cand:
        cand.append(stripped)

    # Handle common CD8a / CD8A / CD8α style
    if up.startswith("CD8") and up not in cand:
        cand.append("CD8A")

    # HLA-DR → HLA-DRA / HLA-DRB1 guesses
    if up.startswith("HLA-DR"):
        for g in ["HLA-DRA", "HLA-DRB1", "HLA-DRB5"]:
            if g not in cand:
                cand.append(g)

    return cand


def plot_true_vs_pred_markers_on_umap_with_rna(
    adt_adata,
    rna_adata,
    markers,
    rna_gene_map: dict[str, list[str]] | None = None,
    layer_true: str = "scaled",              # ADT "true" (e.g. CLR/scale)
    layer_pred: str = "univi_pred_from_rna", # ADT predicted from RNA
    rna_layer: str | None = "log1p",         # RNA layer for expression; if None, use .X
    title_prefix: str = "adt_true_vs_rna_pred_with_rna",
    figdir: str = FIGDIR,
    size: int = 200,
):
    """
    For each marker in `markers`, plot UMAPs (on ADT UniVI latent) colored by:
      - ADT true expression
      - ADT predicted from RNA (cross-modal decoder)
      - RNA expression of corresponding gene(s), if present in RNA var_names

    RNA mapping:
      * Primary: rna_gene_map[marker] -> list of gene symbols
      * Fallback: simple heuristics (_guess_rna_genes_for_marker)
    """
    if rna_gene_map is None:
        rna_gene_map = {}

    # Sanity: same cells in same order
    assert np.array_equal(
        adt_adata.obs_names, rna_adata.obs_names
    ), "ADT and RNA obs_names must match 1:1 for per-cell comparisons."

    # Check ADT layers
    if layer_true not in adt_adata.layers:
        raise ValueError(f"{layer_true} not in adt_adata.layers")
    if layer_pred not in adt_adata.layers:
        raise ValueError(f"{layer_pred} not in adt_adata.layers")

    # ADT matrices
    X_true_adt = _to_dense(adt_adata.layers[layer_true])
    X_pred_adt = _to_dense(adt_adata.layers[layer_pred])
    adt_var_index = pd.Index(adt_adata.var_names)

    # RNA matrix (either layer or X)
    if rna_layer is None:
        X_rna = _to_dense(rna_adata.X)
        rna_label = "X"
    else:
        if rna_layer not in rna_adata.layers:
            raise ValueError(f"{rna_layer} not in rna_adata.layers")
        X_rna = _to_dense(rna_adata.layers[rna_layer])
        rna_label = rna_layer
    rna_var_index = pd.Index(rna_adata.var_names)

    # Ensure we have a UMAP on the ADT UniVI latent
    ensure_umap_on_univi(adt_adata)  # UMAP on adt_adata.obsm["X_univi"]

    for m in markers:
        if m not in adt_var_index:
            print(f"[UMAP true vs pred + RNA] Marker {m} not in ADT var_names; skipping.")
            continue

        # ADT indices and per-cell values
        j_adt = adt_var_index.get_loc(m)
        adt_adata.obs[f"{m}_adt_true"] = X_true_adt[:, j_adt]
        adt_adata.obs[f"{m}_adt_pred_from_rna"] = X_pred_adt[:, j_adt]

        colors = [f"{m}_adt_true", f"{m}_adt_pred_from_rna"]

        # --- RNA mapping for this marker ---
        # 1) explicit map if provided
        candidates = list(rna_gene_map.get(m, []))

        # 2) heuristic guesses
        if not candidates:
            candidates = _guess_rna_genes_for_marker(m)

        # keep only genes actually present in rna_adata
        present_genes = [g for g in candidates if g in rna_var_index]

        if present_genes:
            print(
                f"[UMAP true vs pred + RNA] {m}: "
                f"using RNA genes {present_genes} (layer={rna_label})."
            )
            for g in present_genes:
                j_rna = rna_var_index.get_loc(g)
                colname = f"{m}_rna_{rna_label}_{g}"
                adt_adata.obs[colname] = X_rna[:, j_rna]
                colors.append(colname)
        else:
            print(
                f"[UMAP true vs pred + RNA] {m}: "
                "no matching RNA genes found in var_names; "
                "plotting only ADT true + ADT pred_from_rna."
            )

        # UMAP panels: ADT true, ADT predicted, (optional) RNA gene panels
        sc.pl.umap(
            adt_adata,
            color=colors,
            size=size,
            alpha=0.8,
            show=False,
        )
        fname = f"umap_{title_prefix}_{m}.png".replace(" ", "_")
        plt.savefig(os.path.join(figdir, fname), bbox_inches="tight")
        plt.show()
        plt.close()


# --------------------------------------------------------------------
# Example mapping from ADT markers -> RNA gene symbols
# (feel free to tweak based on what actually exists in rna_test_adata.var_names)
# --------------------------------------------------------------------

# Example ADT markers to inspect
#adt_markers_to_plot = ["CD3", "CD4", "CD8a", "CD19", "CD14", "CD56", "HLA-DR"]

'''
plot_true_vs_pred_markers_on_umap_with_rna(
    adt_adata=adt_test_adata,
    rna_adata=rna_test_adata,
    markers=adt_markers_to_plot,
    rna_gene_map=rna_gene_map_for_adt_markers,
    layer_true="scaled",                 # whatever you used for ADT "true"
    layer_pred="univi_pred_from_rna",    # RNA→ADT predictions (decoded marker space)
    rna_layer=None,                      # use RNA log1p layer if you have it
    title_prefix="adt_true_vs_rna_pred_with_rna",
)
'''


In [None]:
plot_true_vs_pred_markers_on_umap_with_rna(
    adt_adata=adt_test_adata,
    rna_adata=rna_test_adata,
    markers=["CD3", "CD4", "CD8", "CD19", "CD20", "CD14", "CD16", "CD56", "HLA-DR",
             "CD22", "CD74", "CD69", "CD83", "CD96", "CD11c"],
    rna_gene_map=rna_gene_map_for_adt_markers,
    layer_true="scaled",
    layer_pred="univi_pred_from_rna",
    rna_layer=None,  # or None if you prefer raw .X
    title_prefix="adt_true_vs_rna_pred_with_rna",
)
