# UniVI TEA-seq tri-modal data integration demonstration/tutorial

Andrew Ashford, Pathways + Omics Group, Oregon Health & Science University - 11/18/2025

This Jupyter Notebook will be used to outline the training steps for a UniVI model using human PBMC TEA-seq tri-modal data.


#### Import modules

In [2]:
import os
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import scipy.sparse as sp
import scanpy as sc
import anndata as ad

from sklearn.decomposition import TruncatedSVD
from sklearn.preprocessing import normalize

import snapatac2 as snap


In [3]:
# -------------------------
# 0. Wire up package import
# -------------------------
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)

from univi import (
    UniVIMultiModalVAE,
    ModalityConfig,
    UniVIConfig,
    TrainingConfig,
    matching,
)
from univi.data import MultiModalDataset
from univi.trainer import UniVITrainer


#### Read in and preprocess data as needed

In [4]:
# ----------------------
# Helpers
# ----------------------

def strip_suffix(idx):
    """Drop trailing `-<number>` if present (e.g. 10x -1 suffix)."""
    return idx.astype(str).str.replace(r"-\d+$", "", regex=True)


def load_teaseq_sample(
    prefix: str,
    data_dir: Path,
    target_n: int | None = None,
    min_fragments: int = 1000,
    n_lsi: int = 50,
):
    """
    Load a single TEA-seq sample (RNA, ADT, ATAC) for a given prefix.

    Assumes the following files exist for each prefix:
      {prefix}_200M_cellranger-arc_filtered_feature_bc_matrix.h5
      {prefix}_48M_adt_counts.csv.gz
      {prefix}_200M_atac_filtered_fragments.tsv.gz
      {prefix}_200M_atac_filtered_metadata.csv.gz
    """
    print(f"\n===== Loading TEA-seq sample: {prefix} =====")

    rna_h5         = data_dir / f"{prefix}_200M_cellranger-arc_filtered_feature_bc_matrix.h5"
    adt_counts_csv = data_dir / f"{prefix}_48M_adt_counts.csv.gz"
    frag_tsv       = data_dir / f"{prefix}_200M_atac_filtered_fragments.tsv.gz"
    atac_meta_csv  = data_dir / f"{prefix}_200M_atac_filtered_metadata.csv.gz"

    for p in [rna_h5, adt_counts_csv, frag_tsv, atac_meta_csv]:
        if not p.exists():
            raise FileNotFoundError(p)

    # ----------------------
    # 1) RNA (raw counts)
    # ----------------------
    print("Reading RNA (ARC filtered_feature_bc_matrix.h5)...")
    m = sc.read_10x_h5(rna_h5)
    rna_adata = m.copy()
    rna_adata.var_names_make_unique()
    print("  RNA shape:", rna_adata.shape)

    # ----------------------
    # 2) ADT (raw counts)
    # ----------------------
    print("Reading ADT counts...")
    adt_df = pd.read_csv(adt_counts_csv, index_col=0)  # rows = barcodes
    print("  ADT counts df shape:", adt_df.shape)

    adt_adata = ad.AnnData(
        X=sp.csr_matrix(adt_df.values),
        obs=pd.DataFrame(index=adt_df.index.astype(str)),
        var=pd.DataFrame(index=adt_df.columns.astype(str)),
    )
    adt_adata.var_names_make_unique()
    print("  ADT AnnData:", adt_adata.shape)

    # ----------------------
    # 3) ATAC: fragments + metadata
    # ----------------------
    print("Importing ATAC fragments with snapatac2...")
    atac_raw = snap.pp.import_data(
        fragment_file=str(frag_tsv),
        chrom_sizes=snap.genome.hg38,
        sorted_by_barcode=False,
    )
    print("  ATAC raw object:", atac_raw)

    # Attach metadata (provides mapping: original_barcodes <-> barcodes)
    meta = pd.read_csv(atac_meta_csv)
    print("ATAC meta columns:", meta.columns.tolist())
    meta = meta.set_index("barcodes")  # hex IDs

    # Align ATAC cells to metadata
    common_ids = atac_raw.obs_names.intersection(meta.index)
    print("  ATAC cells with metadata:", len(common_ids), "of", atac_raw.n_obs)
    atac_raw = atac_raw[common_ids].copy()
    atac_raw.obs = atac_raw.obs.join(meta, how="left")

    # QC on n_fragments if available
    if "n_fragments" in atac_raw.obs.columns:
        mask = atac_raw.obs["n_fragments"] >= min_fragments
        print("  Keeping", mask.sum(), "of", atac_raw.n_obs,
              f"ATAC cells with n_fragments >= {min_fragments}")
        atac_raw = atac_raw[mask].copy()

    # ----------------------
    # 4) Tile matrix + TF–IDF + LSI
    # ----------------------
    print("Adding ATAC tile matrix...")
    snap.pp.add_tile_matrix(atac_raw)
    print("  Tile matrix shape:", atac_raw.shape, "(cells x genomic tiles)")

    print("Computing TF–IDF and LSI on ATAC tiles...")
    X = atac_raw.X
    if not sp.issparse(X):
        X = sp.csr_matrix(X)

    n_cells, n_feats = X.shape
    print(f"  ATAC tile matrix: {n_cells} cells × {n_feats} features")

    # TF-IDF
    tf = normalize(X, norm="l1", axis=1)
    df = np.array((X > 0).sum(axis=0)).ravel()
    idf = np.log1p(n_cells / (1.0 + df))
    X_tfidf = tf.multiply(idf)

    svd = TruncatedSVD(n_components=n_lsi, random_state=42)
    lsi = svd.fit_transform(X_tfidf)
    lsi = normalize(lsi, norm="l2", axis=1)

    atac_adata = ad.AnnData(
        X=lsi.astype(np.float32),
        obs=atac_raw.obs.copy(),
    )
    print("  ATAC LSI AnnData:", atac_adata.shape)

    # ----------------------
    # 5) Put all three in shared barcode space
    # ----------------------
    # RNA / ADT: 10x barcodes with -1 suffix
    rna_adata.obs_names = strip_suffix(rna_adata.obs_names.to_series())
    adt_adata.obs_names = strip_suffix(adt_adata.obs_names.to_series())

    # ATAC: use original_barcodes (10x) instead of hex barcodes
    atac_adata.obs["barcode_10x"] = atac_adata.obs["original_barcodes"].astype(str)
    atac_adata.obs_names = strip_suffix(atac_adata.obs["barcode_10x"])

    # Intersection of barcodes present in all 3 modalities
    common_barcodes = (
        set(rna_adata.obs_names)
        & set(adt_adata.obs_names)
        & set(atac_adata.obs_names)
    )
    print("  Common tri-modal cells:", len(common_barcodes))
    if len(common_barcodes) == 0:
        raise ValueError(f"No overlapping barcodes across RNA/ADT/ATAC for prefix {prefix}.")

    common_barcodes = sorted(common_barcodes)
    rna_adata  = rna_adata[common_barcodes].copy()
    adt_adata  = adt_adata[common_barcodes].copy()
    atac_adata = atac_adata[common_barcodes].copy()

    print("  Aligned shapes:")
    print("    RNA :", rna_adata.shape)
    print("    ADT :", adt_adata.shape)
    print("    ATAC:", atac_adata.shape)

    # ----------------------
    # 6) Optional per-sample subsampling
    # ----------------------
    if target_n is not None and rna_adata.n_obs > target_n:
        rng = np.random.default_rng(42)
        keep_idx = rng.choice(rna_adata.n_obs, size=target_n, replace=False)
        keep_barcodes = rna_adata.obs_names[keep_idx]

        rna_adata  = rna_adata[keep_barcodes].copy()
        adt_adata  = adt_adata[keep_barcodes].copy()
        atac_adata = atac_adata[keep_barcodes].copy()

        print(f"  After subsampling: {rna_adata.n_obs} cells.")

    # ----------------------
    # 7) Make obs_names globally unique (barcode|sample)
    # ----------------------
    sample_id = prefix  # you can shorten this if you like

    for adata in (rna_adata, adt_adata, atac_adata):
        adata.obs["barcode"] = adata.obs_names.astype(str)
        adata.obs["sample"]  = sample_id
        adata.obs_names = adata.obs["barcode"] + "|" + sample_id

    return rna_adata, adt_adata, atac_adata


In [5]:
# ----------------------
# Discover all TEA-seq prefixes in the folder
# ----------------------
data_dir = Path("../data/TEA-seq_data")

# Any file ending with this pattern belongs to a TEA-seq multiome sample
h5_files = sorted(data_dir.glob("*_200M_cellranger-arc_filtered_feature_bc_matrix.h5"))
prefixes = [f.name.replace("_200M_cellranger-arc_filtered_feature_bc_matrix.h5", "") for f in h5_files]

print("Found TEA-seq prefixes:")
for p in prefixes:
    print(" ", p)


Found TEA-seq prefixes:
  GSM5123949_X066-MP0C1W1_leukopak_nuclei_multiome
  GSM5123950_X066-MP0C1W2_leukopak_perm-cells_multiome
  GSM5123951_X066-MP0C1W3_leukopak_perm-cells_tea
  GSM5123952_X066-MP0C1W4_leukopak_perm-cells_tea
  GSM5123953_X066-MP0C1W5_leukopak_perm-cells_tea
  GSM5123954_X066-MP0C1W6_leukopak_perm-cells_tea


In [None]:
# ----------------------
# Load all samples
# ----------------------
rna_list, adt_list, atac_list = [], [], []

for prefix in prefixes:
    try:
        rna_i, adt_i, atac_i = load_teaseq_sample(
            prefix=prefix,
            data_dir=data_dir,
            target_n=10000,     # per-sample cap; set to None for all cells
            min_fragments=1000,
            n_lsi=200,
        )
    except FileNotFoundError as e:
        print("  Skipping sample (missing file):", e)
        continue
    except Exception as e:
        print("  ERROR while loading sample, skipping:", prefix, "->", e)
        continue

    rna_list.append(rna_i)
    adt_list.append(adt_i)
    atac_list.append(atac_i)



===== Loading TEA-seq sample: GSM5123949_X066-MP0C1W1_leukopak_nuclei_multiome =====
  Skipping sample (missing file): ../data/TEA-seq_data/GSM5123949_X066-MP0C1W1_leukopak_nuclei_multiome_48M_adt_counts.csv.gz

===== Loading TEA-seq sample: GSM5123950_X066-MP0C1W2_leukopak_perm-cells_multiome =====
  Skipping sample (missing file): ../data/TEA-seq_data/GSM5123950_X066-MP0C1W2_leukopak_perm-cells_multiome_48M_adt_counts.csv.gz

===== Loading TEA-seq sample: GSM5123951_X066-MP0C1W3_leukopak_perm-cells_tea =====
Reading RNA (ARC filtered_feature_bc_matrix.h5)...


  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")


  RNA shape: (7966, 36601)
Reading ADT counts...
  ADT counts df shape: (720873, 47)


In [None]:
# Concatenate across samples (obs_names are already unique and aligned per cell)
rna_all  = ad.concat(rna_list,  join="outer", merge="same")
adt_all  = ad.concat(adt_list,  join="outer", merge="same")
atac_all = ad.concat(atac_list, join="outer", merge="same")

print("\nCombined shapes:")
print("  RNA all :", rna_all.shape)
print("  ADT all :", adt_all.shape)
print("  ATAC all:", atac_all.shape)


#### Preprocess data since we will be using Gaussian decoders in this case to prioritize data alignment

In [None]:
# ----------------------
# Global preprocessing per modality
# ----------------------

# --- RNA: log-normalize + HVGs ---
rna = rna_all.copy()
rna.layers["counts"] = rna.X.copy()

sc.pp.normalize_total(rna, target_sum=1e4)
sc.pp.log1p(rna)
sc.pp.highly_variable_genes(rna, n_top_genes=2000, flavor="seurat_v3")
rna = rna[:, rna.var["highly_variable"]].copy()
print("RNA (HVG log1p) shape:", rna.shape)


In [None]:
# --- ADT: CLR per cell ---
adt = adt_all.copy()
adt.layers["counts"] = adt.X.copy()

X = adt.layers["counts"].astype(float)
if sp.issparse(X):
    X = X.toarray()

eps = 1e-6
X_log = np.log1p(X + eps)
X_clr = X_log - X_log.mean(axis=1, keepdims=True)
adt.X = X_clr.astype(np.float32)
print("ADT CLR shape:", adt.shape)


In [None]:
# --- ATAC: z-score each LSI dimension globally ---
atac = atac_all.copy()
X_atac = atac.X.astype(np.float32)

mean = X_atac.mean(axis=0, keepdims=True)
std  = X_atac.std(axis=0, keepdims=True) + 1e-6
X_z  = (X_atac - mean) / std
atac.X = X_z.astype(np.float32)
print("ATAC LSI-z shape:", atac.shape)


#### Initialize model and data via dataloaders

In [None]:
adata_dict = {
    "rna": rna,
    "adt": adt,
    "atac": atac,
}


In [None]:
from univi.data import MultiModalDataset
from univi.config import UniVIConfig, ModalityConfig, TrainingConfig
from univi.models.univi import UniVIMultiModalVAE
from univi.trainer import UniVITrainer

# ---------- UniVI config (Gaussian for all 3) ----------
univi_cfg = UniVIConfig(
    latent_dim=60,
    beta=80.0,
    gamma=120.0,
    encoder_dropout=0.0,
    decoder_dropout=0.0,
    encoder_batchnorm=True,
    decoder_batchnorm=False,
    kl_anneal_start=5,
    kl_anneal_end=45,
    align_anneal_start=5,
    align_anneal_end=45,
    modalities=[
        ModalityConfig(
            name="rna",
            input_dim=rna.n_vars,
            encoder_hidden=[512, 256],
            decoder_hidden=[256, 512],
            likelihood="gaussian",
        ),
        ModalityConfig(
            name="adt",
            input_dim=adt.n_vars,
            encoder_hidden=[128, 64],
            decoder_hidden=[64, 128],
            likelihood="gaussian",
        ),
        ModalityConfig(
            name="atac",
            input_dim=atac.n_vars,  # n_lsi (e.g. 50)
            encoder_hidden=[224, 128],
            decoder_hidden=[128, 224],
            likelihood="gaussian",
        ),
    ],
)

train_cfg = TrainingConfig(
    n_epochs=200,
    batch_size=256,
    lr=1e-3,
    weight_decay=1e-4,
    device="cpu",   # or "cuda"
    log_every=10,
    grad_clip=5.0,
    num_workers=4,
    seed=42,
    early_stopping=True,
    patience=20,
    min_delta=0.0,
)


In [None]:
dataset = MultiModalDataset(
    adata_dict=adata_dict,
    X_key="X",
    device=train_cfg.device,
)

n_cells = dataset.n_cells
indices = np.arange(n_cells)
rng = np.random.default_rng(42)
rng.shuffle(indices)

frac_train = 0.8
frac_val   = 0.1
n_train = int(frac_train * n_cells)
n_val   = int(frac_val * n_cells)

train_idx = indices[:n_train]
val_idx   = indices[n_train:n_train + n_val]
test_idx  = indices[n_train + n_val:]

train_ds = Subset(dataset, train_idx)
val_ds   = Subset(dataset, val_idx)

train_loader = DataLoader(
    train_ds,
    batch_size=train_cfg.batch_size,
    shuffle=True,
    num_workers=train_cfg.num_workers,
)

val_loader = DataLoader(
    val_ds,
    batch_size=train_cfg.batch_size,
    shuffle=False,
    num_workers=train_cfg.num_workers,
)

model = UniVIMultiModalVAE(univi_cfg).to(train_cfg.device)
trainer = UniVITrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    train_cfg=train_cfg,
    device=train_cfg.device,
)


#### Train model

In [None]:
# ---------- train ----------
history = trainer.fit()


In [None]:
# Quick training curves
fig, ax = plt.subplots()
ax.plot(history["train_loss"], label="train")
ax.plot(history["val_loss"], label="val")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.set_title("UniVI Multiome training curves")
ax.legend()
plt.tight_layout()
plt.show()

fig, ax = plt.subplots()
ax.plot(history["beta"], label="beta")
ax.plot(history["gamma"], label="gamma")
ax.set_xlabel("Epoch")
ax.set_ylabel("Weight")
ax.set_title("KL / alignment annealing")
ax.legend()
plt.tight_layout()
plt.show()


In [None]:
from dataclasses import asdict

os.makedirs("../saved_models", exist_ok=True)

# trainer.model already has the best weights (because we restored best_state_dict)
ckpt_path = "../saved_models/univi_tea_seq_beta_80_gamma_120_60_latent_dims.pt"
torch.save(
    {
        "state_dict": trainer.model.state_dict(),
        "univi_cfg": asdict(univi_cfg),
        "best_epoch": trainer.best_epoch,
        "best_val_loss": trainer.best_val_loss,
    },
    ckpt_path,
)
print("Saved best model to:", ckpt_path)


#### Evaluate model

In [None]:
import torch
from univi.config import UniVIConfig, ModalityConfig
from univi.models.univi import UniVIMultiModalVAE

device = "cpu"  # or "cuda" if available

ckpt = torch.load(
    "../saved_models/univi_tea_seq_beta_80_gamma_120_60_latent_dims.pt",
    map_location=device,
)

# ---- Rebuild UniVIConfig, making sure modalities are ModalityConfig objects ----
cfg_dict = ckpt["univi_cfg"]

# If this is an OmegaConf object or similar, make sure it's a plain dict
try:
    from omegaconf import DictConfig, OmegaConf
    if isinstance(cfg_dict, DictConfig):
        cfg_dict = OmegaConf.to_container(cfg_dict, resolve=True)
except ImportError:
    pass

# Now rehydrate each modality
modalities = [ModalityConfig(**m) for m in cfg_dict["modalities"]]
cfg_dict = {**cfg_dict, "modalities": modalities}

univi_cfg_loaded = UniVIConfig(**cfg_dict)

# ---- Rebuild model + load weights ----
model_loaded = UniVIMultiModalVAE(univi_cfg_loaded).to(device)
model_loaded.load_state_dict(ckpt["state_dict"])

print("Best epoch was:", ckpt.get("best_epoch"), "val loss =", ckpt.get("best_val_loss"))


In [None]:
z_rna  = trainer.encode_modality(rna,  modality="rna")
z_adt  = trainer.encode_modality(adt,  modality="adt")
z_atac = trainer.encode_modality(atac, modality="atac")

rna.obsm["X_univi"]  = z_rna
adt.obsm["X_univi"]  = z_adt
atac.obsm["X_univi"] = z_atac


In [None]:
print(rna_test_adata.obs['celltype'])


In [None]:
# ============================
# UniVI TEA-seq evaluation (RNA / ADT / ATAC)
# ============================
import os
import numpy as np
import torch
import scipy.sparse as sp
import matplotlib.pyplot as plt
import seaborn as sns

from univi import evaluation as univi_eval
from univi import plotting as univi_plot

# -----------------------------------------
# CONFIG
# -----------------------------------------
FIGDIR = "../figures/teaseq_univi_tri_modal_eval"
os.makedirs(FIGDIR, exist_ok=True)

device = train_cfg.device  # e.g. "cuda" or "cpu"
ct_key = "celltype"        # or "celltype.l2" if that's what you used

# Sanity: make sure key exists
for ad, name in [
    (rna_test_adata, "rna_test_adata"),
    (adt_test_adata, "adt_test_adata"),
    (atac_test_adata, "atac_test_adata"),
]:
    if ct_key not in ad.obs.columns:
        raise KeyError(f"{ct_key!r} not in {name}.obs")

# -----------------------------------------
# 1. Encode latent embeddings for test sets
# -----------------------------------------
print("Encoding test sets into UniVI latent space...")

z_rna  = univi_eval.encode_adata(model, rna_test_adata,  modality="rna",  device=device)
z_adt  = univi_eval.encode_adata(model, adt_test_adata,  modality="adt",  device=device)
z_atac = univi_eval.encode_adata(model, atac_test_adata, modality="atac", device=device)

rna_test_adata.obsm["X_univi"]  = z_rna
adt_test_adata.obsm["X_univi"]  = z_adt
atac_test_adata.obsm["X_univi"] = z_atac

print("Latent shapes (test):")
print("  RNA :", z_rna.shape)
print("  ADT :", z_adt.shape)
print("  ATAC:", z_atac.shape)

# Optional: ensure all three test sets are aligned by obs_names
if not (rna_test_adata.obs_names.equals(adt_test_adata.obs_names) and
        rna_test_adata.obs_names.equals(atac_test_adata.obs_names)):
    print("WARNING: test adatas are not perfectly obs_names-aligned. "
          "FOSCTTM/label-transfer still work, but they aren't strictly one-to-one.")


# -----------------------------------------
# 2. FOSCTTM (pairwise alignment)
# -----------------------------------------
print("\nComputing FOSCTTM for each modality pair (lower = better)...")
fos_rna_adt  = univi_eval.compute_foscttm(z_rna,  z_adt)
fos_rna_atac = univi_eval.compute_foscttm(z_rna,  z_atac)
fos_adt_atac = univi_eval.compute_foscttm(z_adt,  z_atac)

print(f"  RNA  vs ADT : {fos_rna_adt:.4f}")
print(f"  RNA  vs ATAC: {fos_rna_atac:.4f}")
print(f"  ADT  vs ATAC: {fos_adt_atac:.4f}")

# Barplot of FOSCTTM
plt.figure(figsize=(4, 4))
pairs = ["RNA–ADT", "RNA–ATAC", "ADT–ATAC"]
vals = [fos_rna_adt, fos_rna_atac, fos_adt_atac]
sns.barplot(x=pairs, y=vals)
plt.ylabel("FOSCTTM (mean)")
plt.title("Tri-modal FOSCTTM (lower = better)")
plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, "foscttm_barplot.png"))
plt.close()


# -----------------------------------------
# 3. Modality mixing (all three modalities)
# -----------------------------------------
Z_joint = np.concatenate([z_rna, z_adt, z_atac], axis=0)
modality_labels = np.array(
    ["rna"]  * z_rna.shape[0]
    + ["adt"]  * z_adt.shape[0]
    + ["atac"] * z_atac.shape[0]
)

mixing_score = univi_eval.compute_modality_mixing(
    Z_joint,
    modality_labels,
    k=20,
)
print(f"\nModality mixing score (RNA/ADT/ATAC, k=20): {mixing_score:.3f}")

# Simple distribution of nearest-neighbor modality labels
print("Computing neighbor modality distribution for quick sanity check...")
k = 20
from sklearn.neighbors import NearestNeighbors

nn = NearestNeighbors(n_neighbors=k + 1)  # +1 for self
nn.fit(Z_joint)
_, idx = nn.kneighbors(Z_joint)

# drop self
idx_neighbors = idx[:, 1:]
neighbor_mods = modality_labels[idx_neighbors]

same_mod_frac = (neighbor_mods == modality_labels[:, None]).mean()
print(f"  Fraction of neighbors with same modality: {same_mod_frac:.3f}")


# -----------------------------------------
# 4. Label transfer between modalities
# -----------------------------------------
labels_rna  = rna_test_adata.obs[ct_key].astype(str).values
labels_adt  = adt_test_adata.obs[ct_key].astype(str).values
labels_atac = atac_test_adata.obs[ct_key].astype(str).values

print("\nLabel transfer accuracies (k=15):")

def label_transfer_report(
    Z_src, lab_src,
    Z_tgt, lab_tgt,
    src_name: str,
    tgt_name: str,
    cm_filename: str,
):
    pred, acc, cm = univi_eval.label_transfer_knn(
        Z_source=Z_src,
        labels_source=lab_src,
        Z_target=Z_tgt,
        labels_target=lab_tgt,
        k=15,
    )
    print(f"  {src_name} → {tgt_name}: {acc:.3f}")

    uniq_labels = np.unique(lab_tgt)
    savepath = os.path.join(FIGDIR, cm_filename)
    univi_plot.plot_confusion_matrix(
        cm,
        labels=uniq_labels,
        title=f"{src_name} → {tgt_name} label transfer ({ct_key})",
        savepath=savepath,
    )

# ADT as a "clean" source
label_transfer_report(
    z_adt, labels_adt,
    z_rna, labels_rna,
    src_name="ADT", tgt_name="RNA",
    cm_filename="cm_ADT_to_RNA.png",
)
label_transfer_report(
    z_adt, labels_adt,
    z_atac, labels_atac,
    src_name="ADT", tgt_name="ATAC",
    cm_filename="cm_ADT_to_ATAC.png",
)

# RNA ↔ ATAC as well
label_transfer_report(
    z_rna, labels_rna,
    z_atac, labels_atac,
    src_name="RNA", tgt_name="ATAC",
    cm_filename="cm_RNA_to_ATAC.png",
)
label_transfer_report(
    z_atac, labels_atac,
    z_rna, labels_rna,
    src_name="ATAC", tgt_name="RNA",
    cm_filename="cm_ATAC_to_RNA.png",
)


# -----------------------------------------
# 5. UMAP visualizations (tri-modal)
# -----------------------------------------
# Tag each test set with modality
for adata, mod in [
    (rna_test_adata, "rna"),
    (adt_test_adata, "adt"),
    (atac_test_adata, "atac"),
]:
    adata.obs["univi_source"] = mod

combined = rna_test_adata.concatenate(
    adt_test_adata, atac_test_adata,
    batch_key="univi_batch",
    batch_categories=["rna", "adt", "atac"],
    index_unique="-",
)

# UMAP colored by cell type
univi_plot.umap_single_adata(
    combined,
    obsm_key="X_univi",
    color=ct_key,
    savepath=os.path.join(FIGDIR, "umap_tri_modal_celltype.png"),
)

# UMAP colored by modality
univi_plot.umap_single_adata(
    combined,
    obsm_key="X_univi",
    color="univi_source",
    savepath=os.path.join(FIGDIR, "umap_tri_modal_modality.png"),
)

# Optional: per-modality UMAPs (just subsets)
for mod in ["rna", "adt", "atac"]:
    sub = combined[combined.obs["univi_source"] == mod].copy()
    univi_plot.umap_single_adata(
        sub,
        obsm_key="X_univi",
        color=ct_key,
        savepath=os.path.join(FIGDIR, f"umap_{mod}_only_celltype.png"),
    )


# -----------------------------------------
# 6. Latent geometry diagnostics
# -----------------------------------------
print("\nLatent geometry diagnostics...")

def latent_norms(z, label):
    norms = np.linalg.norm(z, axis=1)
    return norms, np.repeat(label, len(norms))

norm_rna,  lab_rna  = latent_norms(z_rna,  "RNA")
norm_adt,  lab_adt  = latent_norms(z_adt,  "ADT")
norm_atac, lab_atac = latent_norms(z_atac, "ATAC")

norms_all = np.concatenate([norm_rna, norm_adt, norm_atac])
labs_all  = np.concatenate([lab_rna, lab_adt, lab_atac])

plt.figure(figsize=(5, 4))
sns.violinplot(x=labs_all, y=norms_all, inner="box")
plt.ylabel("‖z‖ (L2 norm)")
plt.xlabel("Modality")
plt.title("Latent L2-norm distribution by modality")
plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, "latent_norms_by_modality.png"))
plt.close()

# Optionally: correlation heatmap of latent dims (on test RNA)
corr_latent = np.corrcoef(z_rna, rowvar=False)
plt.figure(figsize=(6, 5))
sns.heatmap(corr_latent, cmap="vlag", center=0)
plt.title("RNA latent dimension correlation (test set)")
plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, "latent_corr_heatmap_rna.png"))
plt.close()


# -----------------------------------------
# 7. Cross-modal reconstruction metrics
# -----------------------------------------
def _to_dense(X):
    return X.toarray() if sp.issparse(X) else np.asarray(X)

def cross_modal_metrics(
    model,
    src_adata,
    tgt_adata,
    src_mod: str,
    tgt_mod: str,
    name_prefix: str,
    device: str,
):
    Xhat_tgt = univi_eval.cross_modal_predict(
        model,
        adata_src=src_adata,
        src_mod=src_mod,
        tgt_mod=tgt_mod,
        device=device,
        batch_size=512,
    )

    X_tgt = _to_dense(tgt_adata.X)

    mse_feat = univi_eval.mse_per_feature(X_tgt, Xhat_tgt)
    corr_feat = univi_eval.pearson_corr_per_feature(X_tgt, Xhat_tgt)

    print(f"\nCross-modal: {src_mod} → {tgt_mod}")
    print(f"  Mean feature MSE: {mse_feat.mean():.4f}")
    print(f"  Mean feature Pearson r: {corr_feat.mean():.3f}")

    # Histogram of per-feature correlation
    plt.figure(figsize=(5, 4))
    sns.histplot(corr_feat, bins=40, kde=False)
    plt.xlabel("Per-feature Pearson r")
    plt.ylabel("Count")
    plt.title(f"{src_mod} → {tgt_mod}: per-feature correlation")
    plt.tight_layout()
    plt.savefig(os.path.join(FIGDIR, f"{name_prefix}_corr_hist.png"))
    plt.close()

    # Optional: scatter of observed vs predicted for a few top features
    if hasattr(tgt_adata, "var_names"):
        varnames = np.array(tgt_adata.var_names)
        order = np.argsort(corr_feat)
        best_idx = order[-3:]
        worst_idx = order[:3]

        for idx_set, tag in [(best_idx, "best"), (worst_idx, "worst")]:
            plt.figure(figsize=(4 * len(idx_set), 4))
            for i, j in enumerate(idx_set):
                plt.subplot(1, len(idx_set), i + 1)
                plt.scatter(X_tgt[:, j], Xhat_tgt[:, j], s=4, alpha=0.3)
                plt.xlabel("Observed")
                plt.ylabel("Predicted")
                plt.title(f"{tgt_mod} {varnames[j]}\n r = {corr_feat[j]:.2f}")
            plt.tight_layout()
            plt.savefig(os.path.join(FIGDIR, f"{name_prefix}_{tag}_features_scatter.png"))
            plt.close()

    return mse_feat, corr_feat

# Evaluate key directions on TEST set
_ = cross_modal_metrics(model, rna_test_adata, adt_test_adata,
                        src_mod="rna", tgt_mod="adt",
                        name_prefix="RNA_to_ADT_test", device=device)

_ = cross_modal_metrics(model, rna_test_adata, atac_test_adata,
                        src_mod="rna", tgt_mod="atac",
                        name_prefix="RNA_to_ATAC_test", device=device)

_ = cross_modal_metrics(model, adt_test_adata, rna_test_adata,
                        src_mod="adt", tgt_mod="rna",
                        name_prefix="ADT_to_RNA_test", device=device)

_ = cross_modal_metrics(model, atac_test_adata, rna_test_adata,
                        src_mod="atac", tgt_mod="rna",
                        name_prefix="ATAC_to_RNA_test", device=device)


# -----------------------------------------
# 8. Denoising with decoders (on "unused" cells)
# -----------------------------------------
for adata, mod, tag in [
    (locals().get("rna_unused", None),  "rna",  "rna_unused"),
    (locals().get("adt_unused", None),  "adt",  "adt_unused"),
    (locals().get("atac_unused", None), "atac", "atac_unused"),
]:
    if adata is None:
        continue
    print(f"Denoising {tag} ({mod})...")
    univi_eval.denoise_adata(
        model, adata, modality=mod,
        device=device, batch_size=512,
        out_layer="univi_denoised",
    )
    # Simple per-cell reconstruction error hist
    X_raw = _to_dense(adata.X)
    X_den = _to_dense(adata.layers["univi_denoised"])
    per_cell_mse = ((X_raw - X_den) ** 2).mean(axis=1)

    plt.figure(figsize=(5, 4))
    sns.histplot(per_cell_mse, bins=40)
    plt.xlabel("Per-cell MSE (raw vs denoised)")
    plt.ylabel("Count")
    plt.title(f"Denoising quality: {tag} ({mod})")
    plt.tight_layout()
    plt.savefig(os.path.join(FIGDIR, f"denoise_{tag}_{mod}.png"))
    plt.close()


# -----------------------------------------
# 9. Cell-type–conditional sampling in latent space
# -----------------------------------------
if "rna_unused" in locals() and rna_unused is not None:
    print("\nFitting latent Gaussians per cell type (RNA unused)...")

    labels_rna_unused = rna_unused.obs[ct_key].astype(str).values
    Z_rna_unused = univi_eval.encode_adata(
        model, rna_unused, modality="rna", device=device
    )
    rna_unused.obsm["X_univi"] = Z_rna_unused

    gauss_by_ct = univi_eval.fit_latent_gaussians_by_label(
        Z_rna_unused, labels_rna_unused
    )

    unique_cts, counts_ct = np.unique(labels_rna_unused, return_counts=True)
    spec = {
        ct: int(min(1000, n))
        for ct, n in zip(unique_cts, counts_ct)
        if n > 50      # skip ultra-rare types
    }

    print("\nSampling spec per cell type (RNA latent):")
    for ct, n in spec.items():
        print(f"  {ct}: {n} synthetic cells")

    z_samp_by_ct = univi_eval.sample_from_latent_gaussians(
        gauss_by_ct,
        spec,
        random_state=42,
    )

    def decode_latent_samples(model, z_samp_by_ct, modality: str, device: str = "cpu"):
        model.eval()
        out = {}
        with torch.no_grad():
            for lab, Z in z_samp_by_ct.items():
                z_t = torch.as_tensor(Z, dtype=torch.float32, device=device)
                xhat_dict = model.decode_modalities(z_t)
                out[lab] = xhat_dict[modality].cpu().numpy()
        return out

    synthetic_adt_by_ct = decode_latent_samples(
        model, z_samp_by_ct, modality="adt", device=device
    )

    print("\nSynthetic ADT samples by cell type (keys):", list(synthetic_adt_by_ct.keys()))

    # Quick QC: mean synthetic vs real ADT per cell type
    if "adt_unused" in locals() and adt_unused is not None:
        adt_labels = adt_unused.obs[ct_key].astype(str).values
        X_adt_real = _to_dense(adt_unused.X)
        varnames_adt = np.array(adt_unused.var_names)

        for ct in list(synthetic_adt_by_ct.keys())[:5]:  # first few CTs
            syn = synthetic_adt_by_ct[ct].mean(axis=0)
            idx = np.where(adt_labels == ct)[0]
            if idx.size == 0:
                continue
            real = X_adt_real[idx].mean(axis=0)

            corr = np.corrcoef(real, syn)[0, 1]
            print(f"  CT {ct}: mean synthetic vs real ADT corr = {corr:.3f}")
else:
    print("\nNo rna_unused available, skipping latent Gaussian sampling.")
