# UniVI TEA-seq tri-modal data integration demonstration/tutorial

Andrew Ashford, Pathways + Omics Group, Oregon Health & Science University - 11/18/2025

This Jupyter Notebook will be used to outline the training steps for a UniVI model using human PBMC TEA-seq tri-modal data.


#### Import modules

In [1]:
import os
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import scipy.sparse as sp
import scanpy as sc
import anndata as ad

from sklearn.decomposition import TruncatedSVD
from sklearn.preprocessing import normalize

import snapatac2 as snap


In [2]:
# -------------------------
# 0. Wire up package import
# -------------------------
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)

from univi import (
    UniVIMultiModalVAE,
    ModalityConfig,
    UniVIConfig,
    TrainingConfig,
    matching,
)
from univi.data import MultiModalDataset
from univi.trainer import UniVITrainer


In [3]:
import torch
print("Torch:", torch.__version__)
print("torch.version.cuda:", torch.version.cuda)
print("CUDA available:", torch.cuda.is_available())

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)


Torch: 2.9.1+cu128
torch.version.cuda: 12.8
CUDA available: True
Using device: cuda


#### Read in and preprocess data as needed

In [4]:
data_dir = Path("../data/TEA-seq_data")
prefix   = "GSM5123953_X066-MP0C1W5_leukopak_perm-cells_tea"

rna_h5         = data_dir / f"{prefix}_200M_cellranger-arc_filtered_feature_bc_matrix.h5"
adt_counts_csv = data_dir / f"{prefix}_48M_adt_counts.csv.gz"
frag_tsv       = data_dir / f"{prefix}_200M_atac_filtered_fragments.tsv.gz"
atac_meta_csv  = data_dir / f"{prefix}_200M_atac_filtered_metadata.csv.gz"


In [5]:
# ----------------------
# 1) RNA
# ----------------------
print("Reading RNA (ARC filtered_feature_bc_matrix.h5)...")
m = sc.read_10x_h5(rna_h5)
print(m)

print("Feature types:", m.var["feature_types"].unique())

rna_adata = m.copy()
rna_adata.var_names_make_unique()
print("Initial RNA shape:", rna_adata.shape)


Reading RNA (ARC filtered_feature_bc_matrix.h5)...


  utils.warn_names_duplicates("var")


AnnData object with n_obs × n_vars = 8496 × 36601
    var: 'gene_ids', 'feature_types', 'genome', 'interval'
Feature types: ['Gene Expression']
Initial RNA shape: (8496, 36601)


  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")


In [75]:
# Sanity check
print(rna_adata.X.min())
print(rna_adata.X.max())


0.0
737.0


In [6]:
# ----------------------
# 2) ADT
# ----------------------
print("Reading ADT counts...")
adt_df = pd.read_csv(adt_counts_csv, index_col=0)  # rows = barcodes, cols = ADT names
print("ADT counts shape:", adt_df.shape)

adt_adata = ad.AnnData(
    X=sp.csr_matrix(adt_df.values),
    obs=pd.DataFrame(index=adt_df.index.astype(str)),
    var=pd.DataFrame(index=adt_df.columns.astype(str)),
)
adt_adata.var_names_make_unique()
print("ADT AnnData:", adt_adata.shape)


Reading ADT counts...
ADT counts shape: (719170, 47)
ADT AnnData: (719170, 47)


In [76]:
# Sanity check
print(adt_adata.X.min())
print(adt_adata.X.max())


0
12219


In [7]:
# ----------------------
# 3) ATAC: fragments + metadata -> LSI
# ----------------------
print("Importing ATAC fragments with snapatac2...")
atac_raw = snap.pp.import_data(
    fragment_file=str(frag_tsv),
    chrom_sizes=snap.genome.hg38,   # use chrom_sizes
    sorted_by_barcode=False,        # TEA-seq fragments usually unsorted
)

print("ATAC raw object:", atac_raw)


Importing ATAC fragments with snapatac2...


  atac_raw = snap.pp.import_data(


ATAC raw object: AnnData object with n_obs × n_vars = 7540 × 0
    obs: 'n_fragment', 'frac_dup', 'frac_mito'
    uns: 'reference_sequences'
    obsm: 'fragment_paired'


In [8]:
# ---- 3a. Attach external metadata and map barcodes ----
meta = pd.read_csv(atac_meta_csv)
print("ATAC meta columns:", meta.columns.tolist())

# Index metadata by hex barcodes; they match atac_raw.obs_names
meta = meta.set_index("barcodes")

# Align cells common between fragments + metadata
common_ids = atac_raw.obs_names.intersection(meta.index)
print("ATAC cells with metadata:", len(common_ids), "of", atac_raw.n_obs)

atac_raw = atac_raw[common_ids].copy()
atac_raw.obs = atac_raw.obs.join(meta, how="left")


ATAC meta columns: ['original_barcodes', 'n_fragments', 'n_duplicate', 'n_mito', 'n_unique', 'altius_count', 'altius_frac', 'gene_bodies_count', 'gene_bodies_frac', 'peaks_count', 'peaks_frac', 'tss_count', 'tss_frac', 'barcodes', 'cell_name', 'well_id', 'chip_id', 'batch_id', 'pbmc_sample_id', 'DoubletScore', 'DoubletEnrichment', 'TSSEnrichment']
ATAC cells with metadata: 7540 of 7540




In [9]:
# ----------------------
# 4) Simple ATAC QC (using n_fragments from metadata)
# ----------------------
if "n_fragments" in atac_raw.obs.columns:
    mask = atac_raw.obs["n_fragments"] >= 1500
    print("Keeping", mask.sum(), "of", atac_raw.n_obs, "ATAC cells with n_fragments >= 1000")
    atac_raw = atac_raw[mask].copy()


Keeping 7540 of 7540 ATAC cells with n_fragments >= 1000




In [77]:
# Sanity check
print(atac_raw.X.min())
print(atac_raw.X.max())


0
29


In [10]:
# ----------------------
# 5) Tile matrix + TF–IDF + LSI
# ----------------------
print("Adding ATAC tile matrix...")
snap.pp.add_tile_matrix(atac_raw)
print("Tile matrix shape:", atac_raw.shape, "(cells x genomic tiles)")


Adding ATAC tile matrix...
Tile matrix shape: (7540, 6062095) (cells x genomic tiles)


In [11]:
print("Computing TF–IDF and LSI on ATAC tiles...")
X = atac_raw.X
if not sp.issparse(X):
    X = sp.csr_matrix(X)

n_cells, n_feats = X.shape
print(f"ATAC tile matrix: {n_cells} cells × {n_feats} features")


Computing TF–IDF and LSI on ATAC tiles...
ATAC tile matrix: 7540 cells × 6062095 features


In [12]:
# TF: normalize by per-cell counts
tf = normalize(X, norm="l1", axis=1)

# DF: number of cells with a non-zero in each feature
df = np.array((X > 0).sum(axis=0)).ravel()
idf = np.log1p(n_cells / (1.0 + df))

# TF-IDF
X_tfidf = tf.multiply(idf)


In [17]:
'''
import numpy as np

# X: sparse matrix of shape (n_cells, n_peaks)
min_cells_per_peak = 50  # or 20, depending on sparsity

# number of nonzeros per peak (i.e., per column)
peak_nnz = X.getnnz(axis=0)          # returns a 1D np.array-like
peak_nnz = np.asarray(peak_nnz).ravel()

keep_peaks = peak_nnz >= min_cells_per_peak

X_filtered = X[:, keep_peaks]
'''

'\nimport numpy as np\n\n# X: sparse matrix of shape (n_cells, n_peaks)\nmin_cells_per_peak = 50  # or 20, depending on sparsity\n\n# number of nonzeros per peak (i.e., per column)\npeak_nnz = X.getnnz(axis=0)          # returns a 1D np.array-like\npeak_nnz = np.asarray(peak_nnz).ravel()\n\nkeep_peaks = peak_nnz >= min_cells_per_peak\n\nX_filtered = X[:, keep_peaks]\n'

In [15]:
if X_tfidf.dtype != np.float32:
    X_tfidf = X_tfidf.astype(np.float32)


In [16]:
# SVD → LSI
n_lsi = 100
svd = TruncatedSVD(n_components=n_lsi, random_state=42)
lsi = svd.fit_transform(X_tfidf)


In [18]:
# row L2-normalize
lsi = normalize(lsi, norm="l2", axis=1)


In [19]:
# Wrap as ATAC AnnData for UniVI
atac_adata = ad.AnnData(
    X=lsi.astype(np.float32),
    obs=atac_raw.obs.copy(),   # includes original_barcodes, etc.
)
print("ATAC LSI AnnData:", atac_adata.shape)


ATAC LSI AnnData: (7540, 100)


In [20]:
# ----------------------
# 6) Put everything into shared barcode space
# ----------------------
def strip_suffix(idx):
    # drop trailing "-<number>" if present
    return idx.astype(str).str.replace(r"-\d+$", "", regex=True)

# RNA & ADT already indexed by 10x barcodes with "-1"
rna_adata.obs_names = strip_suffix(rna_adata.obs_names.to_series())
adt_adata.obs_names = strip_suffix(adt_adata.obs_names.to_series())

# For ATAC, map from hex barcodes -> original 10x barcodes, then strip suffix
atac_adata.obs["barcode_10x"] = atac_adata.obs["original_barcodes"].astype(str)
atac_adata.obs_names = strip_suffix(atac_adata.obs["barcode_10x"])

# Also strip suffix from RNA/ADT now for consistency
rna_adata.obs_names = strip_suffix(rna_adata.obs_names.to_series())
adt_adata.obs_names = strip_suffix(adt_adata.obs_names.to_series())


In [21]:
# ----------------------
# 7) Tri-modal intersection
# ----------------------
common_barcodes = (
    set(rna_adata.obs_names)
    & set(adt_adata.obs_names)
    & set(atac_adata.obs_names)
)

print("Common tri-modal cells:", len(common_barcodes))
if len(common_barcodes) == 0:
    raise ValueError("No overlapping barcodes across RNA/ADT/ATAC after mapping.")

common_barcodes = sorted(common_barcodes)

rna_adata  = rna_adata[common_barcodes].copy()
adt_adata  = adt_adata[common_barcodes].copy()
atac_adata = atac_adata[common_barcodes].copy()

print("Aligned shapes:")
print("  RNA :", rna_adata.shape)
print("  ADT :", adt_adata.shape)
print("  ATAC:", atac_adata.shape)


Common tri-modal cells: 7421
Aligned shapes:
  RNA : (7421, 36601)
  ADT : (7421, 47)
  ATAC: (7421, 100)


In [22]:
# ----------------------
# 8) Optional subsampling
# ----------------------
target_n = 10000
n_cells  = rna_adata.n_obs

if n_cells > target_n:
    rng = np.random.default_rng(42)
    keep_idx = rng.choice(n_cells, size=target_n, replace=False)
    keep_barcodes = rna_adata.obs_names[keep_idx]

    rna_adata  = rna_adata[keep_barcodes].copy()
    adt_adata  = adt_adata[keep_barcodes].copy()
    atac_adata = atac_adata[keep_barcodes].copy()

print("After optional subsampling:", rna_adata.n_obs, "cells.")


After optional subsampling: 7421 cells.


#### Preprocess data since we will be using Gaussian decoders in this case to prioritize data alignment

In [23]:
# ----------------------
# 9) Modality-specific preprocessing
# ----------------------
'''
# --- RNA: log-normalize + HVGs ---
rna = rna_adata.copy()
rna.layers["counts"] = rna.X.copy()

sc.pp.normalize_total(rna, target_sum=1e4)
sc.pp.log1p(rna)
sc.pp.highly_variable_genes(rna, n_top_genes=2000, flavor="seurat_v3")
rna = rna[:, rna.var["highly_variable"]].copy()
print("RNA (HVG log1p) shape:", rna.shape)
'''

'\n# --- RNA: log-normalize + HVGs ---\nrna = rna_adata.copy()\nrna.layers["counts"] = rna.X.copy()\n\nsc.pp.normalize_total(rna, target_sum=1e4)\nsc.pp.log1p(rna)\nsc.pp.highly_variable_genes(rna, n_top_genes=2000, flavor="seurat_v3")\nrna = rna[:, rna.var["highly_variable"]].copy()\nprint("RNA (HVG log1p) shape:", rna.shape)\n'

In [24]:
'''
# --- ADT: CLR per cell ---
adt = adt_adata.copy()
adt.layers["counts"] = adt.X.copy()

X = adt.layers["counts"].astype(float)
if sp.issparse(X):
    X = X.toarray()

eps = 1e-6
X_log = np.log1p(X + eps)
X_clr = X_log - X_log.mean(axis=1, keepdims=True)
adt.X = X_clr.astype(np.float32)
print("ADT CLR shape:", adt.shape)
'''

'\n# --- ADT: CLR per cell ---\nadt = adt_adata.copy()\nadt.layers["counts"] = adt.X.copy()\n\nX = adt.layers["counts"].astype(float)\nif sp.issparse(X):\n    X = X.toarray()\n\neps = 1e-6\nX_log = np.log1p(X + eps)\nX_clr = X_log - X_log.mean(axis=1, keepdims=True)\nadt.X = X_clr.astype(np.float32)\nprint("ADT CLR shape:", adt.shape)\n'

In [25]:
'''
# --- ATAC: z-score each LSI dimension ---
atac = atac_adata.copy()
X_atac = atac.X.astype(np.float32)

mean = X_atac.mean(axis=0, keepdims=True)
std  = X_atac.std(axis=0, keepdims=True) + 1e-6
X_z  = (X_atac - mean) / std
atac.X = X_z.astype(np.float32)
print("ATAC LSI-z shape:", atac.shape)
'''

'\n# --- ATAC: z-score each LSI dimension ---\natac = atac_adata.copy()\nX_atac = atac.X.astype(np.float32)\n\nmean = X_atac.mean(axis=0, keepdims=True)\nstd  = X_atac.std(axis=0, keepdims=True) + 1e-6\nX_z  = (X_atac - mean) / std\natac.X = X_z.astype(np.float32)\nprint("ATAC LSI-z shape:", atac.shape)\n'

In [26]:
import scanpy as sc
import numpy as np
import scipy.sparse as sp

# ---------- RNA ----------
rna = rna_adata.copy()
rna.layers["counts"] = rna.X.copy()

# Normalize + log1p
sc.pp.normalize_total(rna, target_sum=1e4)
sc.pp.log1p(rna)

# HVGs
sc.pp.highly_variable_genes(rna, n_top_genes=2000, flavor="seurat_v3")
rna = rna[:, rna.var["highly_variable"]].copy()
print("RNA (HVG log1p) shape:", rna.shape)

# Z-score per gene across cells (for Gaussian decoder)
sc.pp.scale(rna, max_value=10)
print("RNA scaled shape:", rna.shape)


# ---------- ADT ----------
adt = adt_adata.copy()
adt.layers["counts"] = adt.X.copy()

X = adt.layers["counts"].astype(float)
if sp.issparse(X):
    X = X.toarray()

eps = 1e-6
# CLR per cell
X_log = np.log1p(X + eps)
X_clr = X_log - X_log.mean(axis=1, keepdims=True)

# Then per-feature z-score across cells
mean_adt = X_clr.mean(axis=0, keepdims=True)
std_adt  = X_clr.std(axis=0, keepdims=True) + 1e-6
X_clr_z  = (X_clr - mean_adt) / std_adt

adt.X = X_clr_z.astype(np.float32)

print("ADT CLR+z shape:", adt.shape)


# ---------- ATAC ----------
atac = atac_adata.copy()

X_atac = atac.X.astype(np.float32)  # assume this is LSI already
mean_atac = X_atac.mean(axis=0, keepdims=True)
std_atac  = X_atac.std(axis=0, keepdims=True) + 1e-6
X_z = (X_atac - mean_atac) / std_atac

atac.X = X_z.astype(np.float32)
print("ATAC LSI-z shape:", atac.shape)


  return fn(*args_all, **kw)


RNA (HVG log1p) shape: (7421, 2000)


  return dispatch(args[0].__class__)(*args, **kw)


RNA scaled shape: (7421, 2000)
ADT CLR+z shape: (7421, 47)
ATAC LSI-z shape: (7421, 100)


#### Initialize model and data via dataloaders

In [27]:
adata_dict = {
    "rna": rna,
    "adt": adt,
    "atac": atac,
}


In [28]:
from univi.data import MultiModalDataset
from univi.config import UniVIConfig, ModalityConfig, TrainingConfig
from univi.models.univi import UniVIMultiModalVAE
from univi.trainer import UniVITrainer
'''
# ---------- UniVI config (Gaussian for all 3) ----------
univi_cfg = UniVIConfig(
    latent_dim=60,
    beta=100.0,
    gamma=120.0,
    encoder_dropout=0.0,
    decoder_dropout=0.0,
    encoder_batchnorm=True,
    decoder_batchnorm=False,
    kl_anneal_start=0,
    kl_anneal_end=0,
    align_anneal_start=0,
    align_anneal_end=0,
    modalities=[
        ModalityConfig(
            name="rna",
            input_dim=rna.n_vars,
            encoder_hidden=[512, 256],
            decoder_hidden=[256, 512],
            likelihood="gaussian",
        ),
        ModalityConfig(
            name="adt",
            input_dim=adt.n_vars,
            encoder_hidden=[128, 64],
            decoder_hidden=[64, 128],
            likelihood="gaussian",
        ),
        ModalityConfig(
            name="atac",
            input_dim=atac.n_vars,  # n_lsi (e.g. 50)
            encoder_hidden=[128, 64],
            decoder_hidden=[64, 128],
            likelihood="gaussian",
        ),
    ],
)

train_cfg = TrainingConfig(
    n_epochs=200,
    batch_size=256,
    lr=1e-3,
    weight_decay=1e-4,
    #device="cuda",   # or "cpu"
    device=device,
    log_every=10,
    grad_clip=5.0,
    num_workers=0,
    seed=42,
    early_stopping=True,
    patience=20,
    min_delta=0.0,
)
'''

'\n# ---------- UniVI config (Gaussian for all 3) ----------\nunivi_cfg = UniVIConfig(\n    latent_dim=60,\n    beta=100.0,\n    gamma=120.0,\n    encoder_dropout=0.0,\n    decoder_dropout=0.0,\n    encoder_batchnorm=True,\n    decoder_batchnorm=False,\n    kl_anneal_start=0,\n    kl_anneal_end=0,\n    align_anneal_start=0,\n    align_anneal_end=0,\n    modalities=[\n        ModalityConfig(\n            name="rna",\n            input_dim=rna.n_vars,\n            encoder_hidden=[512, 256],\n            decoder_hidden=[256, 512],\n            likelihood="gaussian",\n        ),\n        ModalityConfig(\n            name="adt",\n            input_dim=adt.n_vars,\n            encoder_hidden=[128, 64],\n            decoder_hidden=[64, 128],\n            likelihood="gaussian",\n        ),\n        ModalityConfig(\n            name="atac",\n            input_dim=atac.n_vars,  # n_lsi (e.g. 50)\n            encoder_hidden=[128, 64],\n            decoder_hidden=[64, 128],\n            likelihood

In [176]:
from univi.config import UniVIConfig, ModalityConfig, TrainingConfig

univi_cfg = UniVIConfig(
    latent_dim=120,
    #beta=150.0,
    beta=1.0,
    gamma=30.0,
    encoder_dropout=0.0,
    decoder_dropout=0.0,
    encoder_batchnorm=True,
    decoder_batchnorm=False,
    kl_anneal_start=0,
    kl_anneal_end=0,   # ramp KL up over first 50 epochs
    align_anneal_start=0,  # let reconstructions stabilize a bit first
    align_anneal_end=0,
    modalities=[
        ModalityConfig(
            name="rna",
            input_dim=rna.n_vars,
            encoder_hidden=[512, 256],
            decoder_hidden=[256, 512],
            likelihood="gaussian",
        ),
        ModalityConfig(
            name="adt",
            input_dim=adt.n_vars,
            encoder_hidden=[128, 64],
            decoder_hidden=[64, 128],
            likelihood="gaussian",
        ),
        ModalityConfig(
            name="atac",
            input_dim=atac.n_vars,  # n_lsi
            encoder_hidden=[256, 128],
            decoder_hidden=[128, 256],
            likelihood="gaussian",
        ),
    ],
)

train_cfg = TrainingConfig(
    n_epochs=300,
    batch_size=256,
    lr=1e-3,
    weight_decay=1e-4,
    device=device,      # "cuda"
    log_every=10,
    grad_clip=5.0,
    num_workers=0,
    seed=42,
    early_stopping=True,
    patience=35,
    min_delta=0.0,
)


In [177]:
from torch.utils.data import DataLoader, Subset

dataset = MultiModalDataset(
    adata_dict=adata_dict,
    X_key="X",
    device=train_cfg.device,
)

n_cells = dataset.n_cells
indices = np.arange(n_cells)
rng = np.random.default_rng(42)
rng.shuffle(indices)

frac_train = 0.8
frac_val   = 0.1
n_train = int(frac_train * n_cells)
n_val   = int(frac_val * n_cells)

train_idx = indices[:n_train]
val_idx   = indices[n_train:n_train + n_val]
test_idx  = indices[n_train + n_val:]

train_ds = Subset(dataset, train_idx)
val_ds   = Subset(dataset, val_idx)

train_loader = DataLoader(
    train_ds,
    batch_size=train_cfg.batch_size,
    shuffle=True,
    num_workers=train_cfg.num_workers,
)

val_loader = DataLoader(
    val_ds,
    batch_size=train_cfg.batch_size,
    shuffle=False,
    num_workers=train_cfg.num_workers,
)

model = UniVIMultiModalVAE(univi_cfg).to(train_cfg.device)
trainer = UniVITrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    train_cfg=train_cfg,
    device=train_cfg.device,
)


[2025-11-20 10:44:24,797] [UniVITrainer] [INFO] TrainingConfig:
2025-11-20 10:44:24 - INFO - TrainingConfig:
[2025-11-20 10:44:24,803] [UniVITrainer] [INFO]   n_epochs: 300
2025-11-20 10:44:24 - INFO -   n_epochs: 300
[2025-11-20 10:44:24,804] [UniVITrainer] [INFO]   batch_size: 256
2025-11-20 10:44:24 - INFO -   batch_size: 256
[2025-11-20 10:44:24,805] [UniVITrainer] [INFO]   lr: 0.001
2025-11-20 10:44:24 - INFO -   lr: 0.001
[2025-11-20 10:44:24,806] [UniVITrainer] [INFO]   weight_decay: 0.0001
2025-11-20 10:44:24 - INFO -   weight_decay: 0.0001
[2025-11-20 10:44:24,807] [UniVITrainer] [INFO]   device: cuda
2025-11-20 10:44:24 - INFO -   device: cuda
[2025-11-20 10:44:24,809] [UniVITrainer] [INFO]   log_every: 10
2025-11-20 10:44:24 - INFO -   log_every: 10
[2025-11-20 10:44:24,810] [UniVITrainer] [INFO]   grad_clip: 5.0
2025-11-20 10:44:24 - INFO -   grad_clip: 5.0
[2025-11-20 10:44:24,811] [UniVITrainer] [INFO]   num_workers: 0
2025-11-20 10:44:24 - INFO -   num_workers: 0
[2025-1

#### Train model

In [None]:
# ---------- train ----------
history = trainer.fit()


Training UniVI:   0%|          | 0/300 [00:00<?, ?it/s]

[2025-11-20 10:44:40,347] [UniVITrainer] [INFO] [Epoch 001] Train loss: 854.7620 (beta=1.000, gamma=30.000)
2025-11-20 10:44:40 - INFO - [Epoch 001] Train loss: 854.7620 (beta=1.000, gamma=30.000)
[2025-11-20 10:44:41,599] [UniVITrainer] [INFO] [Epoch 001] Val loss: 778.2465 (beta=1.000, gamma=30.000)
2025-11-20 10:44:41 - INFO - [Epoch 001] Val loss: 778.2465 (beta=1.000, gamma=30.000)
[2025-11-20 10:44:41,792] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 778.2465
2025-11-20 10:44:41 - INFO - [Epoch 001] New best val loss: 778.2465
[2025-11-20 10:44:52,062] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 730.9342
2025-11-20 10:44:52 - INFO - [Epoch 002] New best val loss: 730.9342
[2025-11-20 10:45:01,764] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 714.0325
2025-11-20 10:45:01 - INFO - [Epoch 003] New best val loss: 714.0325
[2025-11-20 10:45:11,255] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 708.4523
2025-11-20 10:45:11 - INFO - [Epoch 004] New best v

2025-11-20 10:54:08 - INFO - [Epoch 060] Val loss: 651.2702 (beta=1.000, gamma=30.000)
[2025-11-20 10:54:08,995] [UniVITrainer] [INFO] [Epoch 060] New best val loss: 651.2702
2025-11-20 10:54:08 - INFO - [Epoch 060] New best val loss: 651.2702
[2025-11-20 10:54:47,931] [UniVITrainer] [INFO] [Epoch 064] New best val loss: 650.4043
2025-11-20 10:54:47 - INFO - [Epoch 064] New best val loss: 650.4043
[2025-11-20 10:55:45,038] [UniVITrainer] [INFO] [Epoch 070] Train loss: 643.5834 (beta=1.000, gamma=30.000)
2025-11-20 10:55:45 - INFO - [Epoch 070] Train loss: 643.5834 (beta=1.000, gamma=30.000)
[2025-11-20 10:55:46,128] [UniVITrainer] [INFO] [Epoch 070] Val loss: 650.8521 (beta=1.000, gamma=30.000)
2025-11-20 10:55:46 - INFO - [Epoch 070] Val loss: 650.8521 (beta=1.000, gamma=30.000)
[2025-11-20 10:56:25,667] [UniVITrainer] [INFO] [Epoch 074] New best val loss: 649.1459
2025-11-20 10:56:25 - INFO - [Epoch 074] New best val loss: 649.1459
[2025-11-20 10:57:23,133] [UniVITrainer] [INFO] [Epo

In [None]:
import matplotlib.pyplot as plt

# Quick training curves
fig, ax = plt.subplots()
ax.plot(history["train_loss"], label="train")
ax.plot(history["val_loss"], label="val")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.set_title("UniVI Multiome training curves")
ax.legend()
plt.tight_layout()
plt.show()

fig, ax = plt.subplots()
ax.plot(history["beta"], label="beta")
ax.plot(history["gamma"], label="gamma")
ax.set_xlabel("Epoch")
ax.set_ylabel("Weight")
ax.set_title("KL / alignment annealing")
ax.legend()
plt.tight_layout()
plt.show()


In [None]:
from dataclasses import asdict
import torch

os.makedirs("../saved_models", exist_ok=True)

# trainer.model already has the best weights (because we restored best_state_dict)
ckpt_path = "../saved_models/univi_tea_seq_beta_1_gamma_40_latent_dims_120.pt"
torch.save(
    {
        "state_dict": trainer.model.state_dict(),
        "univi_cfg": asdict(univi_cfg),
        "best_epoch": trainer.best_epoch,
        "best_val_loss": trainer.best_val_loss,
    },
    ckpt_path,
)
print("Saved best model to:", ckpt_path)


#### Evaluate model

In [None]:
import torch
from univi.config import UniVIConfig, ModalityConfig
from univi.models.univi import UniVIMultiModalVAE

#device = "cpu"  # or "cuda" if available
device = 'cuda'

ckpt = torch.load(
    "../saved_models/univi_tea_seq_beta_1_gamma_40_latent_dims_120.pt",
    map_location=device,
)

# ---- Rebuild UniVIConfig, making sure modalities are ModalityConfig objects ----
cfg_dict = ckpt["univi_cfg"]

# If this is an OmegaConf object or similar, make sure it's a plain dict
try:
    from omegaconf import DictConfig, OmegaConf
    if isinstance(cfg_dict, DictConfig):
        cfg_dict = OmegaConf.to_container(cfg_dict, resolve=True)
except ImportError:
    pass

# Now rehydrate each modality
modalities = [ModalityConfig(**m) for m in cfg_dict["modalities"]]
cfg_dict = {**cfg_dict, "modalities": modalities}

univi_cfg_loaded = UniVIConfig(**cfg_dict)

# ---- Rebuild model + load weights ----
model_loaded = UniVIMultiModalVAE(univi_cfg_loaded).to(device)
model_loaded.load_state_dict(ckpt["state_dict"])

print("Best epoch was:", ckpt.get("best_epoch"), "val loss =", ckpt.get("best_val_loss"))


In [None]:
z_rna  = trainer.encode_modality(rna,  modality="rna")
z_adt  = trainer.encode_modality(adt,  modality="adt")
z_atac = trainer.encode_modality(atac, modality="atac")

rna.obsm["X_univi"]  = z_rna
adt.obsm["X_univi"]  = z_adt
atac.obsm["X_univi"] = z_atac


In [None]:
print(history)


In [None]:
# These should be the *same* cells across modalities
assert np.array_equal(rna.obs_names, adt.obs_names)
assert np.array_equal(rna.obs_names, atac.obs_names)


In [None]:
# ------------------------------
# Build train / val / test adatas
# ------------------------------

rna_train_adata  = rna[train_idx].copy()
rna_val_adata    = rna[val_idx].copy()
rna_test_adata   = rna[test_idx].copy()

adt_train_adata  = adt[train_idx].copy()
adt_val_adata    = adt[val_idx].copy()
adt_test_adata   = adt[test_idx].copy()

atac_train_adata = atac[train_idx].copy()
atac_val_adata   = atac[val_idx].copy()
atac_test_adata  = atac[test_idx].copy()


In [None]:
print(rna_test_adata)
print(adt_test_adata)
print(atac_test_adata)


In [None]:
import scanpy as sc

# Make plots a bit bigger / nicer
sc.settings.set_figure_params(dpi=200, figsize=(16, 14))


In [None]:
# ============================
# UniVI TEA-seq evaluation (RNA / ADT / ATAC) – label-free, tri-modal
# ============================
import os
import numpy as np
import torch
import scipy.sparse as sp
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc

from sklearn.metrics import silhouette_score
from sklearn.neighbors import NearestNeighbors

from univi import evaluation as univi_eval
from univi import plotting as univi_plot  # not strictly needed, but left for convenience

# -----------------------------------------
# CONFIG
# -----------------------------------------
FIGDIR = "../figures/teaseq_univi_tri_modal_eval"
os.makedirs(FIGDIR, exist_ok=True)

device = train_cfg.device  # e.g. "cuda" or "cpu"

# -----------------------------------------
# 0. Sanity checks
# -----------------------------------------
assert (
    rna_test_adata.n_obs == adt_test_adata.n_obs == atac_test_adata.n_obs
), "RNA / ADT / ATAC TEST sets must have same #cells"
assert np.array_equal(rna_test_adata.obs_names, adt_test_adata.obs_names), (
    "RNA and ADT obs_names must match 1:1 for pairwise metrics."
)
assert np.array_equal(rna_test_adata.obs_names, atac_test_adata.obs_names), (
    "RNA and ATAC obs_names must match 1:1 for pairwise metrics."
)

print(f"Test cells: {rna_test_adata.n_obs}")

# -----------------------------------------
# 1. Encode latent embeddings for test sets
# -----------------------------------------
print("\nEncoding test sets into UniVI latent space...")

z_rna  = univi_eval.encode_adata(model, rna_test_adata,  modality="rna",  device=device)
z_adt  = univi_eval.encode_adata(model, adt_test_adata,  modality="adt",  device=device)
z_atac = univi_eval.encode_adata(model, atac_test_adata, modality="atac", device=device)

rna_test_adata.obsm["X_univi"]  = z_rna
adt_test_adata.obsm["X_univi"]  = z_adt
atac_test_adata.obsm["X_univi"] = z_atac

print("Latent shapes (test):")
print("  RNA :", z_rna.shape)
print("  ADT :", z_adt.shape)
print("  ATAC:", z_atac.shape)

# -----------------------------------------
# 2. FOSCTTM (pairwise alignment)
# -----------------------------------------
print("\nComputing FOSCTTM for each modality pair (lower = better)...")
fos_rna_adt  = univi_eval.compute_foscttm(z_rna,  z_adt)
fos_rna_atac = univi_eval.compute_foscttm(z_rna,  z_atac)
fos_adt_atac = univi_eval.compute_foscttm(z_adt,  z_atac)

print(f"  RNA  vs ADT : {fos_rna_adt:.4f}")
print(f"  RNA  vs ATAC: {fos_rna_atac:.4f}")
print(f"  ADT  vs ATAC: {fos_adt_atac:.4f}")

plt.figure(figsize=(4, 4))
pairs = ["RNA–ADT", "RNA–ATAC", "ADT–ATAC"]
vals = [fos_rna_adt, fos_rna_atac, fos_adt_atac]
sns.barplot(x=pairs, y=vals)
plt.ylabel("FOSCTTM (mean)")
plt.title("Tri-modal FOSCTTM (lower = better)")
#plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, "foscttm_barplot.png"))
plt.show()
plt.close()

# -----------------------------------------
# 3. Modality mixing (all three modalities)
# -----------------------------------------
Z_joint = np.concatenate([z_rna, z_adt, z_atac], axis=0)
modality_labels = np.array(
    ["rna"]  * z_rna.shape[0]
    + ["adt"]  * z_adt.shape[0]
    + ["atac"] * z_atac.shape[0]
)

mixing_score = univi_eval.compute_modality_mixing(
    Z_joint,
    modality_labels,
    k=20,
)
print(f"\nGlobal modality mixing score (RNA/ADT/ATAC, k=20): {mixing_score:.3f}")

# kNN modality composition heatmap
print("Computing kNN modality composition...")
k = 30
nn = NearestNeighbors(n_neighbors=k + 1)
nn.fit(Z_joint)
_, idx = nn.kneighbors(Z_joint)

idx_neighbors = idx[:, 1:]  # drop self
neighbor_mods = modality_labels[idx_neighbors]

modalities = np.array(["rna", "adt", "atac"])
comp_matrix = np.zeros((len(modalities), len(modalities)))  # row = center, col = neighbor

for i, m_center in enumerate(modalities):
    mask_center = modality_labels == m_center
    neigh_for_center = neighbor_mods[mask_center].reshape(-1)
    for j, m_neigh in enumerate(modalities):
        comp_matrix[i, j] = (neigh_for_center == m_neigh).mean()

same_mod_frac = (neighbor_mods == modality_labels[:, None]).mean()
print(f"  Fraction of neighbors with same modality (k={k}): {same_mod_frac:.3f}")

plt.figure(figsize=(5, 4))
sns.heatmap(
    comp_matrix,
    annot=True,
    fmt=".2f",
    xticklabels=modalities,
    yticklabels=modalities,
    cmap="viridis",
)
plt.xlabel("Neighbor modality")
plt.ylabel("Center modality")
plt.title(f"kNN modality composition (k={k})")
plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, "knn_modality_composition.png"))
plt.show()
plt.close()

# -----------------------------------------
# 4. UMAP on UniVI latent (tri-modal)
# -----------------------------------------
print("\nBuilding tri-modal UMAP on UniVI latent...")

# Tag each test set with modality
rna_tmp  = rna_test_adata.copy()
adt_tmp  = adt_test_adata.copy()
atac_tmp = atac_test_adata.copy()

rna_tmp.obs["univi_source"]  = "rna"
adt_tmp.obs["univi_source"]  = "adt"
atac_tmp.obs["univi_source"] = "atac"

combined = rna_tmp.concatenate(
    adt_tmp,
    atac_tmp,
    join="outer",
    batch_key="univi_batch",
    batch_categories=["rna", "adt", "atac"],
    index_unique=None,
)

# ensure latent is correctly stacked
combined.obsm["X_univi"] = np.vstack([
    rna_test_adata.obsm["X_univi"],
    adt_test_adata.obsm["X_univi"],
    atac_test_adata.obsm["X_univi"],
])

# neighbors / UMAP / Leiden
sc.pp.neighbors(combined, use_rep="X_univi", n_neighbors=30)
sc.tl.umap(combined)
sc.tl.leiden(combined, key_added="univi_leiden", resolution=0.75)

# UMAP colored by modality
sc.pl.umap(
    combined,
    color="univi_source",
    size=65,
    alpha=0.8,
    show=False,
)
plt.savefig(os.path.join(FIGDIR, "umap_tri_modal_modality.png"), bbox_inches="tight")
plt.show()
plt.close()

# UMAP colored by Leiden clusters (pseudo-clusters)
sc.pl.umap(
    combined,
    color="univi_leiden",
    size=65,
    alpha=0.8,
    show=False,
)
plt.savefig(os.path.join(FIGDIR, "umap_tri_modal_leiden.png"), bbox_inches="tight")
plt.show()
plt.close()

# Per-modality UMAPs, colored by Leiden
for mod in ["rna", "adt", "atac"]:
    sub = combined[combined.obs["univi_source"] == mod].copy()
    sc.pl.umap(
        sub,
        color="univi_leiden",
        size=65,
        alpha=0.8,
        show=False,
    )
    plt.savefig(
        os.path.join(FIGDIR, f"umap_{mod}_only_leiden.png"),
        bbox_inches="tight",
    )
    plt.show()
    plt.close()

# -----------------------------------------
# 5. Latent geometry diagnostics
# -----------------------------------------
print("\nLatent geometry diagnostics...")

def latent_norms(z, label):
    norms = np.linalg.norm(z, axis=1)
    return norms, np.repeat(label, len(norms))

norm_rna,  lab_rna  = latent_norms(z_rna,  "RNA")
norm_adt,  lab_adt  = latent_norms(z_adt,  "ADT")
norm_atac, lab_atac = latent_norms(z_atac, "ATAC")

norms_all = np.concatenate([norm_rna, norm_adt, norm_atac])
labs_all  = np.concatenate([lab_rna, lab_adt, lab_atac])

plt.figure(figsize=(5, 4))
sns.violinplot(x=labs_all, y=norms_all, inner="box")
plt.ylabel("‖z‖ (L2 norm)")
plt.xlabel("Modality")
plt.title("Latent L2-norm distribution by modality")
plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, "latent_norms_by_modality.png"))
plt.show()
plt.close()

# Latent correlation heatmap of RNA latent dims
corr_latent = np.corrcoef(z_rna, rowvar=False)
plt.figure(figsize=(6, 5))
sns.heatmap(corr_latent, cmap="vlag", center=0)
plt.title("RNA latent dimension correlation (test set)")
plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, "latent_corr_heatmap_rna.png"))
plt.show()
plt.close()

# Silhouette score on modality (lower = better mixing)
if len(np.unique(modality_labels)) > 1:
    sil_mod = silhouette_score(Z_joint, modality_labels)
else:
    sil_mod = np.nan
print(f"Silhouette (modality) on UniVI latent: {sil_mod:.3f}")

# -----------------------------------------
# 6. Local modality entropy (kNN) + UMAP
# -----------------------------------------
print("\nComputing local modality entropy...")

n_neighbors_local = 30
nn_local = NearestNeighbors(n_neighbors=n_neighbors_local, metric="euclidean")
nn_local.fit(Z_joint)
_, idx_local = nn_local.kneighbors(Z_joint)

mods = modality_labels
local_modality_entropy = []

for i in range(Z_joint.shape[0]):
    neigh = idx_local[i, 1:]  # drop self
    neigh_mod = mods[neigh]

    # empirical distribution over modalities
    ent = 0.0
    for m in modalities:
        p = (neigh_mod == m).mean()
        if p > 0:
            ent -= p * np.log2(p)
    local_modality_entropy.append(ent)

local_modality_entropy = np.asarray(local_modality_entropy)
combined.obs["local_modality_entropy"] = local_modality_entropy

print(f"Mean local modality entropy (k={n_neighbors_local}): {local_modality_entropy.mean():.3f}")

plt.figure(figsize=(5, 4))
plt.hist(local_modality_entropy, bins=30)
plt.xlabel("Local modality entropy (bits)")
plt.ylabel("Cells")
plt.title("kNN modality entropy")
plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, "local_modality_entropy_hist.png"))
plt.show()
plt.close()

# UMAP colored by local modality entropy
sc.pl.umap(
    combined,
    color="local_modality_entropy",
    size=65,
    alpha=0.8,
    show=False,
)
plt.savefig(os.path.join(FIGDIR, "umap_local_modality_entropy.png"), bbox_inches="tight")
plt.show()
plt.close()

# -----------------------------------------
# 7. Pairwise matching metrics (top-k) for all modality pairs
# -----------------------------------------
print("\nPairwise matching metrics (top-1 / top-5 / top-10)...")

def topk_matching(z_src, z_tgt, pair_name: str, k_match: int = 10):
    nn = NearestNeighbors(n_neighbors=k_match, metric="euclidean")
    nn.fit(z_tgt)
    _, idx_knn = nn.kneighbors(z_src)

    true_idx = np.arange(z_src.shape[0])
    top1_hits  = (idx_knn[:, 0] == true_idx)
    top5_hits  = (idx_knn[:, :5] == true_idx[:, None]).any(axis=1)
    top10_hits = (idx_knn[:, :10] == true_idx[:, None]).any(axis=1)

    print(f"  {pair_name}:")
    print(f"    Top-1 accuracy:  {top1_hits.mean():.3f}")
    print(f"    Top-5 accuracy:  {top5_hits.mean():.3f}")
    print(f"    Top-10 accuracy: {top10_hits.mean():.3f}")

    plt.figure(figsize=(5, 4))
    plt.bar(
        ["Top-1", "Top-5", "Top-10"],
        [top1_hits.mean(), top5_hits.mean(), top10_hits.mean()],
    )
    plt.ylabel("Fraction of correctly matched pairs")
    plt.title(f"Cross-modal matching accuracy ({pair_name})")
    plt.tight_layout()
    fname = f"matching_{pair_name.replace(' ', '_').replace('→','to')}.png"
    plt.savefig(os.path.join(FIGDIR, fname))
    plt.show()
    plt.close()

topk_matching(z_rna,  z_adt,  "RNA→ADT")
topk_matching(z_adt,  z_rna,  "ADT→RNA")
topk_matching(z_rna,  z_atac, "RNA→ATAC")
topk_matching(z_atac, z_rna,  "ATAC→RNA")
topk_matching(z_adt,  z_atac, "ADT→ATAC")
topk_matching(z_atac, z_adt,  "ATAC→ADT")

# -----------------------------------------
# 8. Cross-modal reconstruction metrics
# -----------------------------------------
def _to_dense(X):
    return X.toarray() if sp.issparse(X) else np.asarray(X)

def cross_modal_metrics(
    model,
    src_adata,
    tgt_adata,
    src_mod: str,
    tgt_mod: str,
    name_prefix: str,
    device: str,
):
    Xhat_tgt = univi_eval.cross_modal_predict(
        model,
        adata_src=src_adata,
        src_mod=src_mod,
        tgt_mod=tgt_mod,
        device=device,
        batch_size=512,
    )

    X_tgt = _to_dense(tgt_adata.X)

    mse_feat  = univi_eval.mse_per_feature(X_tgt, Xhat_tgt)
    corr_feat = univi_eval.pearson_corr_per_feature(X_tgt, Xhat_tgt)

    print(f"\nCross-modal reconstruction: {src_mod} → {tgt_mod}")
    print(f"  Mean feature MSE: {mse_feat.mean():.4f}")
    print(f"  Mean feature Pearson r: {corr_feat.mean():.3f}")

    # Histogram of per-feature correlation
    plt.figure(figsize=(5, 4))
    sns.histplot(corr_feat, bins=40, kde=False)
    plt.xlabel("Per-feature Pearson r")
    plt.ylabel("Count")
    plt.title(f"{src_mod} → {tgt_mod}: feature-wise correlation")
    plt.tight_layout()
    plt.savefig(os.path.join(FIGDIR, f"{name_prefix}_corr_hist.png"))
    plt.show()
    plt.close()

    # Histogram of per-feature MSE
    plt.figure(figsize=(5, 4))
    sns.histplot(mse_feat, bins=40, kde=False)
    plt.xlabel("Per-feature MSE")
    plt.ylabel("Count")
    plt.title(f"{src_mod} → {tgt_mod}: feature-wise MSE")
    plt.tight_layout()
    plt.savefig(os.path.join(FIGDIR, f"{name_prefix}_mse_hist.png"))
    plt.show()
    plt.close()

    return mse_feat, corr_feat

# Evaluate key directions on TEST set
_ = cross_modal_metrics(model, rna_test_adata, adt_test_adata,
                        src_mod="rna", tgt_mod="adt",
                        name_prefix="RNA_to_ADT_test", device=device)

_ = cross_modal_metrics(model, rna_test_adata, atac_test_adata,
                        src_mod="rna", tgt_mod="atac",
                        name_prefix="RNA_to_ATAC_test", device=device)

_ = cross_modal_metrics(model, adt_test_adata, rna_test_adata,
                        src_mod="adt", tgt_mod="rna",
                        name_prefix="ADT_to_RNA_test", device=device)

_ = cross_modal_metrics(model, atac_test_adata, rna_test_adata,
                        src_mod="atac", tgt_mod="rna",
                        name_prefix="ATAC_to_RNA_test", device=device)

# -----------------------------------------
# 9. Denoising with decoders (on "unused" cells)
# -----------------------------------------
print("\nDenoising on *_unused sets (if present)...")

for adata, mod, tag in [
    (locals().get("rna_unused",  None), "rna",  "rna_unused"),
    (locals().get("adt_unused",  None), "adt",  "adt_unused"),
    (locals().get("atac_unused", None), "atac", "atac_unused"),
]:
    if adata is None or adata.n_obs == 0:
        print(f"  Skipping {tag} ({mod}) – no cells.")
        continue

    print(f"  Denoising {tag} ({mod})...")
    univi_eval.denoise_adata(
        model, adata, modality=mod,
        device=device, batch_size=512,
        out_layer="univi_denoised",
    )
    X_raw = _to_dense(adata.X)
    X_den = _to_dense(adata.layers["univi_denoised"])
    per_cell_mse = ((X_raw - X_den) ** 2).mean(axis=1)

    plt.figure(figsize=(5, 4))
    sns.histplot(per_cell_mse, bins=40)
    plt.xlabel("Per-cell MSE (raw vs denoised)")
    plt.ylabel("Count")
    plt.title(f"Denoising quality: {tag} ({mod})")
    plt.tight_layout()
    plt.savefig(os.path.join(FIGDIR, f"denoise_{tag}_{mod}.png"))
    plt.show()
    plt.close()

print("\nNo curated celltypes; all metrics are label-free (modality-based only).")


In [None]:
# ============================
# 10. Summarize TEA-seq metrics & save as JSON
# ============================
import json

teaseq_metrics = {
    # Pairwise FOSCTTM
    "foscttm_rna_adt": float(fos_rna_adt),
    "foscttm_rna_atac": float(fos_rna_atac),
    "foscttm_adt_atac": float(fos_adt_atac),

    # Global mixing
    "modality_mixing_k20": float(mixing_score),
    "same_modality_neighbor_frac_k20": float(same_mod_frac),

    # Latent geometry
    "silhouette_modality": float(sil_mod) if not np.isnan(sil_mod) else None,
    "mean_local_modality_entropy_k20": float(local_modality_entropy.mean()),
    "median_local_modality_entropy_k20": float(np.median(local_modality_entropy)),

    # Dataset sizes
    "n_cells_test": int(rna_test_adata.n_obs),
    "n_genes_rna": int(rna_test_adata.n_vars),
    "n_features_adt": int(adt_test_adata.n_vars),
    "n_features_atac": int(atac_test_adata.n_vars),
}

'''
metrics_path = os.path.join(FIGDIR, "teaseq_univi_metrics.json")
with open(metrics_path, "w") as f:
    json.dump(teaseq_metrics, f, indent=2)

print(f"\n[TEA-seq] Saved benchmark metrics to: {metrics_path}")
'''

# ============================
# 11. kNN distance diagnostics: same vs cross-modality neighbors
# ============================
print("\nComputing kNN distance distributions (same vs cross-modality)...")

from sklearn.neighbors import NearestNeighbors

k_dist = 30
nn_dist = NearestNeighbors(n_neighbors=k_dist + 1, metric="euclidean")
nn_dist.fit(Z_joint)
dists_all, idx_all = nn_dist.kneighbors(Z_joint)

# drop self
dists_neighbors = dists_all[:, 1:]              # (n_cells, k_dist)
idx_neighbors_dist = idx_all[:, 1:]             # already had neighbor_mods from earlier
flat_dists = dists_neighbors.reshape(-1)

# Center & neighbor modalities (flattened to per-edge view)
center_mods_flat = np.repeat(modality_labels, k_dist)
neighbor_mods_flat = neighbor_mods.reshape(-1)

same_mod_edge = neighbor_mods_flat == center_mods_flat

plt.figure(figsize=(6, 4))
sns.kdeplot(flat_dists[same_mod_edge], label="same modality", fill=True, alpha=0.6)
sns.kdeplot(flat_dists[~same_mod_edge], label="different modality", fill=True, alpha=0.6)
plt.xlabel("Euclidean distance in UniVI latent")
plt.ylabel("Density")
plt.title(f"kNN distance distribution (k={k_dist})")
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, "knn_distance_same_vs_cross_modality.png"), dpi=200)
plt.show()
plt.close()


# ============================
# 12. Cross-modal neighbor fraction per cell (and UMAP)
# ============================
print("\nComputing per-cell cross-modal neighbor fraction...")

# fraction of neighbors NOT sharing the cell's modality
cross_mod_neighbor_frac = 1.0 - (neighbor_mods == modality_labels[:, None]).mean(axis=1)
combined.obs["cross_mod_neighbor_frac"] = cross_mod_neighbor_frac

plt.figure(figsize=(5, 4))
sns.histplot(cross_mod_neighbor_frac, bins=30)
plt.xlabel("Fraction of neighbors with different modality")
plt.ylabel("Cells")
plt.title(f"Cross-modal neighbor fraction (k={k})")
plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, "cross_mod_neighbor_fraction_hist.png"), dpi=200)
plt.show()
plt.close()

# UMAP colored by cross-modal neighbor fraction
sc.pl.umap(
    combined,
    color="cross_mod_neighbor_frac",
    size=65,
    alpha=0.8,
    cmap="viridis",
    show=False,
)
plt.title("UMAP – cross-modal neighbor fraction")
plt.savefig(os.path.join(FIGDIR, "umap_cross_mod_neighbor_fraction.png"),
            bbox_inches="tight", dpi=200)
plt.show()
plt.close()


# ============================
# 13. Per-cluster modality composition (Leiden clusters)
# ============================
print("\nComputing modality composition per Leiden cluster...")

if "univi_leiden" in combined.obs.columns:
    comp_ct = pd.crosstab(combined.obs["univi_leiden"], combined.obs["univi_source"])
    comp_prop = comp_ct.div(comp_ct.sum(axis=1), axis=0)  # row-normalize

    plt.figure(figsize=(7, 5))
    comp_prop.sort_index().plot(
        kind="bar",
        stacked=True,
        width=0.9,
        colormap="tab10",
    )
    plt.xlabel("UniVI Leiden cluster")
    plt.ylabel("Fraction of cells")
    plt.title("Modality composition per UniVI Leiden cluster")
    plt.legend(title="Modality", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()
    plt.savefig(os.path.join(FIGDIR, "cluster_modality_composition_stacked_bar.png"),
                dpi=200)
    plt.show()
    plt.close()
else:
    print("  WARNING: 'univi_leiden' not found in combined.obs; skipping cluster composition plot.")


# ============================
# 14. Feature-level cross-modal scatter plots (RNA → ADT / ATAC)
# ============================
print("\nFeature-level scatter plots for cross-modal prediction (RNA→ADT / RNA→ATAC)...")

# Example: RNA → ADT for a small subset of cells & selected markers
# (adjust marker names to what you actually have in adt_test_adata.var_names)

n_scatter_cells = min(5000, adt_test_adata.n_obs)  # subsample if huge
cell_idx = np.random.default_rng(42).choice(adt_test_adata.n_obs, n_scatter_cells, replace=False)

# recompute predictions just for this subset to avoid reusing big arrays
Xhat_adt_sub = univi_eval.cross_modal_predict(
    model,
    adata_src=rna_test_adata[cell_idx],
    src_mod="rna",
    tgt_mod="adt",
    device=device,
    batch_size=512,
)
X_adt_sub = _to_dense(adt_test_adata[cell_idx].X)

adt_markers_to_plot = [
    # put your favorite ADT markers here
    # e.g. "CD3", "CD4", "CD8A", "CD56", ...
]

for marker in adt_markers_to_plot:
    if marker not in adt_test_adata.var_names:
        print(f"  [RNA→ADT] Marker '{marker}' not found in adt_test_adata.var_names; skipping.")
        continue

    j = np.where(adt_test_adata.var_names == marker)[0][0]
    y_true = X_adt_sub[:, j]
    y_pred = Xhat_adt_sub[:, j]

    plt.figure(figsize=(4.5, 4))
    plt.hexbin(y_true, y_pred, gridsize=50, mincnt=1)
    plt.xlabel(f"True ADT ({marker})")
    plt.ylabel(f"Predicted ADT ({marker})")
    plt.title(f"RNA→ADT prediction for {marker}")
    plt.tight_layout()
    fname = f"scatter_RNA_to_ADT_{marker}.png".replace(" ", "_")
    plt.savefig(os.path.join(FIGDIR, fname), dpi=200)
    plt.show()
    plt.close()


# Example: RNA → ATAC for a few peaks / gene-body features (if named)
# (This is more exploratory since ATAC features are often peaks; choose a few named ones if available.)

n_scatter_cells_atac = min(5000, atac_test_adata.n_obs)
cell_idx_atac = np.random.default_rng(123).choice(atac_test_adata.n_obs, n_scatter_cells_atac, replace=False)

Xhat_atac_sub = univi_eval.cross_modal_predict(
    model,
    adata_src=rna_test_adata[cell_idx_atac],
    src_mod="rna",
    tgt_mod="atac",
    device=device,
    batch_size=512,
)
X_atac_sub = _to_dense(atac_test_adata[cell_idx_atac].X)

# If you have named ATAC features (e.g. gene body aggregates), you can list them here.
atac_features_to_plot = [
    # e.g. "TNFRSF4_body", "IFNG_enh", ...
]

for feat in atac_features_to_plot:
    if feat not in atac_test_adata.var_names:
        print(f"  [RNA→ATAC] Feature '{feat}' not found in atac_test_adata.var_names; skipping.")
        continue

    j = np.where(atac_test_adata.var_names == feat)[0][0]
    y_true = X_atac_sub[:, j]
    y_pred = Xhat_atac_sub[:, j]

    plt.figure(figsize=(4.5, 4))
    plt.hexbin(y_true, y_pred, gridsize=50, mincnt=1)
    plt.xlabel(f"True ATAC ({feat})")
    plt.ylabel(f"Predicted ATAC ({feat})")
    plt.title(f"RNA→ATAC prediction for {feat}")
    plt.tight_layout()
    fname = f"scatter_RNA_to_ATAC_{feat}.png".replace(" ", "_")
    plt.savefig(os.path.join(FIGDIR, fname), dpi=200)
    plt.show()
    plt.close()
