# UniVI TEA-seq tri-modal data integration demonstration/tutorial - Added hyperparameter optimization code

Andrew Ashford, Pathways + Omics Group, Oregon Health & Science University - 11/18/2025

This Jupyter Notebook will be used to outline the training steps for a UniVI model using human PBMC TEA-seq tri-modal data.


#### Import modules

In [1]:
import os
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import scipy.sparse as sp
import scanpy as sc
import anndata as ad

from sklearn.decomposition import TruncatedSVD
from sklearn.preprocessing import normalize

import snapatac2 as snap


In [2]:
# -------------------------
# 0. Wire up package import
# -------------------------
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)

from univi import (
    UniVIMultiModalVAE,
    ModalityConfig,
    UniVIConfig,
    TrainingConfig,
    matching,
)
from univi.data import MultiModalDataset
from univi.trainer import UniVITrainer


In [3]:
import torch
print("Torch:", torch.__version__)
print("torch.version.cuda:", torch.version.cuda)
print("CUDA available:", torch.cuda.is_available())

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)


Torch: 2.9.1+cu128
torch.version.cuda: 12.8
CUDA available: True
Using device: cuda


#### Read in and preprocess data as needed

In [4]:
data_dir = Path("../data/TEA-seq_data")
prefix   = "GSM5123953_X066-MP0C1W5_leukopak_perm-cells_tea"

rna_h5         = data_dir / f"{prefix}_200M_cellranger-arc_filtered_feature_bc_matrix.h5"
adt_counts_csv = data_dir / f"{prefix}_48M_adt_counts.csv.gz"
frag_tsv       = data_dir / f"{prefix}_200M_atac_filtered_fragments.tsv.gz"
atac_meta_csv  = data_dir / f"{prefix}_200M_atac_filtered_metadata.csv.gz"


In [5]:
# ----------------------
# 1) RNA
# ----------------------
print("Reading RNA (ARC filtered_feature_bc_matrix.h5)...")
m = sc.read_10x_h5(rna_h5)
print(m)

print("Feature types:", m.var["feature_types"].unique())

rna_adata = m.copy()
rna_adata.var_names_make_unique()
print("Initial RNA shape:", rna_adata.shape)


Reading RNA (ARC filtered_feature_bc_matrix.h5)...


  utils.warn_names_duplicates("var")


AnnData object with n_obs × n_vars = 8496 × 36601
    var: 'gene_ids', 'feature_types', 'genome', 'interval'
Feature types: ['Gene Expression']
Initial RNA shape: (8496, 36601)


  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")


In [6]:
# ----------------------
# 2) ADT
# ----------------------
print("Reading ADT counts...")
adt_df = pd.read_csv(adt_counts_csv, index_col=0)  # rows = barcodes, cols = ADT names
print("ADT counts shape:", adt_df.shape)

adt_adata = ad.AnnData(
    X=sp.csr_matrix(adt_df.values),
    obs=pd.DataFrame(index=adt_df.index.astype(str)),
    var=pd.DataFrame(index=adt_df.columns.astype(str)),
)
adt_adata.var_names_make_unique()
print("ADT AnnData:", adt_adata.shape)


Reading ADT counts...
ADT counts shape: (719170, 47)
ADT AnnData: (719170, 47)


In [7]:
# ----------------------
# 3) ATAC: fragments + metadata -> LSI
# ----------------------
print("Importing ATAC fragments with snapatac2...")
atac_raw = snap.pp.import_data(
    fragment_file=str(frag_tsv),
    chrom_sizes=snap.genome.hg38,   # use chrom_sizes
    sorted_by_barcode=False,        # TEA-seq fragments usually unsorted
)

print("ATAC raw object:", atac_raw)


Importing ATAC fragments with snapatac2...


  atac_raw = snap.pp.import_data(


ATAC raw object: AnnData object with n_obs × n_vars = 7540 × 0
    obs: 'n_fragment', 'frac_dup', 'frac_mito'
    uns: 'reference_sequences'
    obsm: 'fragment_paired'


In [8]:
# ---- 3a. Attach external metadata and map barcodes ----
meta = pd.read_csv(atac_meta_csv)
print("ATAC meta columns:", meta.columns.tolist())

# Index metadata by hex barcodes; they match atac_raw.obs_names
meta = meta.set_index("barcodes")

# Align cells common between fragments + metadata
common_ids = atac_raw.obs_names.intersection(meta.index)
print("ATAC cells with metadata:", len(common_ids), "of", atac_raw.n_obs)

atac_raw = atac_raw[common_ids].copy()
atac_raw.obs = atac_raw.obs.join(meta, how="left")


ATAC meta columns: ['original_barcodes', 'n_fragments', 'n_duplicate', 'n_mito', 'n_unique', 'altius_count', 'altius_frac', 'gene_bodies_count', 'gene_bodies_frac', 'peaks_count', 'peaks_frac', 'tss_count', 'tss_frac', 'barcodes', 'cell_name', 'well_id', 'chip_id', 'batch_id', 'pbmc_sample_id', 'DoubletScore', 'DoubletEnrichment', 'TSSEnrichment']
ATAC cells with metadata: 7540 of 7540




In [9]:
# ----------------------
# 4) Simple ATAC QC (using n_fragments from metadata)
# ----------------------
if "n_fragments" in atac_raw.obs.columns:
    mask = atac_raw.obs["n_fragments"] >= 1500
    print("Keeping", mask.sum(), "of", atac_raw.n_obs, "ATAC cells with n_fragments >= 1000")
    atac_raw = atac_raw[mask].copy()


Keeping 7540 of 7540 ATAC cells with n_fragments >= 1000




In [10]:
# ----------------------
# 5) Tile matrix + TF–IDF + LSI
# ----------------------
print("Adding ATAC tile matrix...")
snap.pp.add_tile_matrix(atac_raw)
print("Tile matrix shape:", atac_raw.shape, "(cells x genomic tiles)")


Adding ATAC tile matrix...
Tile matrix shape: (7540, 6062095) (cells x genomic tiles)


In [11]:
print("Computing TF–IDF and LSI on ATAC tiles...")
X = atac_raw.X
if not sp.issparse(X):
    X = sp.csr_matrix(X)

n_cells, n_feats = X.shape
print(f"ATAC tile matrix: {n_cells} cells × {n_feats} features")


Computing TF–IDF and LSI on ATAC tiles...
ATAC tile matrix: 7540 cells × 6062095 features


In [12]:
# TF: normalize by per-cell counts
tf = normalize(X, norm="l1", axis=1)

# DF: number of cells with a non-zero in each feature
df = np.array((X > 0).sum(axis=0)).ravel()
idf = np.log1p(n_cells / (1.0 + df))

# TF-IDF
X_tfidf = tf.multiply(idf)


In [13]:
'''
import numpy as np

# X: sparse matrix of shape (n_cells, n_peaks)
min_cells_per_peak = 50  # or 20, depending on sparsity

# number of nonzeros per peak (i.e., per column)
peak_nnz = X.getnnz(axis=0)          # returns a 1D np.array-like
peak_nnz = np.asarray(peak_nnz).ravel()

keep_peaks = peak_nnz >= min_cells_per_peak

X_filtered = X[:, keep_peaks]
'''

'\nimport numpy as np\n\n# X: sparse matrix of shape (n_cells, n_peaks)\nmin_cells_per_peak = 50  # or 20, depending on sparsity\n\n# number of nonzeros per peak (i.e., per column)\npeak_nnz = X.getnnz(axis=0)          # returns a 1D np.array-like\npeak_nnz = np.asarray(peak_nnz).ravel()\n\nkeep_peaks = peak_nnz >= min_cells_per_peak\n\nX_filtered = X[:, keep_peaks]\n'

In [14]:
if X_tfidf.dtype != np.float32:
    X_tfidf = X_tfidf.astype(np.float32)


In [15]:
# SVD → LSI
n_lsi = 100
svd = TruncatedSVD(n_components=n_lsi, random_state=42)
lsi = svd.fit_transform(X_tfidf)


In [16]:
# row L2-normalize
lsi = normalize(lsi, norm="l2", axis=1)


In [17]:
# Wrap as ATAC AnnData for UniVI
atac_adata = ad.AnnData(
    X=lsi.astype(np.float32),
    obs=atac_raw.obs.copy(),   # includes original_barcodes, etc.
)
print("ATAC LSI AnnData:", atac_adata.shape)


ATAC LSI AnnData: (7540, 100)


In [18]:
# ----------------------
# 6) Put everything into shared barcode space
# ----------------------
def strip_suffix(idx):
    # drop trailing "-<number>" if present
    return idx.astype(str).str.replace(r"-\d+$", "", regex=True)

# RNA & ADT already indexed by 10x barcodes with "-1"
rna_adata.obs_names = strip_suffix(rna_adata.obs_names.to_series())
adt_adata.obs_names = strip_suffix(adt_adata.obs_names.to_series())

# For ATAC, map from hex barcodes -> original 10x barcodes, then strip suffix
atac_adata.obs["barcode_10x"] = atac_adata.obs["original_barcodes"].astype(str)
atac_adata.obs_names = strip_suffix(atac_adata.obs["barcode_10x"])

# Also strip suffix from RNA/ADT now for consistency
rna_adata.obs_names = strip_suffix(rna_adata.obs_names.to_series())
adt_adata.obs_names = strip_suffix(adt_adata.obs_names.to_series())


In [19]:
# ----------------------
# 7) Tri-modal intersection
# ----------------------
common_barcodes = (
    set(rna_adata.obs_names)
    & set(adt_adata.obs_names)
    & set(atac_adata.obs_names)
)

print("Common tri-modal cells:", len(common_barcodes))
if len(common_barcodes) == 0:
    raise ValueError("No overlapping barcodes across RNA/ADT/ATAC after mapping.")

common_barcodes = sorted(common_barcodes)

rna_adata  = rna_adata[common_barcodes].copy()
adt_adata  = adt_adata[common_barcodes].copy()
atac_adata = atac_adata[common_barcodes].copy()

print("Aligned shapes:")
print("  RNA :", rna_adata.shape)
print("  ADT :", adt_adata.shape)
print("  ATAC:", atac_adata.shape)


Common tri-modal cells: 7421
Aligned shapes:
  RNA : (7421, 36601)
  ADT : (7421, 47)
  ATAC: (7421, 100)


In [20]:
# ----------------------
# 8) Optional subsampling
# ----------------------
target_n = 10000
n_cells  = rna_adata.n_obs

if n_cells > target_n:
    rng = np.random.default_rng(42)
    keep_idx = rng.choice(n_cells, size=target_n, replace=False)
    keep_barcodes = rna_adata.obs_names[keep_idx]

    rna_adata  = rna_adata[keep_barcodes].copy()
    adt_adata  = adt_adata[keep_barcodes].copy()
    atac_adata = atac_adata[keep_barcodes].copy()

print("After optional subsampling:", rna_adata.n_obs, "cells.")


After optional subsampling: 7421 cells.


#### Preprocess data since we will be using Gaussian decoders in this case to prioritize data alignment

In [21]:
# ----------------------
# 9) Modality-specific preprocessing
# ----------------------
'''
# --- RNA: log-normalize + HVGs ---
rna = rna_adata.copy()
rna.layers["counts"] = rna.X.copy()

sc.pp.normalize_total(rna, target_sum=1e4)
sc.pp.log1p(rna)
sc.pp.highly_variable_genes(rna, n_top_genes=2000, flavor="seurat_v3")
rna = rna[:, rna.var["highly_variable"]].copy()
print("RNA (HVG log1p) shape:", rna.shape)
'''

'\n# --- RNA: log-normalize + HVGs ---\nrna = rna_adata.copy()\nrna.layers["counts"] = rna.X.copy()\n\nsc.pp.normalize_total(rna, target_sum=1e4)\nsc.pp.log1p(rna)\nsc.pp.highly_variable_genes(rna, n_top_genes=2000, flavor="seurat_v3")\nrna = rna[:, rna.var["highly_variable"]].copy()\nprint("RNA (HVG log1p) shape:", rna.shape)\n'

In [22]:
'''
# --- ADT: CLR per cell ---
adt = adt_adata.copy()
adt.layers["counts"] = adt.X.copy()

X = adt.layers["counts"].astype(float)
if sp.issparse(X):
    X = X.toarray()

eps = 1e-6
X_log = np.log1p(X + eps)
X_clr = X_log - X_log.mean(axis=1, keepdims=True)
adt.X = X_clr.astype(np.float32)
print("ADT CLR shape:", adt.shape)
'''

'\n# --- ADT: CLR per cell ---\nadt = adt_adata.copy()\nadt.layers["counts"] = adt.X.copy()\n\nX = adt.layers["counts"].astype(float)\nif sp.issparse(X):\n    X = X.toarray()\n\neps = 1e-6\nX_log = np.log1p(X + eps)\nX_clr = X_log - X_log.mean(axis=1, keepdims=True)\nadt.X = X_clr.astype(np.float32)\nprint("ADT CLR shape:", adt.shape)\n'

In [23]:
'''
# --- ATAC: z-score each LSI dimension ---
atac = atac_adata.copy()
X_atac = atac.X.astype(np.float32)

mean = X_atac.mean(axis=0, keepdims=True)
std  = X_atac.std(axis=0, keepdims=True) + 1e-6
X_z  = (X_atac - mean) / std
atac.X = X_z.astype(np.float32)
print("ATAC LSI-z shape:", atac.shape)
'''

'\n# --- ATAC: z-score each LSI dimension ---\natac = atac_adata.copy()\nX_atac = atac.X.astype(np.float32)\n\nmean = X_atac.mean(axis=0, keepdims=True)\nstd  = X_atac.std(axis=0, keepdims=True) + 1e-6\nX_z  = (X_atac - mean) / std\natac.X = X_z.astype(np.float32)\nprint("ATAC LSI-z shape:", atac.shape)\n'

In [24]:
import scanpy as sc
import numpy as np
import scipy.sparse as sp

# ---------- RNA ----------
rna = rna_adata.copy()
rna.layers["counts"] = rna.X.copy()

# Normalize + log1p
sc.pp.normalize_total(rna, target_sum=1e4)
sc.pp.log1p(rna)

# HVGs
sc.pp.highly_variable_genes(rna, n_top_genes=2000, flavor="seurat_v3")
rna = rna[:, rna.var["highly_variable"]].copy()
print("RNA (HVG log1p) shape:", rna.shape)

# Z-score per gene across cells (for Gaussian decoder)
sc.pp.scale(rna, max_value=10)
print("RNA scaled shape:", rna.shape)


# ---------- ADT ----------
adt = adt_adata.copy()
adt.layers["counts"] = adt.X.copy()

X = adt.layers["counts"].astype(float)
if sp.issparse(X):
    X = X.toarray()

eps = 1e-6
# CLR per cell
X_log = np.log1p(X + eps)
X_clr = X_log - X_log.mean(axis=1, keepdims=True)

# Then per-feature z-score across cells
mean_adt = X_clr.mean(axis=0, keepdims=True)
std_adt  = X_clr.std(axis=0, keepdims=True) + 1e-6
X_clr_z  = (X_clr - mean_adt) / std_adt

adt.X = X_clr_z.astype(np.float32)
print("ADT CLR+z shape:", adt.shape)


# ---------- ATAC ----------
atac = atac_adata.copy()

X_atac = atac.X.astype(np.float32)  # assume this is LSI already
mean_atac = X_atac.mean(axis=0, keepdims=True)
std_atac  = X_atac.std(axis=0, keepdims=True) + 1e-6
X_z = (X_atac - mean_atac) / std_atac

atac.X = X_z.astype(np.float32)
print("ATAC LSI-z shape:", atac.shape)


  return fn(*args_all, **kw)


RNA (HVG log1p) shape: (7421, 2000)


  return dispatch(args[0].__class__)(*args, **kw)


RNA scaled shape: (7421, 2000)
ADT CLR+z shape: (7421, 47)
ATAC LSI-z shape: (7421, 100)


#### Hyperparameter optimization code for TEA-seq

In [26]:
# ==============================================
# UniVI hyperparameter search (TEA-seq: RNA / ADT / ATAC, tri-modal)
# ==============================================
import itertools
import json
import time
from copy import deepcopy

import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

import torch
from torch.utils.data import DataLoader, Subset

from univi.config import UniVIConfig, ModalityConfig, TrainingConfig
from univi.data import MultiModalDataset
from univi.models.univi import UniVIMultiModalVAE
from univi.trainer import UniVITrainer
from univi import evaluation as univi_eval

sns.set(style="whitegrid")

# ------------------------------------------------------
# 0. Assumes you already have:
#    - rna  : AnnData (RNA, HVGs or other preproc features)
#    - adt  : AnnData (ADT features)
#    - atac : AnnData (ATAC features, e.g. LSI / gene-body)
#      * obs_names aligned across all three:
#        rna.obs_names == adt.obs_names == atac.obs_names
#    - TEA-seq is label-free in this setup (no celltype key used)
# ------------------------------------------------------

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# ------------------------------------------------------
# 1. Random train/val/test split (label-free)
# ------------------------------------------------------

def make_random_split(
    n_cells,
    frac_train=0.8,
    frac_val=0.1,
    seed=42,
):
    """
    Returns train_idx, val_idx, test_idx (indices 0..n_cells-1),
    using a random split.
    """
    rng = np.random.default_rng(seed)
    idx = np.arange(n_cells)
    rng.shuffle(idx)

    n = idx.shape[0]
    n_train = int(frac_train * n)
    n_val = int(frac_val * n)
    n_test = n - n_train - n_val

    train_idx = idx[:n_train]
    val_idx   = idx[n_train:n_train + n_val]
    test_idx  = idx[n_train + n_val:]

    print(f"Random split over {n} cells")
    print(f"  Train: {len(train_idx)}")
    print(f"  Val  : {len(val_idx)}")
    print(f"  Test : {len(test_idx)}")

    return train_idx, val_idx, test_idx


# Use the unified names here
n_cells_total = rna.n_obs
assert n_cells_total == adt.n_obs == atac.n_obs, "n_obs mismatch across modalities!"
assert np.array_equal(rna.obs_names, adt.obs_names), "rna.obs_names != adt.obs_names"
assert np.array_equal(rna.obs_names, atac.obs_names), "rna.obs_names != atac.obs_names"

train_idx, val_idx, test_idx = make_random_split(
    n_cells_total,
    frac_train=0.8,
    frac_val=0.1,
    seed=42,
)

# Build view-specific AnnDatas for the split (so we can reuse them)
rna_train  = rna[train_idx].copy()
rna_val    = rna[val_idx].copy()
rna_test   = rna[test_idx].copy()

adt_train  = adt[train_idx].copy()
adt_val    = adt[val_idx].copy()
adt_test   = adt[test_idx].copy()

atac_train = atac[train_idx].copy()
atac_val   = atac[val_idx].copy()
atac_test  = atac[test_idx].copy()

# Consistency checks
assert np.all(rna_train.obs_names  == adt_train.obs_names)
assert np.all(rna_val.obs_names    == adt_val.obs_names)
assert np.all(rna_test.obs_names   == adt_test.obs_names)
assert np.all(rna_train.obs_names  == atac_train.obs_names)
assert np.all(rna_val.obs_names    == atac_val.obs_names)
assert np.all(rna_test.obs_names   == atac_test.obs_names)

print("TEA-seq tri-modal train/val/test splits prepared.")


Using device: cuda
Random split over 7421 cells
  Train: 5936
  Val  : 742
  Test : 743
TEA-seq tri-modal train/val/test splits prepared.


In [27]:
# ------------------------------------------------------
# 2. Build a base MultiModalDataset (train+val only)
#    – we will reuse this across all hyperparameters.
# ------------------------------------------------------

adata_trainval = {
    "rna":  rna[np.concatenate([train_idx, val_idx])].copy(),
    "adt":  adt[np.concatenate([train_idx, val_idx])].copy(),
    "atac": atac[np.concatenate([train_idx, val_idx])].copy(),
}

trainval_obs_names = adata_trainval["rna"].obs_names.to_numpy()
assert np.array_equal(trainval_obs_names, adata_trainval["adt"].obs_names)
assert np.array_equal(trainval_obs_names, adata_trainval["atac"].obs_names)

dataset = MultiModalDataset(
    adata_dict=adata_trainval,
    X_key="X",          # rna.X, adt.X, atac.X used as UniVI inputs
    device=device,
)

n_cells_tv = dataset.n_cells
print("Train+Val cells in dataset:", n_cells_tv)

# Remap train/val indices into [0..n_cells_tv)
name_to_pos = {name: i for i, name in enumerate(trainval_obs_names)}
train_idx_ds = np.array([name_to_pos[n] for n in rna_train.obs_names])
val_idx_ds   = np.array([name_to_pos[n] for n in rna_val.obs_names])


Train+Val cells in dataset: 6678


In [41]:
# ------------------------------------------------------
# 3. Hyperparameter search space (arch + regularization)
# ------------------------------------------------------

# Architecture options; tweak as needed
rna_arch_options = [
    {"name": "rna_med2",  "enc": [512, 256],         "dec": [256, 512]},
    {"name": "rna_wide2", "enc": [1024, 512],        "dec": [512, 1024]},
    {"name": "rna_wide3", "enc": [1024, 512, 256],   "dec": [256, 512, 1024]},
]

adt_arch_options = [
    {"name": "adt_small2", "enc": [128, 64],      "dec": [64, 128]},
    {"name": "adt_med2",   "enc": [256, 128],     "dec": [128, 256]},
]

atac_arch_options = [
    {"name": "atac_small2", "enc": [128, 64],      "dec": [64, 128]},
    {"name": "atac_med2",   "enc": [256, 128],     "dec": [128, 256]},
    {"name": "atac_wide2",  "enc": [512, 256],     "dec": [256, 512]},
]

search_space = {
    "latent_dim":        [20, 32, 40, 50, 64, 72, 86, 100, 124, 156, 200],
    "beta":              [0.0, 1.0, 10.0, 40.0, 60.0, 80.0, 100.0, 140.0, 180.0, 240.0, 300.0, 400.0],
    "gamma":             [0.0, 40.0, 60.0, 80.0, 100.0, 140.0, 180.0, 240.0, 300.0, 400.0, 500.0, 1000.0],
    "lr":                [1e-3, 5e-4],
    "weight_decay":      [1e-4, 1e-5],
    "encoder_dropout":   [0.0, 0.1],
    "decoder_batchnorm": [False, True],
    "rna_arch":          rna_arch_options,
    "adt_arch":          adt_arch_options,
    "atac_arch":         atac_arch_options,
}

MAX_CONFIGS = 120  # how many random configs to try


def iter_hparam_configs(space_dict, max_configs=MAX_CONFIGS, seed=0):
    """
    Random sampler over the hyperparameter space.
    Each config independently samples a value for each key.
    """
    rng = np.random.default_rng(seed)
    keys = list(space_dict.keys())
    for _ in range(max_configs):
        hp = {}
        for k in keys:
            options = space_dict[k]
            idx = rng.integers(len(options))
            hp[k] = options[idx]
        yield hp


In [42]:
# ------------------------------------------------------
# 4. Helper to build UniVI + TrainingConfig from hparams
# ------------------------------------------------------

def build_univi_and_train_cfg(hp):
    latent_dim        = hp["latent_dim"]
    beta              = hp["beta"]
    gamma             = hp["gamma"]
    lr                = hp["lr"]
    weight_decay      = hp["weight_decay"]
    encoder_dropout   = hp["encoder_dropout"]      # from search space
    decoder_batchnorm = hp["decoder_batchnorm"]    # from search space

    rna_arch  = hp["rna_arch"]
    adt_arch  = hp["adt_arch"]
    atac_arch = hp["atac_arch"]

    # -------- modality configs (no dropout here) --------
    mod_rna = ModalityConfig(
        name="rna",
        input_dim=rna.n_vars,
        encoder_hidden=rna_arch["enc"],
        decoder_hidden=rna_arch["dec"],
        likelihood="gaussian",       # or "nb" if rna.X are counts
    )

    mod_adt = ModalityConfig(
        name="adt",
        input_dim=adt.X.shape[1],
        encoder_hidden=adt_arch["enc"],
        decoder_hidden=adt_arch["dec"],
        likelihood="gaussian",       # CLR / standardized ADT
    )

    mod_atac = ModalityConfig(
        name="atac",
        input_dim=atac.X.shape[1],
        encoder_hidden=atac_arch["enc"],
        decoder_hidden=atac_arch["dec"],
        likelihood="gaussian",       # LSI / gene-body
    )

    # -------- global UniVI config (dropout + batchnorm here) --------
    univi_cfg = UniVIConfig(
        latent_dim=latent_dim,
        modalities=[mod_rna, mod_adt, mod_atac],
        beta=beta,
        gamma=gamma,
        encoder_dropout=encoder_dropout,
        # you can also search this if you want:
        decoder_dropout=0.0,
        encoder_batchnorm=True,
        decoder_batchnorm=decoder_batchnorm,
        # set annealing if you want non-default behavior:
        # kl_anneal_start=0,
        # kl_anneal_end=0,
        # align_anneal_start=0,
        # align_anneal_end=0,
    )

    # -------- training config --------
    train_cfg = TrainingConfig(
        n_epochs=200,
        batch_size=256,
        lr=lr,
        weight_decay=weight_decay,
        device=device,
        log_every=10,
        num_workers=0,
        seed=42,
        early_stopping=True,
        patience=20,
        min_delta=0.0,
    )

    return univi_cfg, train_cfg


In [43]:
# ------------------------------------------------------
# 5. Train + evaluate one hyperparameter configuration (TEA-seq tri-modal)
# ------------------------------------------------------

def evaluate_config(hp, config_id):
    """
    Train a UniVI tri-modal model with hyperparameters hp, evaluate on val set.

    Returns a dict with:
      - best_val_loss
      - FOSCTTM for each pair (RNA–ADT, RNA–ATAC, ADT–ATAC) on val
      - global modality mixing score on val
      - composite score for hyperparam selection (lower = better)
      - training history and hp
    """
    print("\n" + "=" * 80)
    print(f"[Config {config_id}] Hyperparameters:")
    pretty_hp = {
        **{k: v for k, v in hp.items() if k not in ("rna_arch", "adt_arch", "atac_arch")},
        "rna_arch": hp["rna_arch"]["name"],
        "adt_arch": hp["adt_arch"]["name"],
        "atac_arch": hp["atac_arch"]["name"],
    }
    print(json.dumps(pretty_hp, indent=2))
    print("=" * 80)

    univi_cfg, train_cfg = build_univi_and_train_cfg(hp)

    model = UniVIMultiModalVAE(univi_cfg).to(device)

    # DataLoaders
    train_ds = Subset(dataset, train_idx_ds)
    val_ds   = Subset(dataset, val_idx_ds)

    train_loader = DataLoader(
        train_ds,
        batch_size=train_cfg.batch_size,
        shuffle=True,
        num_workers=train_cfg.num_workers,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=train_cfg.batch_size,
        shuffle=False,
        num_workers=train_cfg.num_workers,
    )

    trainer = UniVITrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        train_cfg=train_cfg,
        device=device,
    )

    t0 = time.time()
    history = trainer.fit()
    t1 = time.time()

    best_val    = float(min(trainer.history["val_loss"]))
    final_train = float(trainer.history["train_loss"][-1])
    final_beta  = float(trainer.history["beta"][-1])
    final_gamma = float(trainer.history["gamma"][-1])
    elapsed_min = (t1 - t0) / 60.0

    # Latent embeddings for validation cells (from val AnnDatas, not the dataset)
    z_rna_val  = trainer.encode_modality(rna_val,  modality="rna")
    z_adt_val  = trainer.encode_modality(adt_val,  modality="adt")
    z_atac_val = trainer.encode_modality(atac_val, modality="atac")

    # Pairwise FOSCTTM
    fos_rna_adt  = float(univi_eval.compute_foscttm(z_rna_val,  z_adt_val))
    fos_rna_atac = float(univi_eval.compute_foscttm(z_rna_val,  z_atac_val))
    fos_adt_atac = float(univi_eval.compute_foscttm(z_adt_val, z_atac_val))

    fos_mean = float((fos_rna_adt + fos_rna_atac + fos_adt_atac) / 3.0)

    # Global mixing score (lower = better, same convention as your CITE-seq code)
    Z_joint_val = np.concatenate([z_rna_val, z_adt_val, z_atac_val], axis=0)
    modality_labels_val = np.array(
        ["rna"]  * z_rna_val.shape[0]
        + ["adt"]  * z_adt_val.shape[0]
        + ["atac"] * z_atac_val.shape[0]
    )

    mixing_score = float(univi_eval.compute_modality_mixing(
        Z_joint_val,
        modality_labels_val,
        k=20,
    ))

    # Composite score: lower val_loss, FOSCTTM, and mixing are all better
    score = best_val * (1.0 + fos_mean) * (1.0 + mixing_score)

    result = {
        "config_id": config_id,
        "best_val_loss": best_val,
        "final_train_loss": final_train,
        "final_beta": final_beta,
        "final_gamma": final_gamma,
        "fos_rna_adt_val": fos_rna_adt,
        "fos_rna_atac_val": fos_rna_atac,
        "fos_adt_atac_val": fos_adt_atac,
        "fos_mean_val": fos_mean,
        "mixing_score_val": mixing_score,
        "score": float(score),
        "minutes": elapsed_min,
        "history": history,
        "hp": deepcopy(hp),
    }

    print(f"[Config {config_id}] Done in {elapsed_min:.1f} min")
    print(f"  best_val_loss              = {best_val:.3f}")
    print(f"  FOSCTTM (RNA vs ADT, val)  = {fos_rna_adt:.4f}")
    print(f"  FOSCTTM (RNA vs ATAC, val) = {fos_rna_atac:.4f}")
    print(f"[Config {config_id}] FOSCTTM (ADT vs ATAC, val) = {fos_adt_atac:.4f}")
    print(f"  Mean FOSCTTM (3 pairs)     = {fos_mean:.4f}")
    print(f"  Modality mixing (k=20)     = {mixing_score:.4f}")
    print(f"  Composite score            = {score:.2f}")

    return result


In [None]:
# ------------------------------------------------------
# 6. Run the search
# ------------------------------------------------------

all_results = []
best_score = None
best_result = None

for i, hp in enumerate(iter_hparam_configs(search_space), start=1):
    res = evaluate_config(hp, config_id=i)
    s = res["score"]
    all_results.append(res)

    if best_score is None or s < best_score:
        best_score = s
        best_result = res
        print(f"--> New best config (id={i}) with score={s:.3f}")

print("\n========================================")
print("TEA-seq Hyperparameter search finished.")
print("========================================")

# Sort configs by score
all_results_sorted = sorted(all_results, key=lambda r: r["score"])
for r in all_results_sorted:
    hp = r["hp"]
    print(
        f"Config {r['config_id']:02d} | "
        f"latent={hp['latent_dim']:>2d}, "
        f"beta={hp['beta']:>5.1f}, "
        f"gamma={hp['gamma']:>5.1f}, "
        f"lr={hp['lr']:.0e}, "
        f"wd={hp['weight_decay']:.0e}, "
        f"enc_drop={hp['encoder_dropout']:.2f}, "
        f"dec_bn={hp['decoder_batchnorm']} | "
        f"rna_arch={hp['rna_arch']['name']}, "
        f"adt_arch={hp['adt_arch']['name']}, "
        f"atac_arch={hp['atac_arch']['name']} | "
        f"val_loss={r['best_val_loss']:.2f}, "
        f"FOS_RNA-ADT={r['fos_rna_adt_val']:.4f}, "
        f"FOS_RNA-ATAC={r['fos_rna_atac_val']:.4f}, "
        f"FOS_ADT-ATAC={r['fos_adt_atac_val']:.4f}, "
        f"mixing={r['mixing_score_val']:.4f}, "
        f"score={r['score']:.2f}"
    )

print("\nBest configuration hyperparameters:")
pretty_best = {
    **{k: v for k, v in best_result["hp"].items() if k not in ("rna_arch", "adt_arch", "atac_arch")},
    "rna_arch": best_result["hp"]["rna_arch"]["name"],
    "adt_arch": best_result["hp"]["adt_arch"]["name"],
    "atac_arch": best_result["hp"]["atac_arch"]["name"],
}
print(json.dumps(pretty_best, indent=2))
print(
    f"Best val_loss={best_result['best_val_loss']:.3f}, "
    f"FOS_mean={best_result['fos_mean_val']:.4f}, "
    f"mixing={best_result['mixing_score_val']:.4f}, "
    f"score={best_result['score']:.2f}"
)


[2025-11-19 23:11:44,140] [UniVITrainer] [INFO] TrainingConfig:
2025-11-19 23:11:44 - INFO - TrainingConfig:
[2025-11-19 23:11:44,142] [UniVITrainer] [INFO]   n_epochs: 200
2025-11-19 23:11:44 - INFO -   n_epochs: 200
[2025-11-19 23:11:44,146] [UniVITrainer] [INFO]   batch_size: 256
2025-11-19 23:11:44 - INFO -   batch_size: 256
[2025-11-19 23:11:44,156] [UniVITrainer] [INFO]   lr: 0.001
2025-11-19 23:11:44 - INFO -   lr: 0.001
[2025-11-19 23:11:44,158] [UniVITrainer] [INFO]   weight_decay: 0.0001
2025-11-19 23:11:44 - INFO -   weight_decay: 0.0001
[2025-11-19 23:11:44,165] [UniVITrainer] [INFO]   device: cuda
2025-11-19 23:11:44 - INFO -   device: cuda
[2025-11-19 23:11:44,167] [UniVITrainer] [INFO]   log_every: 10
2025-11-19 23:11:44 - INFO -   log_every: 10
[2025-11-19 23:11:44,168] [UniVITrainer] [INFO]   grad_clip: None
2025-11-19 23:11:44 - INFO -   grad_clip: None
[2025-11-19 23:11:44,169] [UniVITrainer] [INFO]   num_workers: 0
2025-11-19 23:11:44 - INFO -   num_workers: 0
[2025


[Config 1] Hyperparameters:
{
  "latent_dim": 156,
  "beta": 240.0,
  "gamma": 180.0,
  "lr": 0.001,
  "weight_decay": 0.0001,
  "encoder_dropout": 0.0,
  "decoder_batchnorm": false,
  "rna_arch": "rna_med2",
  "adt_arch": "adt_small2",
  "atac_arch": "atac_wide2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-19 23:11:51,112] [UniVITrainer] [INFO] [Epoch 001] Train loss: 6900.0410 (beta=240.000, gamma=180.000)
2025-11-19 23:11:51 - INFO - [Epoch 001] Train loss: 6900.0410 (beta=240.000, gamma=180.000)
[2025-11-19 23:11:51,953] [UniVITrainer] [INFO] [Epoch 001] Val loss: 2012.9840 (beta=240.000, gamma=180.000)
2025-11-19 23:11:51 - INFO - [Epoch 001] Val loss: 2012.9840 (beta=240.000, gamma=180.000)
[2025-11-19 23:11:51,969] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 2012.9840
2025-11-19 23:11:51 - INFO - [Epoch 001] New best val loss: 2012.9840
[2025-11-19 23:12:00,004] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 1203.1804
2025-11-19 23:12:00 - INFO - [Epoch 002] New best val loss: 1203.1804
[2025-11-19 23:12:08,021] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 1005.3835
2025-11-19 23:12:08 - INFO - [Epoch 003] New best val loss: 1005.3835
[2025-11-19 23:12:15,801] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 944.9690
2025-11-19 23:12:15 - INFO - 

2025-11-19 23:19:39 - INFO - [Epoch 050] Train loss: 782.9580 (beta=240.000, gamma=180.000)
[2025-11-19 23:19:41,077] [UniVITrainer] [INFO] [Epoch 050] Val loss: 784.6361 (beta=240.000, gamma=180.000)
2025-11-19 23:19:41 - INFO - [Epoch 050] Val loss: 784.6361 (beta=240.000, gamma=180.000)
[2025-11-19 23:19:51,392] [UniVITrainer] [INFO] [Epoch 051] New best val loss: 783.2460
2025-11-19 23:19:51 - INFO - [Epoch 051] New best val loss: 783.2460
[2025-11-19 23:20:11,604] [UniVITrainer] [INFO] [Epoch 053] New best val loss: 782.9938
2025-11-19 23:20:11 - INFO - [Epoch 053] New best val loss: 782.9938
[2025-11-19 23:20:21,924] [UniVITrainer] [INFO] [Epoch 054] New best val loss: 782.2261
2025-11-19 23:20:21 - INFO - [Epoch 054] New best val loss: 782.2261
[2025-11-19 23:21:12,082] [UniVITrainer] [INFO] [Epoch 059] New best val loss: 780.1767
2025-11-19 23:21:12 - INFO - [Epoch 059] New best val loss: 780.1767
[2025-11-19 23:21:21,100] [UniVITrainer] [INFO] [Epoch 060] Train loss: 770.6018 

[2025-11-19 23:31:04,579] [UniVITrainer] [INFO] [Epoch 117] New best val loss: 772.1107
2025-11-19 23:31:04 - INFO - [Epoch 117] New best val loss: 772.1107
[2025-11-19 23:31:14,411] [UniVITrainer] [INFO] [Epoch 118] New best val loss: 771.8345
2025-11-19 23:31:14 - INFO - [Epoch 118] New best val loss: 771.8345
[2025-11-19 23:31:24,457] [UniVITrainer] [INFO] [Epoch 119] New best val loss: 771.6207
2025-11-19 23:31:24 - INFO - [Epoch 119] New best val loss: 771.6207
[2025-11-19 23:31:32,881] [UniVITrainer] [INFO] [Epoch 120] Train loss: 765.7484 (beta=240.000, gamma=180.000)
2025-11-19 23:31:32 - INFO - [Epoch 120] Train loss: 765.7484 (beta=240.000, gamma=180.000)
[2025-11-19 23:31:33,945] [UniVITrainer] [INFO] [Epoch 120] Val loss: 771.7509 (beta=240.000, gamma=180.000)
2025-11-19 23:31:33 - INFO - [Epoch 120] Val loss: 771.7509 (beta=240.000, gamma=180.000)
[2025-11-19 23:31:43,971] [UniVITrainer] [INFO] [Epoch 121] New best val loss: 771.5199
2025-11-19 23:31:43 - INFO - [Epoch 121

[2025-11-19 23:44:33,822] [UniVITrainer] [INFO] TrainingConfig:
2025-11-19 23:44:33 - INFO - TrainingConfig:
[2025-11-19 23:44:33,824] [UniVITrainer] [INFO]   n_epochs: 200
2025-11-19 23:44:33 - INFO -   n_epochs: 200
[2025-11-19 23:44:33,825] [UniVITrainer] [INFO]   batch_size: 256
2025-11-19 23:44:33 - INFO -   batch_size: 256
[2025-11-19 23:44:33,826] [UniVITrainer] [INFO]   lr: 0.0005
2025-11-19 23:44:33 - INFO -   lr: 0.0005
[2025-11-19 23:44:33,827] [UniVITrainer] [INFO]   weight_decay: 1e-05
2025-11-19 23:44:33 - INFO -   weight_decay: 1e-05
[2025-11-19 23:44:33,828] [UniVITrainer] [INFO]   device: cuda
2025-11-19 23:44:33 - INFO -   device: cuda
[2025-11-19 23:44:33,829] [UniVITrainer] [INFO]   log_every: 10
2025-11-19 23:44:33 - INFO -   log_every: 10
[2025-11-19 23:44:33,830] [UniVITrainer] [INFO]   grad_clip: None
2025-11-19 23:44:33 - INFO -   grad_clip: None
[2025-11-19 23:44:33,831] [UniVITrainer] [INFO]   num_workers: 0
2025-11-19 23:44:33 - INFO -   num_workers: 0
[2025

[Config 1] Done in 32.7 min
  best_val_loss              = 768.685
  FOSCTTM (RNA vs ADT, val)  = 0.5396
  FOSCTTM (RNA vs ATAC, val) = 0.5196
[Config 1] FOSCTTM (ADT vs ATAC, val) = 0.4995
  Mean FOSCTTM (3 pairs)     = 0.5196
  Modality mixing (k=20)     = 0.0095
  Composite score            = 1179.14
--> New best config (id=1) with score=1179.136

[Config 2] Hyperparameters:
{
  "latent_dim": 100,
  "beta": 500.0,
  "gamma": 180.0,
  "lr": 0.0005,
  "weight_decay": 1e-05,
  "encoder_dropout": 0.1,
  "decoder_batchnorm": true,
  "rna_arch": "rna_wide2",
  "adt_arch": "adt_med2",
  "atac_arch": "atac_wide2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-19 23:44:42,863] [UniVITrainer] [INFO] [Epoch 001] Train loss: 9315.6695 (beta=500.000, gamma=180.000)
2025-11-19 23:44:42 - INFO - [Epoch 001] Train loss: 9315.6695 (beta=500.000, gamma=180.000)
[2025-11-19 23:44:43,935] [UniVITrainer] [INFO] [Epoch 001] Val loss: 3017.1510 (beta=500.000, gamma=180.000)
2025-11-19 23:44:43 - INFO - [Epoch 001] Val loss: 3017.1510 (beta=500.000, gamma=180.000)
[2025-11-19 23:44:43,970] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 3017.1510
2025-11-19 23:44:43 - INFO - [Epoch 001] New best val loss: 3017.1510
[2025-11-19 23:44:53,716] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 1470.7642
2025-11-19 23:44:53 - INFO - [Epoch 002] New best val loss: 1470.7642
[2025-11-19 23:45:03,444] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 1200.8636
2025-11-19 23:45:03 - INFO - [Epoch 003] New best val loss: 1200.8636
[2025-11-19 23:45:13,039] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 1079.7780
2025-11-19 23:45:13 - INFO -

[2025-11-19 23:53:02,026] [UniVITrainer] [INFO] [Epoch 060] Val loss: 819.0218 (beta=500.000, gamma=180.000)
2025-11-19 23:53:02 - INFO - [Epoch 060] Val loss: 819.0218 (beta=500.000, gamma=180.000)
[2025-11-19 23:53:25,488] [UniVITrainer] [INFO] [Epoch 063] New best val loss: 814.0481
2025-11-19 23:53:25 - INFO - [Epoch 063] New best val loss: 814.0481
[2025-11-19 23:54:14,824] [UniVITrainer] [INFO] [Epoch 070] Train loss: 798.5115 (beta=500.000, gamma=180.000)
2025-11-19 23:54:14 - INFO - [Epoch 070] Train loss: 798.5115 (beta=500.000, gamma=180.000)
[2025-11-19 23:54:15,572] [UniVITrainer] [INFO] [Epoch 070] Val loss: 814.5974 (beta=500.000, gamma=180.000)
2025-11-19 23:54:15 - INFO - [Epoch 070] Val loss: 814.5974 (beta=500.000, gamma=180.000)
[2025-11-19 23:54:21,730] [UniVITrainer] [INFO] [Epoch 071] New best val loss: 811.8329
2025-11-19 23:54:21 - INFO - [Epoch 071] New best val loss: 811.8329
[2025-11-19 23:54:29,210] [UniVITrainer] [INFO] [Epoch 072] New best val loss: 811.41

2025-11-20 00:06:06 - INFO - [Epoch 170] Val loss: 786.9663 (beta=500.000, gamma=180.000)
[2025-11-20 00:06:29,294] [UniVITrainer] [INFO] [Epoch 173] New best val loss: 784.6318
2025-11-20 00:06:29 - INFO - [Epoch 173] New best val loss: 784.6318
[2025-11-20 00:06:44,399] [UniVITrainer] [INFO] [Epoch 175] New best val loss: 784.2018
2025-11-20 00:06:44 - INFO - [Epoch 175] New best val loss: 784.2018
[2025-11-20 00:06:51,751] [UniVITrainer] [INFO] [Epoch 176] New best val loss: 784.1520
2025-11-20 00:06:51 - INFO - [Epoch 176] New best val loss: 784.1520
[2025-11-20 00:07:20,076] [UniVITrainer] [INFO] [Epoch 180] Train loss: 772.2194 (beta=500.000, gamma=180.000)
2025-11-20 00:07:20 - INFO - [Epoch 180] Train loss: 772.2194 (beta=500.000, gamma=180.000)
[2025-11-20 00:07:20,757] [UniVITrainer] [INFO] [Epoch 180] Val loss: 784.8333 (beta=500.000, gamma=180.000)
2025-11-20 00:07:20 - INFO - [Epoch 180] Val loss: 784.8333 (beta=500.000, gamma=180.000)
[2025-11-20 00:08:09,124] [UniVITrain

[Config 2] Done in 25.1 min
  best_val_loss              = 780.677
  FOSCTTM (RNA vs ADT, val)  = 0.6390
  FOSCTTM (RNA vs ATAC, val) = 0.5004
[Config 2] FOSCTTM (ADT vs ATAC, val) = 0.4969
  Mean FOSCTTM (3 pairs)     = 0.5454
  Modality mixing (k=20)     = 0.0005
  Composite score            = 1207.08

[Config 3] Hyperparameters:
{
  "latent_dim": 50,
  "beta": 400.0,
  "gamma": 300.0,
  "lr": 0.001,
  "weight_decay": 0.0001,
  "encoder_dropout": 0.1,
  "decoder_batchnorm": true,
  "rna_arch": "rna_med2",
  "adt_arch": "adt_med2",
  "atac_arch": "atac_wide2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 00:09:49,563] [UniVITrainer] [INFO] [Epoch 001] Train loss: 4315.9104 (beta=400.000, gamma=300.000)
2025-11-20 00:09:49 - INFO - [Epoch 001] Train loss: 4315.9104 (beta=400.000, gamma=300.000)
[2025-11-20 00:09:50,378] [UniVITrainer] [INFO] [Epoch 001] Val loss: 1268.7218 (beta=400.000, gamma=300.000)
2025-11-20 00:09:50 - INFO - [Epoch 001] Val loss: 1268.7218 (beta=400.000, gamma=300.000)
[2025-11-20 00:09:50,616] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 1268.7218
2025-11-20 00:09:50 - INFO - [Epoch 001] New best val loss: 1268.7218
[2025-11-20 00:09:58,183] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 1036.5960
2025-11-20 00:09:58 - INFO - [Epoch 002] New best val loss: 1036.5960
[2025-11-20 00:10:05,966] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 924.1252
2025-11-20 00:10:05 - INFO - [Epoch 003] New best val loss: 924.1252
[2025-11-20 00:10:13,469] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 883.4719
2025-11-20 00:10:13 - INFO - [E

[2025-11-20 00:17:00,885] [UniVITrainer] [INFO] [Epoch 060] Val loss: 780.3376 (beta=400.000, gamma=300.000)
2025-11-20 00:17:00 - INFO - [Epoch 060] Val loss: 780.3376 (beta=400.000, gamma=300.000)
[2025-11-20 00:17:53,055] [UniVITrainer] [INFO] [Epoch 067] New best val loss: 777.0358
2025-11-20 00:17:53 - INFO - [Epoch 067] New best val loss: 777.0358
[2025-11-20 00:18:13,042] [UniVITrainer] [INFO] [Epoch 070] Train loss: 775.6922 (beta=400.000, gamma=300.000)
2025-11-20 00:18:13 - INFO - [Epoch 070] Train loss: 775.6922 (beta=400.000, gamma=300.000)
[2025-11-20 00:18:13,878] [UniVITrainer] [INFO] [Epoch 070] Val loss: 777.8863 (beta=400.000, gamma=300.000)
2025-11-20 00:18:13 - INFO - [Epoch 070] Val loss: 777.8863 (beta=400.000, gamma=300.000)
[2025-11-20 00:18:19,492] [UniVITrainer] [INFO] [Epoch 071] New best val loss: 776.0653
2025-11-20 00:18:19 - INFO - [Epoch 071] New best val loss: 776.0653
[2025-11-20 00:18:27,546] [UniVITrainer] [INFO] [Epoch 072] New best val loss: 775.62

2025-11-20 00:27:44 - INFO - [Epoch 148] New best val loss: 767.8772
[2025-11-20 00:27:52,376] [UniVITrainer] [INFO] [Epoch 149] New best val loss: 767.8339
2025-11-20 00:27:52 - INFO - [Epoch 149] New best val loss: 767.8339
[2025-11-20 00:27:59,343] [UniVITrainer] [INFO] [Epoch 150] Train loss: 771.6810 (beta=400.000, gamma=300.000)
2025-11-20 00:27:59 - INFO - [Epoch 150] Train loss: 771.6810 (beta=400.000, gamma=300.000)
[2025-11-20 00:28:00,180] [UniVITrainer] [INFO] [Epoch 150] Val loss: 767.8877 (beta=400.000, gamma=300.000)
2025-11-20 00:28:00 - INFO - [Epoch 150] Val loss: 767.8877 (beta=400.000, gamma=300.000)
[2025-11-20 00:28:08,210] [UniVITrainer] [INFO] [Epoch 151] New best val loss: 767.7776
2025-11-20 00:28:08 - INFO - [Epoch 151] New best val loss: 767.7776
[2025-11-20 00:28:59,519] [UniVITrainer] [INFO] [Epoch 158] New best val loss: 767.4820
2025-11-20 00:28:59 - INFO - [Epoch 158] New best val loss: 767.4820
[2025-11-20 00:29:13,523] [UniVITrainer] [INFO] [Epoch 160

[Config 3] Done in 24.4 min
  best_val_loss              = 766.871
  FOSCTTM (RNA vs ADT, val)  = 0.5191
  FOSCTTM (RNA vs ATAC, val) = 0.5203
[Config 3] FOSCTTM (ADT vs ATAC, val) = 0.4517
  Mean FOSCTTM (3 pairs)     = 0.4970
  Modality mixing (k=20)     = 0.0000
  Composite score            = 1148.02
--> New best config (id=3) with score=1148.017

[Config 4] Hyperparameters:
{
  "latent_dim": 156,
  "beta": 60.0,
  "gamma": 40.0,
  "lr": 0.0005,
  "weight_decay": 0.0001,
  "encoder_dropout": 0.1,
  "decoder_batchnorm": false,
  "rna_arch": "rna_med2",
  "adt_arch": "adt_small2",
  "atac_arch": "atac_med2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 00:34:14,607] [UniVITrainer] [INFO] [Epoch 001] Train loss: 3438.2481 (beta=60.000, gamma=40.000)
2025-11-20 00:34:14 - INFO - [Epoch 001] Train loss: 3438.2481 (beta=60.000, gamma=40.000)
[2025-11-20 00:34:15,410] [UniVITrainer] [INFO] [Epoch 001] Val loss: 1936.3228 (beta=60.000, gamma=40.000)
2025-11-20 00:34:15 - INFO - [Epoch 001] Val loss: 1936.3228 (beta=60.000, gamma=40.000)
[2025-11-20 00:34:15,534] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 1936.3228
2025-11-20 00:34:15 - INFO - [Epoch 001] New best val loss: 1936.3228
[2025-11-20 00:34:23,299] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 1143.9238
2025-11-20 00:34:23 - INFO - [Epoch 002] New best val loss: 1143.9238
[2025-11-20 00:34:30,388] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 943.6554
2025-11-20 00:34:30 - INFO - [Epoch 003] New best val loss: 943.6554
[2025-11-20 00:34:38,160] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 886.7045
2025-11-20 00:34:38 - INFO - [Epoch 004

[2025-11-20 00:42:19,803] [UniVITrainer] [INFO] [Epoch 067] New best val loss: 766.0083
2025-11-20 00:42:19 - INFO - [Epoch 067] New best val loss: 766.0083
[2025-11-20 00:42:41,495] [UniVITrainer] [INFO] [Epoch 070] Train loss: 771.1906 (beta=60.000, gamma=40.000)
2025-11-20 00:42:41 - INFO - [Epoch 070] Train loss: 771.1906 (beta=60.000, gamma=40.000)
[2025-11-20 00:42:42,328] [UniVITrainer] [INFO] [Epoch 070] Val loss: 769.8859 (beta=60.000, gamma=40.000)
2025-11-20 00:42:42 - INFO - [Epoch 070] Val loss: 769.8859 (beta=60.000, gamma=40.000)
[2025-11-20 00:43:54,035] [UniVITrainer] [INFO] [Epoch 080] Train loss: 773.1522 (beta=60.000, gamma=40.000)
2025-11-20 00:43:54 - INFO - [Epoch 080] Train loss: 773.1522 (beta=60.000, gamma=40.000)
[2025-11-20 00:43:54,524] [UniVITrainer] [INFO] [Epoch 080] Val loss: 770.8414 (beta=60.000, gamma=40.000)
2025-11-20 00:43:54 - INFO - [Epoch 080] Val loss: 770.8414 (beta=60.000, gamma=40.000)
[2025-11-20 00:44:46,771] [UniVITrainer] [INFO] Early s

[Config 4] Done in 10.6 min
  best_val_loss              = 766.008
  FOSCTTM (RNA vs ADT, val)  = 0.2176
  FOSCTTM (RNA vs ATAC, val) = 0.2234
[Config 4] FOSCTTM (ADT vs ATAC, val) = 0.2329
  Mean FOSCTTM (3 pairs)     = 0.2246
  Modality mixing (k=20)     = 0.1448
  Composite score            = 1073.90
--> New best config (id=4) with score=1073.902

[Config 5] Hyperparameters:
{
  "latent_dim": 64,
  "beta": 0.0,
  "gamma": 0.0,
  "lr": 0.001,
  "weight_decay": 0.0001,
  "encoder_dropout": 0.1,
  "decoder_batchnorm": true,
  "rna_arch": "rna_wide2",
  "adt_arch": "adt_small2",
  "atac_arch": "atac_med2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 00:44:54,366] [UniVITrainer] [INFO] [Epoch 001] Train loss: 782.8636 (beta=0.000, gamma=0.000)
2025-11-20 00:44:54 - INFO - [Epoch 001] Train loss: 782.8636 (beta=0.000, gamma=0.000)
[2025-11-20 00:44:55,186] [UniVITrainer] [INFO] [Epoch 001] Val loss: 731.5600 (beta=0.000, gamma=0.000)
2025-11-20 00:44:55 - INFO - [Epoch 001] Val loss: 731.5600 (beta=0.000, gamma=0.000)
[2025-11-20 00:44:55,291] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 731.5600
2025-11-20 00:44:55 - INFO - [Epoch 001] New best val loss: 731.5600
[2025-11-20 00:45:02,224] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 648.3118
2025-11-20 00:45:02 - INFO - [Epoch 002] New best val loss: 648.3118
[2025-11-20 00:45:09,186] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 642.5913
2025-11-20 00:45:09 - INFO - [Epoch 003] New best val loss: 642.5913
[2025-11-20 00:45:16,620] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 640.9319
2025-11-20 00:45:16 - INFO - [Epoch 004] New best val l

2025-11-20 00:51:19 - INFO -   patience: 20
[2025-11-20 00:51:19,501] [UniVITrainer] [INFO]   min_delta: 0.0
2025-11-20 00:51:19 - INFO -   min_delta: 0.0


[Config 5] Done in 6.5 min
  best_val_loss              = 571.090
  FOSCTTM (RNA vs ADT, val)  = 0.2039
  FOSCTTM (RNA vs ATAC, val) = 0.3829
[Config 5] FOSCTTM (ADT vs ATAC, val) = 0.4202
  Mean FOSCTTM (3 pairs)     = 0.3357
  Modality mixing (k=20)     = 0.0149
  Composite score            = 774.19
--> New best config (id=5) with score=774.193

[Config 6] Hyperparameters:
{
  "latent_dim": 124,
  "beta": 100.0,
  "gamma": 140.0,
  "lr": 0.0005,
  "weight_decay": 1e-05,
  "encoder_dropout": 0.1,
  "decoder_batchnorm": false,
  "rna_arch": "rna_wide3",
  "adt_arch": "adt_med2",
  "atac_arch": "atac_med2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 00:51:26,316] [UniVITrainer] [INFO] [Epoch 001] Train loss: 5561.9546 (beta=100.000, gamma=140.000)
2025-11-20 00:51:26 - INFO - [Epoch 001] Train loss: 5561.9546 (beta=100.000, gamma=140.000)
[2025-11-20 00:51:27,135] [UniVITrainer] [INFO] [Epoch 001] Val loss: 2317.4333 (beta=100.000, gamma=140.000)
2025-11-20 00:51:27 - INFO - [Epoch 001] Val loss: 2317.4333 (beta=100.000, gamma=140.000)
[2025-11-20 00:51:27,316] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 2317.4333
2025-11-20 00:51:27 - INFO - [Epoch 001] New best val loss: 2317.4333
[2025-11-20 00:51:34,660] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 1195.9583
2025-11-20 00:51:34 - INFO - [Epoch 002] New best val loss: 1195.9583
[2025-11-20 00:51:40,907] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 997.3190
2025-11-20 00:51:40 - INFO - [Epoch 003] New best val loss: 997.3190
[2025-11-20 00:51:48,314] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 947.8478
2025-11-20 00:51:48 - INFO - [E

2025-11-20 00:56:53 - INFO - [Epoch 046] New best val loss: 777.6694
[2025-11-20 00:57:01,124] [UniVITrainer] [INFO] [Epoch 047] New best val loss: 776.5288
2025-11-20 00:57:01 - INFO - [Epoch 047] New best val loss: 776.5288
[2025-11-20 00:57:08,545] [UniVITrainer] [INFO] [Epoch 048] New best val loss: 775.4659
2025-11-20 00:57:08 - INFO - [Epoch 048] New best val loss: 775.4659
[2025-11-20 00:57:22,152] [UniVITrainer] [INFO] [Epoch 050] Train loss: 784.8402 (beta=100.000, gamma=140.000)
2025-11-20 00:57:22 - INFO - [Epoch 050] Train loss: 784.8402 (beta=100.000, gamma=140.000)
[2025-11-20 00:57:22,964] [UniVITrainer] [INFO] [Epoch 050] Val loss: 775.1205 (beta=100.000, gamma=140.000)
2025-11-20 00:57:22 - INFO - [Epoch 050] Val loss: 775.1205 (beta=100.000, gamma=140.000)
[2025-11-20 00:57:23,142] [UniVITrainer] [INFO] [Epoch 050] New best val loss: 775.1205
2025-11-20 00:57:23 - INFO - [Epoch 050] New best val loss: 775.1205
[2025-11-20 00:57:30,890] [UniVITrainer] [INFO] [Epoch 051

2025-11-20 01:03:44 - INFO - [Epoch 102] New best val loss: 767.1903
[2025-11-20 01:04:27,369] [UniVITrainer] [INFO] [Epoch 108] New best val loss: 767.0944
2025-11-20 01:04:27 - INFO - [Epoch 108] New best val loss: 767.0944
[2025-11-20 01:04:41,552] [UniVITrainer] [INFO] [Epoch 110] Train loss: 767.1356 (beta=100.000, gamma=140.000)
2025-11-20 01:04:41 - INFO - [Epoch 110] Train loss: 767.1356 (beta=100.000, gamma=140.000)
[2025-11-20 01:04:42,389] [UniVITrainer] [INFO] [Epoch 110] Val loss: 767.1335 (beta=100.000, gamma=140.000)
2025-11-20 01:04:42 - INFO - [Epoch 110] Val loss: 767.1335 (beta=100.000, gamma=140.000)
[2025-11-20 01:04:50,373] [UniVITrainer] [INFO] [Epoch 111] New best val loss: 767.0328
2025-11-20 01:04:50 - INFO - [Epoch 111] New best val loss: 767.0328
[2025-11-20 01:05:20,389] [UniVITrainer] [INFO] [Epoch 115] New best val loss: 767.0259
2025-11-20 01:05:20 - INFO - [Epoch 115] New best val loss: 767.0259
[2025-11-20 01:05:26,527] [UniVITrainer] [INFO] [Epoch 116

2025-11-20 01:15:38 - INFO - [Epoch 200] Val loss: 766.4101 (beta=100.000, gamma=140.000)
[2025-11-20 01:15:38,412] [UniVITrainer] [INFO] Restored best model from epoch 192 (val loss = 766.3941)
2025-11-20 01:15:38 - INFO - Restored best model from epoch 192 (val loss = 766.3941)
[2025-11-20 01:15:40,896] [UniVITrainer] [INFO] TrainingConfig:
2025-11-20 01:15:40 - INFO - TrainingConfig:
[2025-11-20 01:15:40,897] [UniVITrainer] [INFO]   n_epochs: 200
2025-11-20 01:15:40 - INFO -   n_epochs: 200
[2025-11-20 01:15:40,905] [UniVITrainer] [INFO]   batch_size: 256
2025-11-20 01:15:40 - INFO -   batch_size: 256
[2025-11-20 01:15:40,909] [UniVITrainer] [INFO]   lr: 0.001
2025-11-20 01:15:40 - INFO -   lr: 0.001
[2025-11-20 01:15:40,916] [UniVITrainer] [INFO]   weight_decay: 1e-05
2025-11-20 01:15:40 - INFO -   weight_decay: 1e-05
[2025-11-20 01:15:40,917] [UniVITrainer] [INFO]   device: cuda
2025-11-20 01:15:40 - INFO -   device: cuda
[2025-11-20 01:15:40,918] [UniVITrainer] [INFO]   log_every

[Config 6] Done in 24.3 min
  best_val_loss              = 766.394
  FOSCTTM (RNA vs ADT, val)  = 0.4772
  FOSCTTM (RNA vs ATAC, val) = 0.4847
[Config 6] FOSCTTM (ADT vs ATAC, val) = 0.4869
  Mean FOSCTTM (3 pairs)     = 0.4829
  Modality mixing (k=20)     = 0.0002
  Composite score            = 1136.79

[Config 7] Hyperparameters:
{
  "latent_dim": 156,
  "beta": 300.0,
  "gamma": 300.0,
  "lr": 0.001,
  "weight_decay": 1e-05,
  "encoder_dropout": 0.0,
  "decoder_batchnorm": true,
  "rna_arch": "rna_wide3",
  "adt_arch": "adt_med2",
  "atac_arch": "atac_med2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 01:15:47,788] [UniVITrainer] [INFO] [Epoch 001] Train loss: 8937.6331 (beta=300.000, gamma=300.000)
2025-11-20 01:15:47 - INFO - [Epoch 001] Train loss: 8937.6331 (beta=300.000, gamma=300.000)
[2025-11-20 01:15:48,607] [UniVITrainer] [INFO] [Epoch 001] Val loss: 2409.8827 (beta=300.000, gamma=300.000)
2025-11-20 01:15:48 - INFO - [Epoch 001] Val loss: 2409.8827 (beta=300.000, gamma=300.000)
[2025-11-20 01:15:48,877] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 2409.8827
2025-11-20 01:15:48 - INFO - [Epoch 001] New best val loss: 2409.8827
[2025-11-20 01:15:56,460] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 1395.4587
2025-11-20 01:15:56 - INFO - [Epoch 002] New best val loss: 1395.4587
[2025-11-20 01:16:03,866] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 1149.9595
2025-11-20 01:16:03 - INFO - [Epoch 003] New best val loss: 1149.9595
[2025-11-20 01:16:10,735] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 1052.6732
2025-11-20 01:16:10 - INFO -

[2025-11-20 01:22:59,675] [UniVITrainer] [INFO] [Epoch 060] Val loss: 789.5868 (beta=300.000, gamma=300.000)
2025-11-20 01:22:59 - INFO - [Epoch 060] Val loss: 789.5868 (beta=300.000, gamma=300.000)
[2025-11-20 01:22:59,919] [UniVITrainer] [INFO] [Epoch 060] New best val loss: 789.5868
2025-11-20 01:22:59 - INFO - [Epoch 060] New best val loss: 789.5868
[2025-11-20 01:23:15,322] [UniVITrainer] [INFO] [Epoch 062] New best val loss: 788.8138
2025-11-20 01:23:15 - INFO - [Epoch 062] New best val loss: 788.8138
[2025-11-20 01:23:22,877] [UniVITrainer] [INFO] [Epoch 063] New best val loss: 787.7903
2025-11-20 01:23:22 - INFO - [Epoch 063] New best val loss: 787.7903
[2025-11-20 01:23:30,561] [UniVITrainer] [INFO] [Epoch 064] New best val loss: 787.7324
2025-11-20 01:23:30 - INFO - [Epoch 064] New best val loss: 787.7324
[2025-11-20 01:23:38,415] [UniVITrainer] [INFO] [Epoch 065] New best val loss: 786.3551
2025-11-20 01:23:38 - INFO - [Epoch 065] New best val loss: 786.3551
[2025-11-20 01:2

[2025-11-20 01:33:01,067] [UniVITrainer] [INFO] [Epoch 141] New best val loss: 771.7521
2025-11-20 01:33:01 - INFO - [Epoch 141] New best val loss: 771.7521
[2025-11-20 01:33:08,902] [UniVITrainer] [INFO] [Epoch 142] New best val loss: 771.6725
2025-11-20 01:33:08 - INFO - [Epoch 142] New best val loss: 771.6725
[2025-11-20 01:33:29,264] [UniVITrainer] [INFO] [Epoch 145] New best val loss: 770.9633
2025-11-20 01:33:29 - INFO - [Epoch 145] New best val loss: 770.9633
[2025-11-20 01:34:03,761] [UniVITrainer] [INFO] [Epoch 150] Train loss: 770.1188 (beta=300.000, gamma=300.000)
2025-11-20 01:34:03 - INFO - [Epoch 150] Train loss: 770.1188 (beta=300.000, gamma=300.000)
[2025-11-20 01:34:04,581] [UniVITrainer] [INFO] [Epoch 150] Val loss: 772.0460 (beta=300.000, gamma=300.000)
2025-11-20 01:34:04 - INFO - [Epoch 150] Val loss: 772.0460 (beta=300.000, gamma=300.000)
[2025-11-20 01:34:19,387] [UniVITrainer] [INFO] [Epoch 152] New best val loss: 770.6269
2025-11-20 01:34:19 - INFO - [Epoch 152

[Config 7] Done in 24.5 min
  best_val_loss              = 769.069
  FOSCTTM (RNA vs ADT, val)  = 0.5799
  FOSCTTM (RNA vs ATAC, val) = 0.5184
[Config 7] FOSCTTM (ADT vs ATAC, val) = 0.4877
  Mean FOSCTTM (3 pairs)     = 0.5287
  Modality mixing (k=20)     = 0.0028
  Composite score            = 1178.97

[Config 8] Hyperparameters:
{
  "latent_dim": 64,
  "beta": 80.0,
  "gamma": 140.0,
  "lr": 0.001,
  "weight_decay": 1e-05,
  "encoder_dropout": 0.1,
  "decoder_batchnorm": false,
  "rna_arch": "rna_wide3",
  "adt_arch": "adt_med2",
  "atac_arch": "atac_med2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 01:40:20,903] [UniVITrainer] [INFO] [Epoch 001] Train loss: 2345.9685 (beta=80.000, gamma=140.000)
2025-11-20 01:40:20 - INFO - [Epoch 001] Train loss: 2345.9685 (beta=80.000, gamma=140.000)
[2025-11-20 01:40:21,737] [UniVITrainer] [INFO] [Epoch 001] Val loss: 1023.5419 (beta=80.000, gamma=140.000)
2025-11-20 01:40:21 - INFO - [Epoch 001] Val loss: 1023.5419 (beta=80.000, gamma=140.000)
[2025-11-20 01:40:21,789] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 1023.5419
2025-11-20 01:40:21 - INFO - [Epoch 001] New best val loss: 1023.5419
[2025-11-20 01:40:28,881] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 855.2831
2025-11-20 01:40:28 - INFO - [Epoch 002] New best val loss: 855.2831
[2025-11-20 01:40:36,693] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 821.8358
2025-11-20 01:40:36 - INFO - [Epoch 003] New best val loss: 821.8358
[2025-11-20 01:40:44,275] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 807.3366
2025-11-20 01:40:44 - INFO - [Epoch 0

2025-11-20 01:45:56 - INFO - [Epoch 046] New best val loss: 767.8908
[2025-11-20 01:46:03,413] [UniVITrainer] [INFO] [Epoch 047] New best val loss: 767.7378
2025-11-20 01:46:03 - INFO - [Epoch 047] New best val loss: 767.7378
[2025-11-20 01:46:23,769] [UniVITrainer] [INFO] [Epoch 050] Train loss: 770.7273 (beta=80.000, gamma=140.000)
2025-11-20 01:46:23 - INFO - [Epoch 050] Train loss: 770.7273 (beta=80.000, gamma=140.000)
[2025-11-20 01:46:23,903] [UniVITrainer] [INFO] [Epoch 050] Val loss: 767.7397 (beta=80.000, gamma=140.000)
2025-11-20 01:46:23 - INFO - [Epoch 050] Val loss: 767.7397 (beta=80.000, gamma=140.000)
[2025-11-20 01:46:31,743] [UniVITrainer] [INFO] [Epoch 051] New best val loss: 767.6363
2025-11-20 01:46:31 - INFO - [Epoch 051] New best val loss: 767.6363
[2025-11-20 01:46:39,561] [UniVITrainer] [INFO] [Epoch 052] New best val loss: 767.4896
2025-11-20 01:46:39 - INFO - [Epoch 052] New best val loss: 767.4896
[2025-11-20 01:46:46,449] [UniVITrainer] [INFO] [Epoch 053] Ne

[2025-11-20 01:55:05,703] [UniVITrainer] [INFO] [Epoch 120] Val loss: 766.4034 (beta=80.000, gamma=140.000)
2025-11-20 01:55:05 - INFO - [Epoch 120] Val loss: 766.4034 (beta=80.000, gamma=140.000)
[2025-11-20 01:55:13,448] [UniVITrainer] [INFO] [Epoch 121] New best val loss: 766.3757
2025-11-20 01:55:13 - INFO - [Epoch 121] New best val loss: 766.3757
[2025-11-20 01:55:20,196] [UniVITrainer] [INFO] [Epoch 122] New best val loss: 766.3545
2025-11-20 01:55:20 - INFO - [Epoch 122] New best val loss: 766.3545
[2025-11-20 01:56:17,857] [UniVITrainer] [INFO] [Epoch 130] Train loss: 769.0411 (beta=80.000, gamma=140.000)
2025-11-20 01:56:17 - INFO - [Epoch 130] Train loss: 769.0411 (beta=80.000, gamma=140.000)
[2025-11-20 01:56:18,667] [UniVITrainer] [INFO] [Epoch 130] Val loss: 766.3407 (beta=80.000, gamma=140.000)
2025-11-20 01:56:18 - INFO - [Epoch 130] Val loss: 766.3407 (beta=80.000, gamma=140.000)
[2025-11-20 01:56:18,821] [UniVITrainer] [INFO] [Epoch 130] New best val loss: 766.3407
202

[Config 8] Done in 22.2 min
  best_val_loss              = 766.300
  FOSCTTM (RNA vs ADT, val)  = 0.4878
  FOSCTTM (RNA vs ATAC, val) = 0.4649
[Config 8] FOSCTTM (ADT vs ATAC, val) = 0.4953
  Mean FOSCTTM (3 pairs)     = 0.4827
  Modality mixing (k=20)     = 0.0000
  Composite score            = 1136.18

[Config 9] Hyperparameters:
{
  "latent_dim": 100,
  "beta": 180.0,
  "gamma": 80.0,
  "lr": 0.001,
  "weight_decay": 1e-05,
  "encoder_dropout": 0.1,
  "decoder_batchnorm": true,
  "rna_arch": "rna_wide2",
  "adt_arch": "adt_med2",
  "atac_arch": "atac_med2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 02:02:37,497] [UniVITrainer] [INFO] [Epoch 001] Train loss: 3300.5204 (beta=180.000, gamma=80.000)
2025-11-20 02:02:37 - INFO - [Epoch 001] Train loss: 3300.5204 (beta=180.000, gamma=80.000)
[2025-11-20 02:02:38,333] [UniVITrainer] [INFO] [Epoch 001] Val loss: 1159.9046 (beta=180.000, gamma=80.000)
2025-11-20 02:02:38 - INFO - [Epoch 001] Val loss: 1159.9046 (beta=180.000, gamma=80.000)
[2025-11-20 02:02:38,500] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 1159.9046
2025-11-20 02:02:38 - INFO - [Epoch 001] New best val loss: 1159.9046
[2025-11-20 02:02:46,481] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 989.6286
2025-11-20 02:02:46 - INFO - [Epoch 002] New best val loss: 989.6286
[2025-11-20 02:02:54,347] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 919.9830
2025-11-20 02:02:54 - INFO - [Epoch 003] New best val loss: 919.9830
[2025-11-20 02:03:01,733] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 878.2198
2025-11-20 02:03:01 - INFO - [Epoch 0

2025-11-20 02:08:18 - INFO - [Epoch 047] New best val loss: 769.6745
[2025-11-20 02:08:26,191] [UniVITrainer] [INFO] [Epoch 048] New best val loss: 769.5408
2025-11-20 02:08:26 - INFO - [Epoch 048] New best val loss: 769.5408
[2025-11-20 02:08:32,370] [UniVITrainer] [INFO] [Epoch 049] New best val loss: 769.4303
2025-11-20 02:08:32 - INFO - [Epoch 049] New best val loss: 769.4303
[2025-11-20 02:08:39,277] [UniVITrainer] [INFO] [Epoch 050] Train loss: 775.3228 (beta=180.000, gamma=80.000)
2025-11-20 02:08:39 - INFO - [Epoch 050] Train loss: 775.3228 (beta=180.000, gamma=80.000)
[2025-11-20 02:08:40,116] [UniVITrainer] [INFO] [Epoch 050] Val loss: 769.5475 (beta=180.000, gamma=80.000)
2025-11-20 02:08:40 - INFO - [Epoch 050] Val loss: 769.5475 (beta=180.000, gamma=80.000)
[2025-11-20 02:08:47,791] [UniVITrainer] [INFO] [Epoch 051] New best val loss: 768.8100
2025-11-20 02:08:47 - INFO - [Epoch 051] New best val loss: 768.8100
[2025-11-20 02:09:02,976] [UniVITrainer] [INFO] [Epoch 053] Ne

2025-11-20 02:18:24 - INFO - [Epoch 130] Train loss: 768.0847 (beta=180.000, gamma=80.000)
[2025-11-20 02:18:25,841] [UniVITrainer] [INFO] [Epoch 130] Val loss: 766.5190 (beta=180.000, gamma=80.000)
2025-11-20 02:18:25 - INFO - [Epoch 130] Val loss: 766.5190 (beta=180.000, gamma=80.000)
[2025-11-20 02:18:39,830] [UniVITrainer] [INFO] [Epoch 132] New best val loss: 766.4932
2025-11-20 02:18:39 - INFO - [Epoch 132] New best val loss: 766.4932
[2025-11-20 02:19:35,728] [UniVITrainer] [INFO] [Epoch 140] Train loss: 764.4784 (beta=180.000, gamma=80.000)
2025-11-20 02:19:35 - INFO - [Epoch 140] Train loss: 764.4784 (beta=180.000, gamma=80.000)
[2025-11-20 02:19:36,127] [UniVITrainer] [INFO] [Epoch 140] Val loss: 766.6237 (beta=180.000, gamma=80.000)
2025-11-20 02:19:36 - INFO - [Epoch 140] Val loss: 766.6237 (beta=180.000, gamma=80.000)
[2025-11-20 02:19:57,574] [UniVITrainer] [INFO] [Epoch 143] New best val loss: 766.4482
2025-11-20 02:19:57 - INFO - [Epoch 143] New best val loss: 766.4482


[Config 9] Done in 20.4 min
  best_val_loss              = 766.401
  FOSCTTM (RNA vs ADT, val)  = 0.5038
  FOSCTTM (RNA vs ATAC, val) = 0.4367
[Config 9] FOSCTTM (ADT vs ATAC, val) = 0.4917
  Mean FOSCTTM (3 pairs)     = 0.4774
  Modality mixing (k=20)     = 0.0001
  Composite score            = 1132.35

[Config 10] Hyperparameters:
{
  "latent_dim": 50,
  "beta": 500.0,
  "gamma": 80.0,
  "lr": 0.001,
  "weight_decay": 1e-05,
  "encoder_dropout": 0.1,
  "decoder_batchnorm": false,
  "rna_arch": "rna_med2",
  "adt_arch": "adt_small2",
  "atac_arch": "atac_wide2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 02:23:07,000] [UniVITrainer] [INFO] [Epoch 001] Train loss: 3839.6963 (beta=500.000, gamma=80.000)
2025-11-20 02:23:07 - INFO - [Epoch 001] Train loss: 3839.6963 (beta=500.000, gamma=80.000)
[2025-11-20 02:23:07,836] [UniVITrainer] [INFO] [Epoch 001] Val loss: 1206.8608 (beta=500.000, gamma=80.000)
2025-11-20 02:23:07 - INFO - [Epoch 001] Val loss: 1206.8608 (beta=500.000, gamma=80.000)
[2025-11-20 02:23:08,038] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 1206.8608
2025-11-20 02:23:08 - INFO - [Epoch 001] New best val loss: 1206.8608
[2025-11-20 02:23:15,826] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 936.4517
2025-11-20 02:23:15 - INFO - [Epoch 002] New best val loss: 936.4517
[2025-11-20 02:23:23,648] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 859.1669
2025-11-20 02:23:23 - INFO - [Epoch 003] New best val loss: 859.1669
[2025-11-20 02:23:30,803] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 835.9442
2025-11-20 02:23:30 - INFO - [Epoch 0

[2025-11-20 02:29:42,318] [UniVITrainer] [INFO] [Epoch 054] New best val loss: 773.3922
2025-11-20 02:29:42 - INFO - [Epoch 054] New best val loss: 773.3922
[2025-11-20 02:30:27,599] [UniVITrainer] [INFO] [Epoch 060] Train loss: 772.1972 (beta=500.000, gamma=80.000)
2025-11-20 02:30:27 - INFO - [Epoch 060] Train loss: 772.1972 (beta=500.000, gamma=80.000)
[2025-11-20 02:30:28,404] [UniVITrainer] [INFO] [Epoch 060] Val loss: 773.3474 (beta=500.000, gamma=80.000)
2025-11-20 02:30:28 - INFO - [Epoch 060] Val loss: 773.3474 (beta=500.000, gamma=80.000)
[2025-11-20 02:30:28,592] [UniVITrainer] [INFO] [Epoch 060] New best val loss: 773.3474
2025-11-20 02:30:28 - INFO - [Epoch 060] New best val loss: 773.3474
[2025-11-20 02:30:44,200] [UniVITrainer] [INFO] [Epoch 062] New best val loss: 773.2745
2025-11-20 02:30:44 - INFO - [Epoch 062] New best val loss: 773.2745
[2025-11-20 02:30:58,323] [UniVITrainer] [INFO] [Epoch 064] New best val loss: 773.2562
2025-11-20 02:30:58 - INFO - [Epoch 064] Ne

[Config 10] Done in 12.5 min
  best_val_loss              = 771.155
  FOSCTTM (RNA vs ADT, val)  = 0.5031
  FOSCTTM (RNA vs ATAC, val) = 0.4757
[Config 10] FOSCTTM (ADT vs ATAC, val) = 0.4989
  Mean FOSCTTM (3 pairs)     = 0.4926
  Modality mixing (k=20)     = 0.0300
  Composite score            = 1185.48

[Config 11] Hyperparameters:
{
  "latent_dim": 64,
  "beta": 400.0,
  "gamma": 80.0,
  "lr": 0.001,
  "weight_decay": 1e-05,
  "encoder_dropout": 0.1,
  "decoder_batchnorm": false,
  "rna_arch": "rna_med2",
  "adt_arch": "adt_med2",
  "atac_arch": "atac_med2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 02:35:39,736] [UniVITrainer] [INFO] [Epoch 001] Train loss: 3862.3606 (beta=400.000, gamma=80.000)
2025-11-20 02:35:39 - INFO - [Epoch 001] Train loss: 3862.3606 (beta=400.000, gamma=80.000)
[2025-11-20 02:35:40,558] [UniVITrainer] [INFO] [Epoch 001] Val loss: 1160.6260 (beta=400.000, gamma=80.000)
2025-11-20 02:35:40 - INFO - [Epoch 001] Val loss: 1160.6260 (beta=400.000, gamma=80.000)
[2025-11-20 02:35:40,654] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 1160.6260
2025-11-20 02:35:40 - INFO - [Epoch 001] New best val loss: 1160.6260
[2025-11-20 02:35:48,467] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 962.4774
2025-11-20 02:35:48 - INFO - [Epoch 002] New best val loss: 962.4774
[2025-11-20 02:35:55,576] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 866.5616
2025-11-20 02:35:55 - INFO - [Epoch 003] New best val loss: 866.5616
[2025-11-20 02:36:03,424] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 842.7882
2025-11-20 02:36:03 - INFO - [Epoch 0

2025-11-20 02:43:28 - INFO - [Epoch 066] New best val loss: 772.7369
[2025-11-20 02:43:50,476] [UniVITrainer] [INFO] [Epoch 069] New best val loss: 772.4170
2025-11-20 02:43:50 - INFO - [Epoch 069] New best val loss: 772.4170
[2025-11-20 02:43:57,066] [UniVITrainer] [INFO] [Epoch 070] Train loss: 774.5637 (beta=400.000, gamma=80.000)
2025-11-20 02:43:57 - INFO - [Epoch 070] Train loss: 774.5637 (beta=400.000, gamma=80.000)
[2025-11-20 02:43:57,859] [UniVITrainer] [INFO] [Epoch 070] Val loss: 771.6141 (beta=400.000, gamma=80.000)
2025-11-20 02:43:57 - INFO - [Epoch 070] Val loss: 771.6141 (beta=400.000, gamma=80.000)
[2025-11-20 02:43:57,946] [UniVITrainer] [INFO] [Epoch 070] New best val loss: 771.6141
2025-11-20 02:43:57 - INFO - [Epoch 070] New best val loss: 771.6141
[2025-11-20 02:44:05,586] [UniVITrainer] [INFO] [Epoch 071] New best val loss: 771.2326
2025-11-20 02:44:05 - INFO - [Epoch 071] New best val loss: 771.2326
[2025-11-20 02:44:20,103] [UniVITrainer] [INFO] [Epoch 073] Ne

2025-11-20 02:50:19 - INFO - [Epoch 123] New best val loss: 766.6147
[2025-11-20 02:50:42,342] [UniVITrainer] [INFO] [Epoch 126] New best val loss: 766.5751
2025-11-20 02:50:42 - INFO - [Epoch 126] New best val loss: 766.5751
[2025-11-20 02:51:11,374] [UniVITrainer] [INFO] [Epoch 130] Train loss: 766.7717 (beta=400.000, gamma=80.000)
2025-11-20 02:51:11 - INFO - [Epoch 130] Train loss: 766.7717 (beta=400.000, gamma=80.000)
[2025-11-20 02:51:12,182] [UniVITrainer] [INFO] [Epoch 130] Val loss: 766.5674 (beta=400.000, gamma=80.000)
2025-11-20 02:51:12 - INFO - [Epoch 130] Val loss: 766.5674 (beta=400.000, gamma=80.000)
[2025-11-20 02:51:12,245] [UniVITrainer] [INFO] [Epoch 130] New best val loss: 766.5674
2025-11-20 02:51:12 - INFO - [Epoch 130] New best val loss: 766.5674
[2025-11-20 02:51:34,598] [UniVITrainer] [INFO] [Epoch 133] New best val loss: 766.5673
2025-11-20 02:51:34 - INFO - [Epoch 133] New best val loss: 766.5673
[2025-11-20 02:51:49,622] [UniVITrainer] [INFO] [Epoch 135] Ne

[Config 11] Done in 22.3 min
  best_val_loss              = 766.375
  FOSCTTM (RNA vs ADT, val)  = 0.5204
  FOSCTTM (RNA vs ATAC, val) = 0.5112
[Config 11] FOSCTTM (ADT vs ATAC, val) = 0.5148
  Mean FOSCTTM (3 pairs)     = 0.5155
  Modality mixing (k=20)     = 0.0000
  Composite score            = 1161.42

[Config 12] Hyperparameters:
{
  "latent_dim": 86,
  "beta": 40.0,
  "gamma": 500.0,
  "lr": 0.001,
  "weight_decay": 1e-05,
  "encoder_dropout": 0.1,
  "decoder_batchnorm": true,
  "rna_arch": "rna_med2",
  "adt_arch": "adt_med2",
  "atac_arch": "atac_small2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 02:57:57,710] [UniVITrainer] [INFO] [Epoch 001] Train loss: 6842.1815 (beta=40.000, gamma=500.000)
2025-11-20 02:57:57 - INFO - [Epoch 001] Train loss: 6842.1815 (beta=40.000, gamma=500.000)
[2025-11-20 02:57:58,524] [UniVITrainer] [INFO] [Epoch 001] Val loss: 1845.0099 (beta=40.000, gamma=500.000)
2025-11-20 02:57:58 - INFO - [Epoch 001] Val loss: 1845.0099 (beta=40.000, gamma=500.000)
[2025-11-20 02:57:58,662] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 1845.0099
2025-11-20 02:57:58 - INFO - [Epoch 001] New best val loss: 1845.0099
[2025-11-20 02:58:05,901] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 1164.0013
2025-11-20 02:58:05 - INFO - [Epoch 002] New best val loss: 1164.0013
[2025-11-20 02:58:13,591] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 997.6691
2025-11-20 02:58:13 - INFO - [Epoch 003] New best val loss: 997.6691
[2025-11-20 02:58:21,159] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 933.7699
2025-11-20 02:58:21 - INFO - [Epoch

2025-11-20 03:03:25 - INFO - [Epoch 045] New best val loss: 769.9369
[2025-11-20 03:03:40,953] [UniVITrainer] [INFO] [Epoch 047] New best val loss: 769.4641
2025-11-20 03:03:40 - INFO - [Epoch 047] New best val loss: 769.4641
[2025-11-20 03:03:56,178] [UniVITrainer] [INFO] [Epoch 049] New best val loss: 769.4073
2025-11-20 03:03:56 - INFO - [Epoch 049] New best val loss: 769.4073
[2025-11-20 03:04:02,996] [UniVITrainer] [INFO] [Epoch 050] Train loss: 776.8928 (beta=40.000, gamma=500.000)
2025-11-20 03:04:02 - INFO - [Epoch 050] Train loss: 776.8928 (beta=40.000, gamma=500.000)
[2025-11-20 03:04:03,875] [UniVITrainer] [INFO] [Epoch 050] Val loss: 769.7122 (beta=40.000, gamma=500.000)
2025-11-20 03:04:03 - INFO - [Epoch 050] Val loss: 769.7122 (beta=40.000, gamma=500.000)
[2025-11-20 03:04:11,823] [UniVITrainer] [INFO] [Epoch 051] New best val loss: 769.2517
2025-11-20 03:04:11 - INFO - [Epoch 051] New best val loss: 769.2517
[2025-11-20 03:04:18,656] [UniVITrainer] [INFO] [Epoch 052] Ne

[2025-11-20 03:14:54,801] [UniVITrainer] [INFO] [Epoch 140] Val loss: 766.5759 (beta=40.000, gamma=500.000)
2025-11-20 03:14:54 - INFO - [Epoch 140] Val loss: 766.5759 (beta=40.000, gamma=500.000)
[2025-11-20 03:15:31,292] [UniVITrainer] [INFO] [Epoch 145] New best val loss: 766.4356
2025-11-20 03:15:31 - INFO - [Epoch 145] New best val loss: 766.4356
[2025-11-20 03:16:05,933] [UniVITrainer] [INFO] [Epoch 150] Train loss: 764.7147 (beta=40.000, gamma=500.000)
2025-11-20 03:16:05 - INFO - [Epoch 150] Train loss: 764.7147 (beta=40.000, gamma=500.000)
[2025-11-20 03:16:06,740] [UniVITrainer] [INFO] [Epoch 150] Val loss: 766.6903 (beta=40.000, gamma=500.000)
2025-11-20 03:16:06 - INFO - [Epoch 150] Val loss: 766.6903 (beta=40.000, gamma=500.000)
[2025-11-20 03:16:21,690] [UniVITrainer] [INFO] [Epoch 152] New best val loss: 766.3920
2025-11-20 03:16:21 - INFO - [Epoch 152] New best val loss: 766.3920
[2025-11-20 03:16:50,176] [UniVITrainer] [INFO] [Epoch 156] New best val loss: 766.3887
202

[Config 12] Done in 24.2 min
  best_val_loss              = 766.339
  FOSCTTM (RNA vs ADT, val)  = 0.5020
  FOSCTTM (RNA vs ATAC, val) = 0.1281
[Config 12] FOSCTTM (ADT vs ATAC, val) = 0.1180
  Mean FOSCTTM (3 pairs)     = 0.2493
  Modality mixing (k=20)     = 0.0001
  Composite score            = 957.55

[Config 13] Hyperparameters:
{
  "latent_dim": 86,
  "beta": 100.0,
  "gamma": 1000.0,
  "lr": 0.001,
  "weight_decay": 1e-05,
  "encoder_dropout": 0.0,
  "decoder_batchnorm": true,
  "rna_arch": "rna_wide2",
  "adt_arch": "adt_med2",
  "atac_arch": "atac_small2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 03:22:10,435] [UniVITrainer] [INFO] [Epoch 001] Train loss: 9377.3006 (beta=100.000, gamma=1000.000)
2025-11-20 03:22:10 - INFO - [Epoch 001] Train loss: 9377.3006 (beta=100.000, gamma=1000.000)
[2025-11-20 03:22:11,250] [UniVITrainer] [INFO] [Epoch 001] Val loss: 2789.3747 (beta=100.000, gamma=1000.000)
2025-11-20 03:22:11 - INFO - [Epoch 001] Val loss: 2789.3747 (beta=100.000, gamma=1000.000)
[2025-11-20 03:22:11,408] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 2789.3747
2025-11-20 03:22:11 - INFO - [Epoch 001] New best val loss: 2789.3747
[2025-11-20 03:22:19,062] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 1514.6497
2025-11-20 03:22:19 - INFO - [Epoch 002] New best val loss: 1514.6497
[2025-11-20 03:22:25,881] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 1244.8735
2025-11-20 03:22:25 - INFO - [Epoch 003] New best val loss: 1244.8735
[2025-11-20 03:22:33,584] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 1086.3174
2025-11-20 03:22:33 - IN

2025-11-20 03:27:59 - INFO - [Epoch 049] New best val loss: 784.0862
[2025-11-20 03:28:05,882] [UniVITrainer] [INFO] [Epoch 050] Train loss: 781.0376 (beta=100.000, gamma=1000.000)
2025-11-20 03:28:05 - INFO - [Epoch 050] Train loss: 781.0376 (beta=100.000, gamma=1000.000)
[2025-11-20 03:28:06,690] [UniVITrainer] [INFO] [Epoch 050] Val loss: 785.4004 (beta=100.000, gamma=1000.000)
2025-11-20 03:28:06 - INFO - [Epoch 050] Val loss: 785.4004 (beta=100.000, gamma=1000.000)
[2025-11-20 03:28:22,215] [UniVITrainer] [INFO] [Epoch 052] New best val loss: 783.0244
2025-11-20 03:28:22 - INFO - [Epoch 052] New best val loss: 783.0244
[2025-11-20 03:28:37,643] [UniVITrainer] [INFO] [Epoch 054] New best val loss: 781.1301
2025-11-20 03:28:37 - INFO - [Epoch 054] New best val loss: 781.1301
[2025-11-20 03:29:00,850] [UniVITrainer] [INFO] [Epoch 057] New best val loss: 780.1903
2025-11-20 03:29:00 - INFO - [Epoch 057] New best val loss: 780.1903
[2025-11-20 03:29:21,858] [UniVITrainer] [INFO] [Epoch

[2025-11-20 03:39:32,413] [UniVITrainer] [INFO] [Epoch 144] New best val loss: 770.8977
2025-11-20 03:39:32 - INFO - [Epoch 144] New best val loss: 770.8977
[2025-11-20 03:40:17,090] [UniVITrainer] [INFO] [Epoch 150] Train loss: 771.1219 (beta=100.000, gamma=1000.000)
2025-11-20 03:40:17 - INFO - [Epoch 150] Train loss: 771.1219 (beta=100.000, gamma=1000.000)
[2025-11-20 03:40:17,961] [UniVITrainer] [INFO] [Epoch 150] Val loss: 773.0557 (beta=100.000, gamma=1000.000)
2025-11-20 03:40:17 - INFO - [Epoch 150] Val loss: 773.0557 (beta=100.000, gamma=1000.000)
[2025-11-20 03:41:31,574] [UniVITrainer] [INFO] [Epoch 160] Train loss: 773.2674 (beta=100.000, gamma=1000.000)
2025-11-20 03:41:31 - INFO - [Epoch 160] Train loss: 773.2674 (beta=100.000, gamma=1000.000)
[2025-11-20 03:41:32,406] [UniVITrainer] [INFO] [Epoch 160] Val loss: 771.4453 (beta=100.000, gamma=1000.000)
2025-11-20 03:41:32 - INFO - [Epoch 160] Val loss: 771.4453 (beta=100.000, gamma=1000.000)
[2025-11-20 03:41:55,136] [UniV

[Config 13] Done in 22.4 min
  best_val_loss              = 769.723
  FOSCTTM (RNA vs ADT, val)  = 0.5518
  FOSCTTM (RNA vs ATAC, val) = 0.4869
[Config 13] FOSCTTM (ADT vs ATAC, val) = 0.4775
  Mean FOSCTTM (3 pairs)     = 0.5054
  Modality mixing (k=20)     = 0.0162
  Composite score            = 1177.45

[Config 14] Hyperparameters:
{
  "latent_dim": 156,
  "beta": 300.0,
  "gamma": 500.0,
  "lr": 0.001,
  "weight_decay": 1e-05,
  "encoder_dropout": 0.1,
  "decoder_batchnorm": false,
  "rna_arch": "rna_wide2",
  "adt_arch": "adt_med2",
  "atac_arch": "atac_small2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 03:44:33,582] [UniVITrainer] [INFO] [Epoch 001] Train loss: 15461.9050 (beta=300.000, gamma=500.000)
2025-11-20 03:44:33 - INFO - [Epoch 001] Train loss: 15461.9050 (beta=300.000, gamma=500.000)
[2025-11-20 03:44:34,257] [UniVITrainer] [INFO] [Epoch 001] Val loss: 3330.2022 (beta=300.000, gamma=500.000)
2025-11-20 03:44:34 - INFO - [Epoch 001] Val loss: 3330.2022 (beta=300.000, gamma=500.000)
[2025-11-20 03:44:34,424] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 3330.2022
2025-11-20 03:44:34 - INFO - [Epoch 001] New best val loss: 3330.2022
[2025-11-20 03:44:41,998] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 1729.6381
2025-11-20 03:44:41 - INFO - [Epoch 002] New best val loss: 1729.6381
[2025-11-20 03:44:49,576] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 1366.0947
2025-11-20 03:44:49 - INFO - [Epoch 003] New best val loss: 1366.0947
[2025-11-20 03:44:57,037] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 1210.5511
2025-11-20 03:44:57 - INFO

[2025-11-20 03:51:08,903] [UniVITrainer] [INFO] [Epoch 054] New best val loss: 771.7796
2025-11-20 03:51:08 - INFO - [Epoch 054] New best val loss: 771.7796
[2025-11-20 03:51:36,079] [UniVITrainer] [INFO] [Epoch 058] New best val loss: 771.7263
2025-11-20 03:51:36 - INFO - [Epoch 058] New best val loss: 771.7263
[2025-11-20 03:51:48,501] [UniVITrainer] [INFO] [Epoch 060] Train loss: 781.7937 (beta=300.000, gamma=500.000)
2025-11-20 03:51:48 - INFO - [Epoch 060] Train loss: 781.7937 (beta=300.000, gamma=500.000)
[2025-11-20 03:51:49,182] [UniVITrainer] [INFO] [Epoch 060] Val loss: 771.5208 (beta=300.000, gamma=500.000)
2025-11-20 03:51:49 - INFO - [Epoch 060] Val loss: 771.5208 (beta=300.000, gamma=500.000)
[2025-11-20 03:51:49,348] [UniVITrainer] [INFO] [Epoch 060] New best val loss: 771.5208
2025-11-20 03:51:49 - INFO - [Epoch 060] New best val loss: 771.5208
[2025-11-20 03:51:56,843] [UniVITrainer] [INFO] [Epoch 061] New best val loss: 771.5095
2025-11-20 03:51:56 - INFO - [Epoch 061

[2025-11-20 04:02:42,388] [UniVITrainer] [INFO] [Epoch 150] New best val loss: 766.9038
2025-11-20 04:02:42 - INFO - [Epoch 150] New best val loss: 766.9038
[2025-11-20 04:02:49,996] [UniVITrainer] [INFO] [Epoch 151] New best val loss: 766.8493
2025-11-20 04:02:49 - INFO - [Epoch 151] New best val loss: 766.8493
[2025-11-20 04:03:54,806] [UniVITrainer] [INFO] [Epoch 160] Train loss: 766.8905 (beta=300.000, gamma=500.000)
2025-11-20 04:03:54 - INFO - [Epoch 160] Train loss: 766.8905 (beta=300.000, gamma=500.000)
[2025-11-20 04:03:55,643] [UniVITrainer] [INFO] [Epoch 160] Val loss: 766.9240 (beta=300.000, gamma=500.000)
2025-11-20 04:03:55 - INFO - [Epoch 160] Val loss: 766.9240 (beta=300.000, gamma=500.000)
[2025-11-20 04:05:04,543] [UniVITrainer] [INFO] [Epoch 170] Train loss: 766.8541 (beta=300.000, gamma=500.000)
2025-11-20 04:05:04 - INFO - [Epoch 170] Train loss: 766.8541 (beta=300.000, gamma=500.000)
[2025-11-20 04:05:04,657] [UniVITrainer] [INFO] [Epoch 170] Val loss: 767.2031 (b

[Config 14] Done in 20.7 min
  best_val_loss              = 766.849
  FOSCTTM (RNA vs ADT, val)  = 0.5376
  FOSCTTM (RNA vs ATAC, val) = 0.0633
[Config 14] FOSCTTM (ADT vs ATAC, val) = 0.0920
  Mean FOSCTTM (3 pairs)     = 0.2310
  Modality mixing (k=20)     = 0.0000
  Composite score            = 943.96

[Config 15] Hyperparameters:
{
  "latent_dim": 72,
  "beta": 240.0,
  "gamma": 400.0,
  "lr": 0.0005,
  "weight_decay": 0.0001,
  "encoder_dropout": 0.0,
  "decoder_batchnorm": false,
  "rna_arch": "rna_wide3",
  "adt_arch": "adt_small2",
  "atac_arch": "atac_med2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 04:05:21,287] [UniVITrainer] [INFO] [Epoch 001] Train loss: 7574.5838 (beta=240.000, gamma=400.000)
2025-11-20 04:05:21 - INFO - [Epoch 001] Train loss: 7574.5838 (beta=240.000, gamma=400.000)
[2025-11-20 04:05:22,119] [UniVITrainer] [INFO] [Epoch 001] Val loss: 3522.7370 (beta=240.000, gamma=400.000)
2025-11-20 04:05:22 - INFO - [Epoch 001] Val loss: 3522.7370 (beta=240.000, gamma=400.000)
[2025-11-20 04:05:22,312] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 3522.7370
2025-11-20 04:05:22 - INFO - [Epoch 001] New best val loss: 3522.7370
[2025-11-20 04:05:30,029] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 1886.8941
2025-11-20 04:05:30 - INFO - [Epoch 002] New best val loss: 1886.8941
[2025-11-20 04:05:37,867] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 1306.4879
2025-11-20 04:05:37 - INFO - [Epoch 003] New best val loss: 1306.4879
[2025-11-20 04:05:45,586] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 1103.2411
2025-11-20 04:05:45 - INFO -

2025-11-20 04:10:49 - INFO - [Epoch 044] New best val loss: 800.6073
[2025-11-20 04:10:57,438] [UniVITrainer] [INFO] [Epoch 045] New best val loss: 799.4038
2025-11-20 04:10:57 - INFO - [Epoch 045] New best val loss: 799.4038
[2025-11-20 04:11:12,034] [UniVITrainer] [INFO] [Epoch 047] New best val loss: 799.2847
2025-11-20 04:11:12 - INFO - [Epoch 047] New best val loss: 799.2847
[2025-11-20 04:11:19,412] [UniVITrainer] [INFO] [Epoch 048] New best val loss: 798.4652
2025-11-20 04:11:19 - INFO - [Epoch 048] New best val loss: 798.4652
[2025-11-20 04:11:27,306] [UniVITrainer] [INFO] [Epoch 049] New best val loss: 797.9280
2025-11-20 04:11:27 - INFO - [Epoch 049] New best val loss: 797.9280
[2025-11-20 04:11:34,313] [UniVITrainer] [INFO] [Epoch 050] Train loss: 778.6434 (beta=240.000, gamma=400.000)
2025-11-20 04:11:34 - INFO - [Epoch 050] Train loss: 778.6434 (beta=240.000, gamma=400.000)
[2025-11-20 04:11:35,124] [UniVITrainer] [INFO] [Epoch 050] Val loss: 798.2448 (beta=240.000, gamma=

[2025-11-20 04:19:08,144] [UniVITrainer] [INFO] [Epoch 112] New best val loss: 785.7546
2025-11-20 04:19:08 - INFO - [Epoch 112] New best val loss: 785.7546
[2025-11-20 04:19:15,741] [UniVITrainer] [INFO] [Epoch 113] New best val loss: 785.6065
2025-11-20 04:19:15 - INFO - [Epoch 113] New best val loss: 785.6065
[2025-11-20 04:19:28,880] [UniVITrainer] [INFO] [Epoch 115] New best val loss: 784.9755
2025-11-20 04:19:28 - INFO - [Epoch 115] New best val loss: 784.9755
[2025-11-20 04:19:57,738] [UniVITrainer] [INFO] [Epoch 119] New best val loss: 784.6952
2025-11-20 04:19:57 - INFO - [Epoch 119] New best val loss: 784.6952
[2025-11-20 04:20:04,533] [UniVITrainer] [INFO] [Epoch 120] Train loss: 767.7931 (beta=240.000, gamma=400.000)
2025-11-20 04:20:04 - INFO - [Epoch 120] Train loss: 767.7931 (beta=240.000, gamma=400.000)
[2025-11-20 04:20:05,349] [UniVITrainer] [INFO] [Epoch 120] Val loss: 784.7262 (beta=240.000, gamma=400.000)
2025-11-20 04:20:05 - INFO - [Epoch 120] Val loss: 784.7262 

[2025-11-20 04:29:45,810] [UniVITrainer] [INFO] [Epoch 199] New best val loss: 778.6556
2025-11-20 04:29:45 - INFO - [Epoch 199] New best val loss: 778.6556
[2025-11-20 04:29:52,707] [UniVITrainer] [INFO] [Epoch 200] Train loss: 769.2086 (beta=240.000, gamma=400.000)
2025-11-20 04:29:52 - INFO - [Epoch 200] Train loss: 769.2086 (beta=240.000, gamma=400.000)
[2025-11-20 04:29:53,511] [UniVITrainer] [INFO] [Epoch 200] Val loss: 778.5659 (beta=240.000, gamma=400.000)
2025-11-20 04:29:53 - INFO - [Epoch 200] Val loss: 778.5659 (beta=240.000, gamma=400.000)
[2025-11-20 04:29:53,660] [UniVITrainer] [INFO] [Epoch 200] New best val loss: 778.5659
2025-11-20 04:29:53 - INFO - [Epoch 200] New best val loss: 778.5659
[2025-11-20 04:29:53,700] [UniVITrainer] [INFO] Restored best model from epoch 200 (val loss = 778.5659)
2025-11-20 04:29:53 - INFO - Restored best model from epoch 200 (val loss = 778.5659)
[2025-11-20 04:29:56,195] [UniVITrainer] [INFO] TrainingConfig:
2025-11-20 04:29:56 - INFO - 

[Config 15] Done in 24.7 min
  best_val_loss              = 778.566
  FOSCTTM (RNA vs ADT, val)  = 0.5131
  FOSCTTM (RNA vs ATAC, val) = 0.5114
[Config 15] FOSCTTM (ADT vs ATAC, val) = 0.4986
  Mean FOSCTTM (3 pairs)     = 0.5077
  Modality mixing (k=20)     = 0.0424
  Composite score            = 1223.61

[Config 16] Hyperparameters:
{
  "latent_dim": 20,
  "beta": 140.0,
  "gamma": 1000.0,
  "lr": 0.0005,
  "weight_decay": 0.0001,
  "encoder_dropout": 0.1,
  "decoder_batchnorm": true,
  "rna_arch": "rna_wide3",
  "adt_arch": "adt_small2",
  "atac_arch": "atac_med2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 04:30:02,172] [UniVITrainer] [INFO] [Epoch 001] Train loss: 4963.0020 (beta=140.000, gamma=1000.000)
2025-11-20 04:30:02 - INFO - [Epoch 001] Train loss: 4963.0020 (beta=140.000, gamma=1000.000)
[2025-11-20 04:30:02,918] [UniVITrainer] [INFO] [Epoch 001] Val loss: 1781.3182 (beta=140.000, gamma=1000.000)
2025-11-20 04:30:02 - INFO - [Epoch 001] Val loss: 1781.3182 (beta=140.000, gamma=1000.000)
[2025-11-20 04:30:03,067] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 1781.3182
2025-11-20 04:30:03 - INFO - [Epoch 001] New best val loss: 1781.3182
[2025-11-20 04:30:11,035] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 1376.5470
2025-11-20 04:30:11 - INFO - [Epoch 002] New best val loss: 1376.5470
[2025-11-20 04:30:18,947] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 1119.5001
2025-11-20 04:30:18 - INFO - [Epoch 003] New best val loss: 1119.5001
[2025-11-20 04:30:26,522] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 994.0175
2025-11-20 04:30:26 - INF

2025-11-20 04:35:41 - INFO - [Epoch 047] New best val loss: 774.6148
[2025-11-20 04:35:49,618] [UniVITrainer] [INFO] [Epoch 048] New best val loss: 774.5492
2025-11-20 04:35:49 - INFO - [Epoch 048] New best val loss: 774.5492
[2025-11-20 04:35:57,091] [UniVITrainer] [INFO] [Epoch 049] New best val loss: 773.9579
2025-11-20 04:35:57 - INFO - [Epoch 049] New best val loss: 773.9579
[2025-11-20 04:36:03,906] [UniVITrainer] [INFO] [Epoch 050] Train loss: 794.1900 (beta=140.000, gamma=1000.000)
2025-11-20 04:36:03 - INFO - [Epoch 050] Train loss: 794.1900 (beta=140.000, gamma=1000.000)
[2025-11-20 04:36:04,716] [UniVITrainer] [INFO] [Epoch 050] Val loss: 773.5939 (beta=140.000, gamma=1000.000)
2025-11-20 04:36:04 - INFO - [Epoch 050] Val loss: 773.5939 (beta=140.000, gamma=1000.000)
[2025-11-20 04:36:04,859] [UniVITrainer] [INFO] [Epoch 050] New best val loss: 773.5939
2025-11-20 04:36:04 - INFO - [Epoch 050] New best val loss: 773.5939
[2025-11-20 04:36:19,775] [UniVITrainer] [INFO] [Epoch

2025-11-20 04:43:21 - INFO - [Epoch 109] New best val loss: 767.4904
[2025-11-20 04:43:28,325] [UniVITrainer] [INFO] [Epoch 110] Train loss: 768.9355 (beta=140.000, gamma=1000.000)
2025-11-20 04:43:28 - INFO - [Epoch 110] Train loss: 768.9355 (beta=140.000, gamma=1000.000)
[2025-11-20 04:43:29,136] [UniVITrainer] [INFO] [Epoch 110] Val loss: 767.4761 (beta=140.000, gamma=1000.000)
2025-11-20 04:43:29 - INFO - [Epoch 110] Val loss: 767.4761 (beta=140.000, gamma=1000.000)
[2025-11-20 04:43:29,298] [UniVITrainer] [INFO] [Epoch 110] New best val loss: 767.4761
2025-11-20 04:43:29 - INFO - [Epoch 110] New best val loss: 767.4761
[2025-11-20 04:43:52,187] [UniVITrainer] [INFO] [Epoch 113] New best val loss: 767.4667
2025-11-20 04:43:52 - INFO - [Epoch 113] New best val loss: 767.4667
[2025-11-20 04:44:15,055] [UniVITrainer] [INFO] [Epoch 116] New best val loss: 767.3659
2025-11-20 04:44:15 - INFO - [Epoch 116] New best val loss: 767.3659
[2025-11-20 04:44:35,208] [UniVITrainer] [INFO] [Epoch

[2025-11-20 04:54:27,952] [UniVITrainer] [INFO] TrainingConfig:
2025-11-20 04:54:27 - INFO - TrainingConfig:
[2025-11-20 04:54:27,954] [UniVITrainer] [INFO]   n_epochs: 200
2025-11-20 04:54:27 - INFO -   n_epochs: 200
[2025-11-20 04:54:27,961] [UniVITrainer] [INFO]   batch_size: 256
2025-11-20 04:54:27 - INFO -   batch_size: 256
[2025-11-20 04:54:27,965] [UniVITrainer] [INFO]   lr: 0.001
2025-11-20 04:54:27 - INFO -   lr: 0.001
[2025-11-20 04:54:27,969] [UniVITrainer] [INFO]   weight_decay: 0.0001
2025-11-20 04:54:27 - INFO -   weight_decay: 0.0001
[2025-11-20 04:54:27,971] [UniVITrainer] [INFO]   device: cuda
2025-11-20 04:54:27 - INFO -   device: cuda
[2025-11-20 04:54:27,972] [UniVITrainer] [INFO]   log_every: 10
2025-11-20 04:54:27 - INFO -   log_every: 10
[2025-11-20 04:54:27,980] [UniVITrainer] [INFO]   grad_clip: None
2025-11-20 04:54:27 - INFO -   grad_clip: None
[2025-11-20 04:54:27,981] [UniVITrainer] [INFO]   num_workers: 0
2025-11-20 04:54:27 - INFO -   num_workers: 0
[2025

[Config 16] Done in 24.5 min
  best_val_loss              = 766.520
  FOSCTTM (RNA vs ADT, val)  = 0.4305
  FOSCTTM (RNA vs ATAC, val) = 0.4869
[Config 16] FOSCTTM (ADT vs ATAC, val) = 0.4883
  Mean FOSCTTM (3 pairs)     = 0.4686
  Modality mixing (k=20)     = 0.1023
  Composite score            = 1240.88

[Config 17] Hyperparameters:
{
  "latent_dim": 156,
  "beta": 400.0,
  "gamma": 100.0,
  "lr": 0.001,
  "weight_decay": 0.0001,
  "encoder_dropout": 0.1,
  "decoder_batchnorm": false,
  "rna_arch": "rna_wide3",
  "adt_arch": "adt_small2",
  "atac_arch": "atac_med2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 04:54:34,375] [UniVITrainer] [INFO] [Epoch 001] Train loss: 9907.5175 (beta=400.000, gamma=100.000)
2025-11-20 04:54:34 - INFO - [Epoch 001] Train loss: 9907.5175 (beta=400.000, gamma=100.000)
[2025-11-20 04:54:35,184] [UniVITrainer] [INFO] [Epoch 001] Val loss: 2949.1383 (beta=400.000, gamma=100.000)
2025-11-20 04:54:35 - INFO - [Epoch 001] Val loss: 2949.1383 (beta=400.000, gamma=100.000)
[2025-11-20 04:54:35,436] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 2949.1383
2025-11-20 04:54:35 - INFO - [Epoch 001] New best val loss: 2949.1383
[2025-11-20 04:54:43,192] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 1400.1966
2025-11-20 04:54:43 - INFO - [Epoch 002] New best val loss: 1400.1966
[2025-11-20 04:54:50,674] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 1090.1801
2025-11-20 04:54:50 - INFO - [Epoch 003] New best val loss: 1090.1801
[2025-11-20 04:54:57,964] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 988.8232
2025-11-20 04:54:57 - INFO - 

2025-11-20 05:01:41 - INFO - [Epoch 060] Val loss: 798.2920 (beta=400.000, gamma=100.000)
[2025-11-20 05:01:49,744] [UniVITrainer] [INFO] [Epoch 061] New best val loss: 792.8147
2025-11-20 05:01:49 - INFO - [Epoch 061] New best val loss: 792.8147
[2025-11-20 05:02:12,339] [UniVITrainer] [INFO] [Epoch 064] New best val loss: 790.8530
2025-11-20 05:02:12 - INFO - [Epoch 064] New best val loss: 790.8530
[2025-11-20 05:02:20,126] [UniVITrainer] [INFO] [Epoch 065] New best val loss: 790.1571
2025-11-20 05:02:20 - INFO - [Epoch 065] New best val loss: 790.1571
[2025-11-20 05:02:55,628] [UniVITrainer] [INFO] [Epoch 070] Train loss: 781.1760 (beta=400.000, gamma=100.000)
2025-11-20 05:02:55 - INFO - [Epoch 070] Train loss: 781.1760 (beta=400.000, gamma=100.000)
[2025-11-20 05:02:56,452] [UniVITrainer] [INFO] [Epoch 070] Val loss: 792.9806 (beta=400.000, gamma=100.000)
2025-11-20 05:02:56 - INFO - [Epoch 070] Val loss: 792.9806 (beta=400.000, gamma=100.000)
[2025-11-20 05:03:04,113] [UniVITrain

2025-11-20 05:12:02 - INFO -   device: cuda
[2025-11-20 05:12:02,326] [UniVITrainer] [INFO]   log_every: 10
2025-11-20 05:12:02 - INFO -   log_every: 10
[2025-11-20 05:12:02,327] [UniVITrainer] [INFO]   grad_clip: None
2025-11-20 05:12:02 - INFO -   grad_clip: None
[2025-11-20 05:12:02,328] [UniVITrainer] [INFO]   num_workers: 0
2025-11-20 05:12:02 - INFO -   num_workers: 0
[2025-11-20 05:12:02,330] [UniVITrainer] [INFO]   seed: 42
2025-11-20 05:12:02 - INFO -   seed: 42
[2025-11-20 05:12:02,331] [UniVITrainer] [INFO]   early_stopping: True
2025-11-20 05:12:02 - INFO -   early_stopping: True
[2025-11-20 05:12:02,332] [UniVITrainer] [INFO]   patience: 20
2025-11-20 05:12:02 - INFO -   patience: 20
[2025-11-20 05:12:02,334] [UniVITrainer] [INFO]   min_delta: 0.0
2025-11-20 05:12:02 - INFO -   min_delta: 0.0


[Config 17] Done in 17.5 min
  best_val_loss              = 772.388
  FOSCTTM (RNA vs ADT, val)  = 0.6448
  FOSCTTM (RNA vs ATAC, val) = 0.4868
[Config 17] FOSCTTM (ADT vs ATAC, val) = 0.5055
  Mean FOSCTTM (3 pairs)     = 0.5457
  Modality mixing (k=20)     = 0.0003
  Composite score            = 1194.29

[Config 18] Hyperparameters:
{
  "latent_dim": 50,
  "beta": 300.0,
  "gamma": 300.0,
  "lr": 0.0005,
  "weight_decay": 1e-05,
  "encoder_dropout": 0.1,
  "decoder_batchnorm": false,
  "rna_arch": "rna_med2",
  "adt_arch": "adt_small2",
  "atac_arch": "atac_wide2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 05:12:08,585] [UniVITrainer] [INFO] [Epoch 001] Train loss: 5734.4715 (beta=300.000, gamma=300.000)
2025-11-20 05:12:08 - INFO - [Epoch 001] Train loss: 5734.4715 (beta=300.000, gamma=300.000)
[2025-11-20 05:12:09,421] [UniVITrainer] [INFO] [Epoch 001] Val loss: 2641.1670 (beta=300.000, gamma=300.000)
2025-11-20 05:12:09 - INFO - [Epoch 001] Val loss: 2641.1670 (beta=300.000, gamma=300.000)
[2025-11-20 05:12:09,620] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 2641.1670
2025-11-20 05:12:09 - INFO - [Epoch 001] New best val loss: 2641.1670
[2025-11-20 05:12:15,944] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 1348.9335
2025-11-20 05:12:15 - INFO - [Epoch 002] New best val loss: 1348.9335
[2025-11-20 05:12:21,846] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 1070.5120
2025-11-20 05:12:21 - INFO - [Epoch 003] New best val loss: 1070.5120
[2025-11-20 05:12:29,268] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 961.1479
2025-11-20 05:12:29 - INFO - 

2025-11-20 05:18:20 - INFO - [Epoch 050] Val loss: 785.4802 (beta=300.000, gamma=300.000)
[2025-11-20 05:18:20,638] [UniVITrainer] [INFO] [Epoch 050] New best val loss: 785.4802
2025-11-20 05:18:20 - INFO - [Epoch 050] New best val loss: 785.4802
[2025-11-20 05:18:42,686] [UniVITrainer] [INFO] [Epoch 053] New best val loss: 784.2936
2025-11-20 05:18:42 - INFO - [Epoch 053] New best val loss: 784.2936
[2025-11-20 05:19:02,703] [UniVITrainer] [INFO] [Epoch 056] New best val loss: 783.7566
2025-11-20 05:19:02 - INFO - [Epoch 056] New best val loss: 783.7566
[2025-11-20 05:19:17,166] [UniVITrainer] [INFO] [Epoch 058] New best val loss: 783.1834
2025-11-20 05:19:17 - INFO - [Epoch 058] New best val loss: 783.1834
[2025-11-20 05:19:24,707] [UniVITrainer] [INFO] [Epoch 059] New best val loss: 783.1524
2025-11-20 05:19:24 - INFO - [Epoch 059] New best val loss: 783.1524
[2025-11-20 05:19:31,433] [UniVITrainer] [INFO] [Epoch 060] Train loss: 787.3987 (beta=300.000, gamma=300.000)
2025-11-20 05:

2025-11-20 05:28:43 - INFO - [Epoch 136] New best val loss: 775.4063
[2025-11-20 05:29:12,253] [UniVITrainer] [INFO] [Epoch 140] Train loss: 771.0152 (beta=300.000, gamma=300.000)
2025-11-20 05:29:12 - INFO - [Epoch 140] Train loss: 771.0152 (beta=300.000, gamma=300.000)
[2025-11-20 05:29:13,055] [UniVITrainer] [INFO] [Epoch 140] Val loss: 776.1035 (beta=300.000, gamma=300.000)
2025-11-20 05:29:13 - INFO - [Epoch 140] Val loss: 776.1035 (beta=300.000, gamma=300.000)
[2025-11-20 05:29:50,431] [UniVITrainer] [INFO] [Epoch 145] New best val loss: 775.0332
2025-11-20 05:29:50 - INFO - [Epoch 145] New best val loss: 775.0332
[2025-11-20 05:30:26,046] [UniVITrainer] [INFO] [Epoch 150] Train loss: 767.6386 (beta=300.000, gamma=300.000)
2025-11-20 05:30:26 - INFO - [Epoch 150] Train loss: 767.6386 (beta=300.000, gamma=300.000)
[2025-11-20 05:30:26,854] [UniVITrainer] [INFO] [Epoch 150] Val loss: 776.4685 (beta=300.000, gamma=300.000)
2025-11-20 05:30:26 - INFO - [Epoch 150] Val loss: 776.4685 

[Config 18] Done in 24.4 min
  best_val_loss              = 771.868
  FOSCTTM (RNA vs ADT, val)  = 0.4995
  FOSCTTM (RNA vs ATAC, val) = 0.4898
[Config 18] FOSCTTM (ADT vs ATAC, val) = 0.4922
  Mean FOSCTTM (3 pairs)     = 0.4938
  Modality mixing (k=20)     = 0.0088
  Composite score            = 1163.23

[Config 19] Hyperparameters:
{
  "latent_dim": 200,
  "beta": 1000.0,
  "gamma": 300.0,
  "lr": 0.0005,
  "weight_decay": 1e-05,
  "encoder_dropout": 0.0,
  "decoder_batchnorm": false,
  "rna_arch": "rna_wide3",
  "adt_arch": "adt_small2",
  "atac_arch": "atac_wide2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 05:36:33,518] [UniVITrainer] [INFO] [Epoch 001] Train loss: 37377.3228 (beta=1000.000, gamma=300.000)
2025-11-20 05:36:33 - INFO - [Epoch 001] Train loss: 37377.3228 (beta=1000.000, gamma=300.000)
[2025-11-20 05:36:34,312] [UniVITrainer] [INFO] [Epoch 001] Val loss: 16094.8968 (beta=1000.000, gamma=300.000)
2025-11-20 05:36:34 - INFO - [Epoch 001] Val loss: 16094.8968 (beta=1000.000, gamma=300.000)
[2025-11-20 05:36:34,655] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 16094.8968
2025-11-20 05:36:34 - INFO - [Epoch 001] New best val loss: 16094.8968
[2025-11-20 05:36:42,305] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 4472.1134
2025-11-20 05:36:42 - INFO - [Epoch 002] New best val loss: 4472.1134
[2025-11-20 05:36:49,565] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 2546.8315
2025-11-20 05:36:49 - INFO - [Epoch 003] New best val loss: 2546.8315
[2025-11-20 05:36:57,456] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 1931.2391
2025-11-20 05:36:5

[2025-11-20 05:42:27,596] [UniVITrainer] [INFO] [Epoch 048] New best val loss: 944.4327
2025-11-20 05:42:27 - INFO - [Epoch 048] New best val loss: 944.4327
[2025-11-20 05:42:41,581] [UniVITrainer] [INFO] [Epoch 050] Train loss: 810.7212 (beta=1000.000, gamma=300.000)
2025-11-20 05:42:41 - INFO - [Epoch 050] Train loss: 810.7212 (beta=1000.000, gamma=300.000)
[2025-11-20 05:42:42,417] [UniVITrainer] [INFO] [Epoch 050] Val loss: 944.7485 (beta=1000.000, gamma=300.000)
2025-11-20 05:42:42 - INFO - [Epoch 050] Val loss: 944.7485 (beta=1000.000, gamma=300.000)
[2025-11-20 05:42:57,825] [UniVITrainer] [INFO] [Epoch 052] New best val loss: 939.9303
2025-11-20 05:42:57 - INFO - [Epoch 052] New best val loss: 939.9303
[2025-11-20 05:43:05,947] [UniVITrainer] [INFO] [Epoch 053] New best val loss: 939.8447
2025-11-20 05:43:05 - INFO - [Epoch 053] New best val loss: 939.8447
[2025-11-20 05:43:14,103] [UniVITrainer] [INFO] [Epoch 054] New best val loss: 936.9316
2025-11-20 05:43:14 - INFO - [Epoch

[2025-11-20 05:51:12,941] [UniVITrainer] [INFO] [Epoch 120] New best val loss: 860.5443
2025-11-20 05:51:12 - INFO - [Epoch 120] New best val loss: 860.5443
[2025-11-20 05:51:19,315] [UniVITrainer] [INFO] [Epoch 121] New best val loss: 856.1553
2025-11-20 05:51:19 - INFO - [Epoch 121] New best val loss: 856.1553
[2025-11-20 05:51:34,843] [UniVITrainer] [INFO] [Epoch 123] New best val loss: 848.7888
2025-11-20 05:51:34 - INFO - [Epoch 123] New best val loss: 848.7888
[2025-11-20 05:51:42,596] [UniVITrainer] [INFO] [Epoch 124] New best val loss: 847.8397
2025-11-20 05:51:42 - INFO - [Epoch 124] New best val loss: 847.8397
[2025-11-20 05:52:18,783] [UniVITrainer] [INFO] [Epoch 129] New best val loss: 841.6419
2025-11-20 05:52:18 - INFO - [Epoch 129] New best val loss: 841.6419
[2025-11-20 05:52:25,097] [UniVITrainer] [INFO] [Epoch 130] Train loss: 777.9441 (beta=1000.000, gamma=300.000)
2025-11-20 05:52:25 - INFO - [Epoch 130] Train loss: 777.9441 (beta=1000.000, gamma=300.000)
[2025-11-2

[2025-11-20 05:59:01,869] [UniVITrainer] [INFO] [Epoch 184] New best val loss: 810.1832
2025-11-20 05:59:01 - INFO - [Epoch 184] New best val loss: 810.1832
[2025-11-20 05:59:17,075] [UniVITrainer] [INFO] [Epoch 186] New best val loss: 808.9061
2025-11-20 05:59:17 - INFO - [Epoch 186] New best val loss: 808.9061
[2025-11-20 05:59:45,013] [UniVITrainer] [INFO] [Epoch 190] Train loss: 777.3308 (beta=1000.000, gamma=300.000)
2025-11-20 05:59:45 - INFO - [Epoch 190] Train loss: 777.3308 (beta=1000.000, gamma=300.000)
[2025-11-20 05:59:45,818] [UniVITrainer] [INFO] [Epoch 190] Val loss: 806.9960 (beta=1000.000, gamma=300.000)
2025-11-20 05:59:45 - INFO - [Epoch 190] Val loss: 806.9960 (beta=1000.000, gamma=300.000)
[2025-11-20 05:59:46,269] [UniVITrainer] [INFO] [Epoch 190] New best val loss: 806.9960
2025-11-20 05:59:46 - INFO - [Epoch 190] New best val loss: 806.9960
[2025-11-20 05:59:53,253] [UniVITrainer] [INFO] [Epoch 191] New best val loss: 805.7068
2025-11-20 05:59:53 - INFO - [Epoch

[Config 19] Done in 24.6 min
  best_val_loss              = 802.440
  FOSCTTM (RNA vs ADT, val)  = 0.4837
  FOSCTTM (RNA vs ATAC, val) = 0.5046
[Config 19] FOSCTTM (ADT vs ATAC, val) = 0.4953
  Mean FOSCTTM (3 pairs)     = 0.4945
  Modality mixing (k=20)     = 0.0106
  Composite score            = 1212.00

[Config 20] Hyperparameters:
{
  "latent_dim": 156,
  "beta": 1000.0,
  "gamma": 100.0,
  "lr": 0.001,
  "weight_decay": 1e-05,
  "encoder_dropout": 0.1,
  "decoder_batchnorm": false,
  "rna_arch": "rna_wide3",
  "adt_arch": "adt_small2",
  "atac_arch": "atac_wide2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 06:01:09,165] [UniVITrainer] [INFO] [Epoch 001] Train loss: 17408.9101 (beta=1000.000, gamma=100.000)
2025-11-20 06:01:09 - INFO - [Epoch 001] Train loss: 17408.9101 (beta=1000.000, gamma=100.000)
[2025-11-20 06:01:09,971] [UniVITrainer] [INFO] [Epoch 001] Val loss: 3217.0339 (beta=1000.000, gamma=100.000)
2025-11-20 06:01:09 - INFO - [Epoch 001] Val loss: 3217.0339 (beta=1000.000, gamma=100.000)
[2025-11-20 06:01:10,327] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 3217.0339
2025-11-20 06:01:10 - INFO - [Epoch 001] New best val loss: 3217.0339
[2025-11-20 06:01:17,661] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 1563.6121
2025-11-20 06:01:17 - INFO - [Epoch 002] New best val loss: 1563.6121
[2025-11-20 06:01:25,277] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 1236.1479
2025-11-20 06:01:25 - INFO - [Epoch 003] New best val loss: 1236.1479
[2025-11-20 06:01:33,210] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 1128.3312
2025-11-20 06:01:33 - 

[Config 20] Done in 4.4 min
  best_val_loss              = 1011.754
  FOSCTTM (RNA vs ADT, val)  = 0.4616
  FOSCTTM (RNA vs ATAC, val) = 0.5163
[Config 20] FOSCTTM (ADT vs ATAC, val) = 0.4955
  Mean FOSCTTM (3 pairs)     = 0.4911
  Modality mixing (k=20)     = 0.1604
  Composite score            = 1750.67

[Config 21] Hyperparameters:
{
  "latent_dim": 40,
  "beta": 140.0,
  "gamma": 80.0,
  "lr": 0.001,
  "weight_decay": 1e-05,
  "encoder_dropout": 0.1,
  "decoder_batchnorm": false,
  "rna_arch": "rna_wide3",
  "adt_arch": "adt_med2",
  "atac_arch": "atac_small2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 06:05:37,046] [UniVITrainer] [INFO] [Epoch 001] Train loss: 1790.9643 (beta=140.000, gamma=80.000)
2025-11-20 06:05:37 - INFO - [Epoch 001] Train loss: 1790.9643 (beta=140.000, gamma=80.000)
[2025-11-20 06:05:37,707] [UniVITrainer] [INFO] [Epoch 001] Val loss: 976.7466 (beta=140.000, gamma=80.000)
2025-11-20 06:05:37 - INFO - [Epoch 001] Val loss: 976.7466 (beta=140.000, gamma=80.000)
[2025-11-20 06:05:37,875] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 976.7466
2025-11-20 06:05:37 - INFO - [Epoch 001] New best val loss: 976.7466
[2025-11-20 06:05:44,699] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 821.3446
2025-11-20 06:05:44 - INFO - [Epoch 002] New best val loss: 821.3446
[2025-11-20 06:05:50,244] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 803.1092
2025-11-20 06:05:50 - INFO - [Epoch 003] New best val loss: 803.1092
[2025-11-20 06:05:58,053] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 794.1638
2025-11-20 06:05:58 - INFO - [Epoch 004] 

2025-11-20 06:11:42 - INFO - [Epoch 050] Train loss: 768.6560 (beta=140.000, gamma=80.000)
[2025-11-20 06:11:43,046] [UniVITrainer] [INFO] [Epoch 050] Val loss: 766.8261 (beta=140.000, gamma=80.000)
2025-11-20 06:11:43 - INFO - [Epoch 050] Val loss: 766.8261 (beta=140.000, gamma=80.000)
[2025-11-20 06:11:43,188] [UniVITrainer] [INFO] [Epoch 050] New best val loss: 766.8261
2025-11-20 06:11:43 - INFO - [Epoch 050] New best val loss: 766.8261
[2025-11-20 06:11:50,150] [UniVITrainer] [INFO] [Epoch 051] New best val loss: 766.7563
2025-11-20 06:11:50 - INFO - [Epoch 051] New best val loss: 766.7563
[2025-11-20 06:12:03,451] [UniVITrainer] [INFO] [Epoch 053] New best val loss: 766.7423
2025-11-20 06:12:03 - INFO - [Epoch 053] New best val loss: 766.7423
[2025-11-20 06:12:11,390] [UniVITrainer] [INFO] [Epoch 054] New best val loss: 766.7115
2025-11-20 06:12:11 - INFO - [Epoch 054] New best val loss: 766.7115
[2025-11-20 06:12:17,134] [UniVITrainer] [INFO] [Epoch 055] New best val loss: 766.6

2025-11-20 06:22:15 - INFO - Early stopping at epoch 136 (best val loss = 766.2912)
[2025-11-20 06:22:15,914] [UniVITrainer] [INFO] Restored best model from epoch 116 (val loss = 766.2912)
2025-11-20 06:22:15 - INFO - Restored best model from epoch 116 (val loss = 766.2912)
[2025-11-20 06:22:18,462] [UniVITrainer] [INFO] TrainingConfig:
2025-11-20 06:22:18 - INFO - TrainingConfig:
[2025-11-20 06:22:18,465] [UniVITrainer] [INFO]   n_epochs: 200
2025-11-20 06:22:18 - INFO -   n_epochs: 200
[2025-11-20 06:22:18,467] [UniVITrainer] [INFO]   batch_size: 256
2025-11-20 06:22:18 - INFO -   batch_size: 256
[2025-11-20 06:22:18,471] [UniVITrainer] [INFO]   lr: 0.001
2025-11-20 06:22:18 - INFO -   lr: 0.001
[2025-11-20 06:22:18,474] [UniVITrainer] [INFO]   weight_decay: 0.0001
2025-11-20 06:22:18 - INFO -   weight_decay: 0.0001
[2025-11-20 06:22:18,476] [UniVITrainer] [INFO]   device: cuda
2025-11-20 06:22:18 - INFO -   device: cuda
[2025-11-20 06:22:18,479] [UniVITrainer] [INFO]   log_every: 10

[Config 21] Done in 16.8 min
  best_val_loss              = 766.291
  FOSCTTM (RNA vs ADT, val)  = 0.5126
  FOSCTTM (RNA vs ATAC, val) = 0.4876
[Config 21] FOSCTTM (ADT vs ATAC, val) = 0.4901
  Mean FOSCTTM (3 pairs)     = 0.4968
  Modality mixing (k=20)     = 0.0007
  Composite score            = 1147.81

[Config 22] Hyperparameters:
{
  "latent_dim": 64,
  "beta": 180.0,
  "gamma": 240.0,
  "lr": 0.001,
  "weight_decay": 0.0001,
  "encoder_dropout": 0.1,
  "decoder_batchnorm": true,
  "rna_arch": "rna_med2",
  "adt_arch": "adt_med2",
  "atac_arch": "atac_wide2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 06:22:25,012] [UniVITrainer] [INFO] [Epoch 001] Train loss: 3632.8436 (beta=180.000, gamma=240.000)
2025-11-20 06:22:25 - INFO - [Epoch 001] Train loss: 3632.8436 (beta=180.000, gamma=240.000)
[2025-11-20 06:22:25,705] [UniVITrainer] [INFO] [Epoch 001] Val loss: 1201.8418 (beta=180.000, gamma=240.000)
2025-11-20 06:22:25 - INFO - [Epoch 001] Val loss: 1201.8418 (beta=180.000, gamma=240.000)
[2025-11-20 06:22:25,901] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 1201.8418
2025-11-20 06:22:25 - INFO - [Epoch 001] New best val loss: 1201.8418
[2025-11-20 06:22:33,507] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 978.3422
2025-11-20 06:22:33 - INFO - [Epoch 002] New best val loss: 978.3422
[2025-11-20 06:22:40,960] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 905.2354
2025-11-20 06:22:40 - INFO - [Epoch 003] New best val loss: 905.2354
[2025-11-20 06:22:48,603] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 870.9789
2025-11-20 06:22:48 - INFO - [Epo

[2025-11-20 06:29:20,835] [UniVITrainer] [INFO] [Epoch 056] New best val loss: 771.3457
2025-11-20 06:29:20 - INFO - [Epoch 056] New best val loss: 771.3457
[2025-11-20 06:29:28,635] [UniVITrainer] [INFO] [Epoch 057] New best val loss: 771.2859
2025-11-20 06:29:28 - INFO - [Epoch 057] New best val loss: 771.2859
[2025-11-20 06:29:42,681] [UniVITrainer] [INFO] [Epoch 059] New best val loss: 770.8811
2025-11-20 06:29:42 - INFO - [Epoch 059] New best val loss: 770.8811
[2025-11-20 06:29:49,590] [UniVITrainer] [INFO] [Epoch 060] Train loss: 773.9175 (beta=180.000, gamma=240.000)
2025-11-20 06:29:49 - INFO - [Epoch 060] Train loss: 773.9175 (beta=180.000, gamma=240.000)
[2025-11-20 06:29:50,428] [UniVITrainer] [INFO] [Epoch 060] Val loss: 771.9095 (beta=180.000, gamma=240.000)
2025-11-20 06:29:50 - INFO - [Epoch 060] Val loss: 771.9095 (beta=180.000, gamma=240.000)
[2025-11-20 06:29:58,386] [UniVITrainer] [INFO] [Epoch 061] New best val loss: 770.7283
2025-11-20 06:29:58 - INFO - [Epoch 061

[2025-11-20 06:37:09,101] [UniVITrainer] [INFO] [Epoch 120] Val loss: 766.7021 (beta=180.000, gamma=240.000)
2025-11-20 06:37:09 - INFO - [Epoch 120] Val loss: 766.7021 (beta=180.000, gamma=240.000)
[2025-11-20 06:37:09,308] [UniVITrainer] [INFO] [Epoch 120] New best val loss: 766.7021
2025-11-20 06:37:09 - INFO - [Epoch 120] New best val loss: 766.7021
[2025-11-20 06:37:16,398] [UniVITrainer] [INFO] [Epoch 121] New best val loss: 766.6596
2025-11-20 06:37:16 - INFO - [Epoch 121] New best val loss: 766.6596
[2025-11-20 06:37:52,240] [UniVITrainer] [INFO] [Epoch 126] New best val loss: 766.6383
2025-11-20 06:37:52 - INFO - [Epoch 126] New best val loss: 766.6383
[2025-11-20 06:38:00,074] [UniVITrainer] [INFO] [Epoch 127] New best val loss: 766.6127
2025-11-20 06:38:00 - INFO - [Epoch 127] New best val loss: 766.6127
[2025-11-20 06:38:07,713] [UniVITrainer] [INFO] [Epoch 128] New best val loss: 766.5999
2025-11-20 06:38:07 - INFO - [Epoch 128] New best val loss: 766.5999
[2025-11-20 06:3

[Config 22] Done in 24.6 min
  best_val_loss              = 766.390
  FOSCTTM (RNA vs ADT, val)  = 0.4993
  FOSCTTM (RNA vs ATAC, val) = 0.4458
[Config 22] FOSCTTM (ADT vs ATAC, val) = 0.4884
  Mean FOSCTTM (3 pairs)     = 0.4778
  Modality mixing (k=20)     = 0.0000
  Composite score            = 1132.63

[Config 23] Hyperparameters:
{
  "latent_dim": 40,
  "beta": 240.0,
  "gamma": 180.0,
  "lr": 0.001,
  "weight_decay": 1e-05,
  "encoder_dropout": 0.1,
  "decoder_batchnorm": false,
  "rna_arch": "rna_med2",
  "adt_arch": "adt_small2",
  "atac_arch": "atac_wide2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 06:47:01,083] [UniVITrainer] [INFO] [Epoch 001] Train loss: 2632.2980 (beta=240.000, gamma=180.000)
2025-11-20 06:47:01 - INFO - [Epoch 001] Train loss: 2632.2980 (beta=240.000, gamma=180.000)
[2025-11-20 06:47:01,862] [UniVITrainer] [INFO] [Epoch 001] Val loss: 1078.4996 (beta=240.000, gamma=180.000)
2025-11-20 06:47:01 - INFO - [Epoch 001] Val loss: 1078.4996 (beta=240.000, gamma=180.000)
[2025-11-20 06:47:02,061] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 1078.4996
2025-11-20 06:47:02 - INFO - [Epoch 001] New best val loss: 1078.4996
[2025-11-20 06:47:09,565] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 875.6692
2025-11-20 06:47:09 - INFO - [Epoch 002] New best val loss: 875.6692
[2025-11-20 06:47:16,028] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 830.6809
2025-11-20 06:47:16 - INFO - [Epoch 003] New best val loss: 830.6809
[2025-11-20 06:47:23,627] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 812.2211
2025-11-20 06:47:23 - INFO - [Epo

[2025-11-20 06:54:11,940] [UniVITrainer] [INFO] [Epoch 060] Val loss: 769.2252 (beta=240.000, gamma=180.000)
2025-11-20 06:54:11 - INFO - [Epoch 060] Val loss: 769.2252 (beta=240.000, gamma=180.000)
[2025-11-20 06:54:26,442] [UniVITrainer] [INFO] [Epoch 062] New best val loss: 768.6997
2025-11-20 06:54:26 - INFO - [Epoch 062] New best val loss: 768.6997
[2025-11-20 06:54:55,630] [UniVITrainer] [INFO] [Epoch 066] New best val loss: 768.3905
2025-11-20 06:54:55 - INFO - [Epoch 066] New best val loss: 768.3905
[2025-11-20 06:55:02,870] [UniVITrainer] [INFO] [Epoch 067] New best val loss: 768.3362
2025-11-20 06:55:02 - INFO - [Epoch 067] New best val loss: 768.3362
[2025-11-20 06:55:15,317] [UniVITrainer] [INFO] [Epoch 069] New best val loss: 768.1684
2025-11-20 06:55:15 - INFO - [Epoch 069] New best val loss: 768.1684
[2025-11-20 06:55:21,969] [UniVITrainer] [INFO] [Epoch 070] Train loss: 767.3963 (beta=240.000, gamma=180.000)
2025-11-20 06:55:21 - INFO - [Epoch 070] Train loss: 767.3963 

2025-11-20 07:01:54 - INFO - [Epoch 124] New best val loss: 766.6983
[2025-11-20 07:02:02,249] [UniVITrainer] [INFO] [Epoch 125] New best val loss: 766.6593
2025-11-20 07:02:02 - INFO - [Epoch 125] New best val loss: 766.6593
[2025-11-20 07:02:25,569] [UniVITrainer] [INFO] [Epoch 128] New best val loss: 766.6539
2025-11-20 07:02:25 - INFO - [Epoch 128] New best val loss: 766.6539
[2025-11-20 07:02:33,293] [UniVITrainer] [INFO] [Epoch 129] New best val loss: 766.6369
2025-11-20 07:02:33 - INFO - [Epoch 129] New best val loss: 766.6369
[2025-11-20 07:02:40,166] [UniVITrainer] [INFO] [Epoch 130] Train loss: 765.8719 (beta=240.000, gamma=180.000)
2025-11-20 07:02:40 - INFO - [Epoch 130] Train loss: 765.8719 (beta=240.000, gamma=180.000)
[2025-11-20 07:02:40,990] [UniVITrainer] [INFO] [Epoch 130] Val loss: 766.6419 (beta=240.000, gamma=180.000)
2025-11-20 07:02:40 - INFO - [Epoch 130] Val loss: 766.6419 (beta=240.000, gamma=180.000)
[2025-11-20 07:03:10,714] [UniVITrainer] [INFO] [Epoch 134

[Config 23] Done in 24.2 min
  best_val_loss              = 766.388
  FOSCTTM (RNA vs ADT, val)  = 0.5319
  FOSCTTM (RNA vs ATAC, val) = 0.5082
[Config 23] FOSCTTM (ADT vs ATAC, val) = 0.4989
  Mean FOSCTTM (3 pairs)     = 0.5130
  Modality mixing (k=20)     = 0.0004
  Composite score            = 1160.00

[Config 24] Hyperparameters:
{
  "latent_dim": 32,
  "beta": 180.0,
  "gamma": 500.0,
  "lr": 0.0005,
  "weight_decay": 0.0001,
  "encoder_dropout": 0.0,
  "decoder_batchnorm": false,
  "rna_arch": "rna_wide3",
  "adt_arch": "adt_med2",
  "atac_arch": "atac_small2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 07:11:13,197] [UniVITrainer] [INFO] [Epoch 001] Train loss: 3624.9808 (beta=180.000, gamma=500.000)
2025-11-20 07:11:13 - INFO - [Epoch 001] Train loss: 3624.9808 (beta=180.000, gamma=500.000)
[2025-11-20 07:11:13,935] [UniVITrainer] [INFO] [Epoch 001] Val loss: 1740.6882 (beta=180.000, gamma=500.000)
2025-11-20 07:11:13 - INFO - [Epoch 001] Val loss: 1740.6882 (beta=180.000, gamma=500.000)
[2025-11-20 07:11:14,096] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 1740.6882
2025-11-20 07:11:14 - INFO - [Epoch 001] New best val loss: 1740.6882
[2025-11-20 07:11:21,598] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 1191.5010
2025-11-20 07:11:21 - INFO - [Epoch 002] New best val loss: 1191.5010
[2025-11-20 07:11:29,143] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 1013.4858
2025-11-20 07:11:29 - INFO - [Epoch 003] New best val loss: 1013.4858
[2025-11-20 07:11:33,962] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 942.8653
2025-11-20 07:11:33 - INFO - 

2025-11-20 07:16:05 - INFO - [Epoch 041] New best val loss: 779.8676
[2025-11-20 07:16:13,159] [UniVITrainer] [INFO] [Epoch 042] New best val loss: 779.7656
2025-11-20 07:16:13 - INFO - [Epoch 042] New best val loss: 779.7656
[2025-11-20 07:16:20,877] [UniVITrainer] [INFO] [Epoch 043] New best val loss: 779.3015
2025-11-20 07:16:20 - INFO - [Epoch 043] New best val loss: 779.3015
[2025-11-20 07:16:27,572] [UniVITrainer] [INFO] [Epoch 044] New best val loss: 778.7939
2025-11-20 07:16:27 - INFO - [Epoch 044] New best val loss: 778.7939
[2025-11-20 07:16:50,237] [UniVITrainer] [INFO] [Epoch 047] New best val loss: 777.9942
2025-11-20 07:16:50 - INFO - [Epoch 047] New best val loss: 777.9942
[2025-11-20 07:16:57,684] [UniVITrainer] [INFO] [Epoch 048] New best val loss: 776.6966
2025-11-20 07:16:57 - INFO - [Epoch 048] New best val loss: 776.6966
[2025-11-20 07:17:12,102] [UniVITrainer] [INFO] [Epoch 050] Train loss: 776.3354 (beta=180.000, gamma=500.000)
2025-11-20 07:17:12 - INFO - [Epoch

2025-11-20 07:23:25 - INFO - [Epoch 100] Val loss: 769.6188 (beta=180.000, gamma=500.000)
[2025-11-20 07:23:33,213] [UniVITrainer] [INFO] [Epoch 101] New best val loss: 769.4762
2025-11-20 07:23:33 - INFO - [Epoch 101] New best val loss: 769.4762
[2025-11-20 07:23:40,976] [UniVITrainer] [INFO] [Epoch 102] New best val loss: 769.3500
2025-11-20 07:23:40 - INFO - [Epoch 102] New best val loss: 769.3500
[2025-11-20 07:23:48,370] [UniVITrainer] [INFO] [Epoch 103] New best val loss: 769.1810
2025-11-20 07:23:48 - INFO - [Epoch 103] New best val loss: 769.1810
[2025-11-20 07:24:03,759] [UniVITrainer] [INFO] [Epoch 105] New best val loss: 769.0280
2025-11-20 07:24:03 - INFO - [Epoch 105] New best val loss: 769.0280
[2025-11-20 07:24:33,221] [UniVITrainer] [INFO] [Epoch 109] New best val loss: 768.9472
2025-11-20 07:24:33 - INFO - [Epoch 109] New best val loss: 768.9472
[2025-11-20 07:24:40,188] [UniVITrainer] [INFO] [Epoch 110] Train loss: 765.4574 (beta=180.000, gamma=500.000)
2025-11-20 07:

[2025-11-20 07:36:17,305] [UniVITrainer] [INFO] [Epoch 191] New best val loss: 767.2202
2025-11-20 07:36:17 - INFO - [Epoch 191] New best val loss: 767.2202
[2025-11-20 07:36:37,002] [UniVITrainer] [INFO] [Epoch 193] New best val loss: 767.1708
2025-11-20 07:36:37 - INFO - [Epoch 193] New best val loss: 767.1708
[2025-11-20 07:37:40,971] [UniVITrainer] [INFO] [Epoch 200] Train loss: 764.7915 (beta=180.000, gamma=500.000)
2025-11-20 07:37:40 - INFO - [Epoch 200] Train loss: 764.7915 (beta=180.000, gamma=500.000)
[2025-11-20 07:37:42,039] [UniVITrainer] [INFO] [Epoch 200] Val loss: 767.1850 (beta=180.000, gamma=500.000)
2025-11-20 07:37:42 - INFO - [Epoch 200] Val loss: 767.1850 (beta=180.000, gamma=500.000)
[2025-11-20 07:37:42,086] [UniVITrainer] [INFO] Restored best model from epoch 193 (val loss = 767.1708)
2025-11-20 07:37:42 - INFO - Restored best model from epoch 193 (val loss = 767.1708)
[2025-11-20 07:37:44,657] [UniVITrainer] [INFO] TrainingConfig:
2025-11-20 07:37:44 - INFO - 

[Config 24] Done in 26.6 min
  best_val_loss              = 767.171
  FOSCTTM (RNA vs ADT, val)  = 0.4204
  FOSCTTM (RNA vs ATAC, val) = 0.4901
[Config 24] FOSCTTM (ADT vs ATAC, val) = 0.5023
  Mean FOSCTTM (3 pairs)     = 0.4709
  Modality mixing (k=20)     = 0.1076
  Composite score            = 1249.86

[Config 25] Hyperparameters:
{
  "latent_dim": 100,
  "beta": 100.0,
  "gamma": 60.0,
  "lr": 0.001,
  "weight_decay": 1e-05,
  "encoder_dropout": 0.1,
  "decoder_batchnorm": false,
  "rna_arch": "rna_wide2",
  "adt_arch": "adt_med2",
  "atac_arch": "atac_small2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 07:37:53,366] [UniVITrainer] [INFO] [Epoch 001] Train loss: 2503.7605 (beta=100.000, gamma=60.000)
2025-11-20 07:37:53 - INFO - [Epoch 001] Train loss: 2503.7605 (beta=100.000, gamma=60.000)
[2025-11-20 07:37:54,440] [UniVITrainer] [INFO] [Epoch 001] Val loss: 1097.0956 (beta=100.000, gamma=60.000)
2025-11-20 07:37:54 - INFO - [Epoch 001] Val loss: 1097.0956 (beta=100.000, gamma=60.000)
[2025-11-20 07:37:54,589] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 1097.0956
2025-11-20 07:37:54 - INFO - [Epoch 001] New best val loss: 1097.0956
[2025-11-20 07:38:04,614] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 879.2995
2025-11-20 07:38:04 - INFO - [Epoch 002] New best val loss: 879.2995
[2025-11-20 07:38:14,387] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 839.6522
2025-11-20 07:38:14 - INFO - [Epoch 003] New best val loss: 839.6522
[2025-11-20 07:38:24,361] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 823.5655
2025-11-20 07:38:24 - INFO - [Epoch 0

2025-11-20 07:44:00 - INFO - [Epoch 043] New best val loss: 767.1606
[2025-11-20 07:44:09,775] [UniVITrainer] [INFO] [Epoch 044] New best val loss: 767.1515
2025-11-20 07:44:09 - INFO - [Epoch 044] New best val loss: 767.1515
[2025-11-20 07:44:26,669] [UniVITrainer] [INFO] [Epoch 046] New best val loss: 767.1148
2025-11-20 07:44:26 - INFO - [Epoch 046] New best val loss: 767.1148
[2025-11-20 07:44:34,553] [UniVITrainer] [INFO] [Epoch 047] New best val loss: 767.1140
2025-11-20 07:44:34 - INFO - [Epoch 047] New best val loss: 767.1140
[2025-11-20 07:44:42,978] [UniVITrainer] [INFO] [Epoch 048] New best val loss: 767.1003
2025-11-20 07:44:42 - INFO - [Epoch 048] New best val loss: 767.1003
[2025-11-20 07:44:52,215] [UniVITrainer] [INFO] [Epoch 049] New best val loss: 767.0043
2025-11-20 07:44:52 - INFO - [Epoch 049] New best val loss: 767.0043
[2025-11-20 07:45:00,580] [UniVITrainer] [INFO] [Epoch 050] Train loss: 771.4664 (beta=100.000, gamma=60.000)
2025-11-20 07:45:00 - INFO - [Epoch 

[2025-11-20 07:54:37,824] [UniVITrainer] [INFO] [Epoch 140] Train loss: 764.8024 (beta=100.000, gamma=60.000)
2025-11-20 07:54:37 - INFO - [Epoch 140] Train loss: 764.8024 (beta=100.000, gamma=60.000)
[2025-11-20 07:54:38,059] [UniVITrainer] [INFO] [Epoch 140] Val loss: 768.4365 (beta=100.000, gamma=60.000)
2025-11-20 07:54:38 - INFO - [Epoch 140] Val loss: 768.4365 (beta=100.000, gamma=60.000)
[2025-11-20 07:54:52,132] [UniVITrainer] [INFO] Early stopping at epoch 143 (best val loss = 766.3418)
2025-11-20 07:54:52 - INFO - Early stopping at epoch 143 (best val loss = 766.3418)
[2025-11-20 07:54:52,162] [UniVITrainer] [INFO] Restored best model from epoch 123 (val loss = 766.3418)
2025-11-20 07:54:52 - INFO - Restored best model from epoch 123 (val loss = 766.3418)
[2025-11-20 07:54:53,999] [UniVITrainer] [INFO] TrainingConfig:
2025-11-20 07:54:53 - INFO - TrainingConfig:
[2025-11-20 07:54:54,002] [UniVITrainer] [INFO]   n_epochs: 200
2025-11-20 07:54:54 - INFO -   n_epochs: 200
[2025-

[Config 25] Done in 17.1 min
  best_val_loss              = 766.342
  FOSCTTM (RNA vs ADT, val)  = 0.5206
  FOSCTTM (RNA vs ATAC, val) = 0.4177
[Config 25] FOSCTTM (ADT vs ATAC, val) = 0.4753
  Mean FOSCTTM (3 pairs)     = 0.4712
  Modality mixing (k=20)     = 0.0000
  Composite score            = 1127.45

[Config 26] Hyperparameters:
{
  "latent_dim": 40,
  "beta": 60.0,
  "gamma": 60.0,
  "lr": 0.0005,
  "weight_decay": 0.0001,
  "encoder_dropout": 0.0,
  "decoder_batchnorm": false,
  "rna_arch": "rna_med2",
  "adt_arch": "adt_med2",
  "atac_arch": "atac_small2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 07:54:58,248] [UniVITrainer] [INFO] [Epoch 001] Train loss: 1531.1146 (beta=60.000, gamma=60.000)
2025-11-20 07:54:58 - INFO - [Epoch 001] Train loss: 1531.1146 (beta=60.000, gamma=60.000)
[2025-11-20 07:54:58,747] [UniVITrainer] [INFO] [Epoch 001] Val loss: 1087.4734 (beta=60.000, gamma=60.000)
2025-11-20 07:54:58 - INFO - [Epoch 001] Val loss: 1087.4734 (beta=60.000, gamma=60.000)
[2025-11-20 07:54:58,842] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 1087.4734
2025-11-20 07:54:58 - INFO - [Epoch 001] New best val loss: 1087.4734
[2025-11-20 07:55:02,494] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 884.9745
2025-11-20 07:55:02 - INFO - [Epoch 002] New best val loss: 884.9745
[2025-11-20 07:55:07,327] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 825.0750
2025-11-20 07:55:07 - INFO - [Epoch 003] New best val loss: 825.0750
[2025-11-20 07:55:12,050] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 807.0419
2025-11-20 07:55:12 - INFO - [Epoch 004] 

2025-11-20 08:02:17 - INFO - TrainingConfig:
[2025-11-20 08:02:17,946] [UniVITrainer] [INFO]   n_epochs: 200
2025-11-20 08:02:17 - INFO -   n_epochs: 200
[2025-11-20 08:02:17,948] [UniVITrainer] [INFO]   batch_size: 256
2025-11-20 08:02:17 - INFO -   batch_size: 256
[2025-11-20 08:02:17,949] [UniVITrainer] [INFO]   lr: 0.0005
2025-11-20 08:02:17 - INFO -   lr: 0.0005
[2025-11-20 08:02:17,950] [UniVITrainer] [INFO]   weight_decay: 1e-05
2025-11-20 08:02:17 - INFO -   weight_decay: 1e-05
[2025-11-20 08:02:17,951] [UniVITrainer] [INFO]   device: cuda
2025-11-20 08:02:17 - INFO -   device: cuda
[2025-11-20 08:02:17,952] [UniVITrainer] [INFO]   log_every: 10
2025-11-20 08:02:17 - INFO -   log_every: 10
[2025-11-20 08:02:17,953] [UniVITrainer] [INFO]   grad_clip: None
2025-11-20 08:02:17 - INFO -   grad_clip: None
[2025-11-20 08:02:17,954] [UniVITrainer] [INFO]   num_workers: 0
2025-11-20 08:02:17 - INFO -   num_workers: 0
[2025-11-20 08:02:17,955] [UniVITrainer] [INFO]   seed: 42
2025-11-20

[Config 26] Done in 7.4 min
  best_val_loss              = 762.325
  FOSCTTM (RNA vs ADT, val)  = 0.2086
  FOSCTTM (RNA vs ATAC, val) = 0.2515
[Config 26] FOSCTTM (ADT vs ATAC, val) = 0.2455
  Mean FOSCTTM (3 pairs)     = 0.2352
  Modality mixing (k=20)     = 0.3831
  Composite score            = 1302.36

[Config 27] Hyperparameters:
{
  "latent_dim": 124,
  "beta": 240.0,
  "gamma": 500.0,
  "lr": 0.0005,
  "weight_decay": 1e-05,
  "encoder_dropout": 0.1,
  "decoder_batchnorm": false,
  "rna_arch": "rna_wide2",
  "adt_arch": "adt_small2",
  "atac_arch": "atac_small2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 08:02:24,553] [UniVITrainer] [INFO] [Epoch 001] Train loss: 16786.9140 (beta=240.000, gamma=500.000)
2025-11-20 08:02:24 - INFO - [Epoch 001] Train loss: 16786.9140 (beta=240.000, gamma=500.000)
[2025-11-20 08:02:25,345] [UniVITrainer] [INFO] [Epoch 001] Val loss: 5889.7700 (beta=240.000, gamma=500.000)
2025-11-20 08:02:25 - INFO - [Epoch 001] Val loss: 5889.7700 (beta=240.000, gamma=500.000)
[2025-11-20 08:02:25,509] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 5889.7700
2025-11-20 08:02:25 - INFO - [Epoch 001] New best val loss: 5889.7700
[2025-11-20 08:02:32,482] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 3337.2332
2025-11-20 08:02:32 - INFO - [Epoch 002] New best val loss: 3337.2332
[2025-11-20 08:02:40,029] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 2120.0328
2025-11-20 08:02:40 - INFO - [Epoch 003] New best val loss: 2120.0328
[2025-11-20 08:02:46,815] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 1656.5616
2025-11-20 08:02:46 - INFO

2025-11-20 08:08:37 - INFO - [Epoch 047] New best val loss: 791.5356
[2025-11-20 08:08:56,949] [UniVITrainer] [INFO] [Epoch 049] New best val loss: 790.3651
2025-11-20 08:08:56 - INFO - [Epoch 049] New best val loss: 790.3651
[2025-11-20 08:09:05,484] [UniVITrainer] [INFO] [Epoch 050] Train loss: 836.8710 (beta=240.000, gamma=500.000)
2025-11-20 08:09:05 - INFO - [Epoch 050] Train loss: 836.8710 (beta=240.000, gamma=500.000)
[2025-11-20 08:09:06,545] [UniVITrainer] [INFO] [Epoch 050] Val loss: 789.0023 (beta=240.000, gamma=500.000)
2025-11-20 08:09:06 - INFO - [Epoch 050] Val loss: 789.0023 (beta=240.000, gamma=500.000)
[2025-11-20 08:09:06,695] [UniVITrainer] [INFO] [Epoch 050] New best val loss: 789.0023
2025-11-20 08:09:06 - INFO - [Epoch 050] New best val loss: 789.0023
[2025-11-20 08:09:16,586] [UniVITrainer] [INFO] [Epoch 051] New best val loss: 787.4659
2025-11-20 08:09:16 - INFO - [Epoch 051] New best val loss: 787.4659
[2025-11-20 08:09:26,326] [UniVITrainer] [INFO] [Epoch 052

[2025-11-20 08:19:16,664] [UniVITrainer] [INFO] [Epoch 113] New best val loss: 769.6005
2025-11-20 08:19:16 - INFO - [Epoch 113] New best val loss: 769.6005
[2025-11-20 08:19:22,111] [UniVITrainer] [INFO] [Epoch 114] New best val loss: 769.3230
2025-11-20 08:19:22 - INFO - [Epoch 114] New best val loss: 769.3230
[2025-11-20 08:19:26,796] [UniVITrainer] [INFO] [Epoch 115] New best val loss: 769.2602
2025-11-20 08:19:26 - INFO - [Epoch 115] New best val loss: 769.2602
[2025-11-20 08:20:01,146] [UniVITrainer] [INFO] [Epoch 120] Train loss: 774.9872 (beta=240.000, gamma=500.000)
2025-11-20 08:20:01 - INFO - [Epoch 120] Train loss: 774.9872 (beta=240.000, gamma=500.000)
[2025-11-20 08:20:01,942] [UniVITrainer] [INFO] [Epoch 120] Val loss: 769.4606 (beta=240.000, gamma=500.000)
2025-11-20 08:20:01 - INFO - [Epoch 120] Val loss: 769.4606 (beta=240.000, gamma=500.000)
[2025-11-20 08:20:31,404] [UniVITrainer] [INFO] [Epoch 124] New best val loss: 769.2380
2025-11-20 08:20:31 - INFO - [Epoch 124

2025-11-20 08:29:49 - INFO -   grad_clip: None
[2025-11-20 08:29:49,054] [UniVITrainer] [INFO]   num_workers: 0
2025-11-20 08:29:49 - INFO -   num_workers: 0
[2025-11-20 08:29:49,055] [UniVITrainer] [INFO]   seed: 42
2025-11-20 08:29:49 - INFO -   seed: 42
[2025-11-20 08:29:49,056] [UniVITrainer] [INFO]   early_stopping: True
2025-11-20 08:29:49 - INFO -   early_stopping: True
[2025-11-20 08:29:49,057] [UniVITrainer] [INFO]   patience: 20
2025-11-20 08:29:49 - INFO -   patience: 20
[2025-11-20 08:29:49,059] [UniVITrainer] [INFO]   min_delta: 0.0
2025-11-20 08:29:49 - INFO -   min_delta: 0.0


[Config 27] Done in 27.5 min
  best_val_loss              = 767.432
  FOSCTTM (RNA vs ADT, val)  = 0.5288
  FOSCTTM (RNA vs ATAC, val) = 0.3994
[Config 27] FOSCTTM (ADT vs ATAC, val) = 0.3951
  Mean FOSCTTM (3 pairs)     = 0.4411
  Modality mixing (k=20)     = 0.0006
  Composite score            = 1106.65

[Config 28] Hyperparameters:
{
  "latent_dim": 64,
  "beta": 100.0,
  "gamma": 140.0,
  "lr": 0.0005,
  "weight_decay": 1e-05,
  "encoder_dropout": 0.1,
  "decoder_batchnorm": true,
  "rna_arch": "rna_wide3",
  "adt_arch": "adt_med2",
  "atac_arch": "atac_med2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 08:29:55,706] [UniVITrainer] [INFO] [Epoch 001] Train loss: 3398.5123 (beta=100.000, gamma=140.000)
2025-11-20 08:29:55 - INFO - [Epoch 001] Train loss: 3398.5123 (beta=100.000, gamma=140.000)
[2025-11-20 08:29:56,492] [UniVITrainer] [INFO] [Epoch 001] Val loss: 1599.6244 (beta=100.000, gamma=140.000)
2025-11-20 08:29:56 - INFO - [Epoch 001] Val loss: 1599.6244 (beta=100.000, gamma=140.000)
[2025-11-20 08:29:56,737] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 1599.6244
2025-11-20 08:29:56 - INFO - [Epoch 001] New best val loss: 1599.6244
[2025-11-20 08:30:04,347] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 1070.1698
2025-11-20 08:30:04 - INFO - [Epoch 002] New best val loss: 1070.1698
[2025-11-20 08:30:11,866] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 929.7274
2025-11-20 08:30:11 - INFO - [Epoch 003] New best val loss: 929.7274
[2025-11-20 08:30:19,466] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 879.9719
2025-11-20 08:30:19 - INFO - [E

[2025-11-20 08:36:37,894] [UniVITrainer] [INFO] [Epoch 055] New best val loss: 774.6937
2025-11-20 08:36:37 - INFO - [Epoch 055] New best val loss: 774.6937
[2025-11-20 08:36:52,401] [UniVITrainer] [INFO] [Epoch 057] New best val loss: 774.5969
2025-11-20 08:36:52 - INFO - [Epoch 057] New best val loss: 774.5969
[2025-11-20 08:36:57,875] [UniVITrainer] [INFO] [Epoch 058] New best val loss: 774.1371
2025-11-20 08:36:57 - INFO - [Epoch 058] New best val loss: 774.1371
[2025-11-20 08:37:10,270] [UniVITrainer] [INFO] [Epoch 060] Train loss: 777.9722 (beta=100.000, gamma=140.000)
2025-11-20 08:37:10 - INFO - [Epoch 060] Train loss: 777.9722 (beta=100.000, gamma=140.000)
[2025-11-20 08:37:11,074] [UniVITrainer] [INFO] [Epoch 060] Val loss: 774.1716 (beta=100.000, gamma=140.000)
2025-11-20 08:37:11 - INFO - [Epoch 060] Val loss: 774.1716 (beta=100.000, gamma=140.000)
[2025-11-20 08:37:29,302] [UniVITrainer] [INFO] [Epoch 063] New best val loss: 773.4042
2025-11-20 08:37:29 - INFO - [Epoch 063

2025-11-20 08:42:50 - INFO - [Epoch 123] New best val loss: 767.1243
[2025-11-20 08:43:05,040] [UniVITrainer] [INFO] [Epoch 126] New best val loss: 766.9914
2025-11-20 08:43:05 - INFO - [Epoch 126] New best val loss: 766.9914
[2025-11-20 08:43:23,060] [UniVITrainer] [INFO] [Epoch 130] Train loss: 763.3913 (beta=100.000, gamma=140.000)
2025-11-20 08:43:23 - INFO - [Epoch 130] Train loss: 763.3913 (beta=100.000, gamma=140.000)
[2025-11-20 08:43:23,540] [UniVITrainer] [INFO] [Epoch 130] Val loss: 767.0899 (beta=100.000, gamma=140.000)
2025-11-20 08:43:23 - INFO - [Epoch 130] Val loss: 767.0899 (beta=100.000, gamma=140.000)
[2025-11-20 08:43:33,190] [UniVITrainer] [INFO] [Epoch 132] New best val loss: 766.9526
2025-11-20 08:43:33 - INFO - [Epoch 132] New best val loss: 766.9526
[2025-11-20 08:43:42,600] [UniVITrainer] [INFO] [Epoch 134] New best val loss: 766.9445
2025-11-20 08:43:42 - INFO - [Epoch 134] New best val loss: 766.9445
[2025-11-20 08:43:47,249] [UniVITrainer] [INFO] [Epoch 135

[Config 28] Done in 19.4 min
  best_val_loss              = 766.417
  FOSCTTM (RNA vs ADT, val)  = 0.4267
  FOSCTTM (RNA vs ATAC, val) = 0.4172
[Config 28] FOSCTTM (ADT vs ATAC, val) = 0.4374
  Mean FOSCTTM (3 pairs)     = 0.4271
  Modality mixing (k=20)     = 0.0176
  Composite score            = 1112.98

[Config 29] Hyperparameters:
{
  "latent_dim": 20,
  "beta": 180.0,
  "gamma": 60.0,
  "lr": 0.0005,
  "weight_decay": 0.0001,
  "encoder_dropout": 0.1,
  "decoder_batchnorm": true,
  "rna_arch": "rna_med2",
  "adt_arch": "adt_med2",
  "atac_arch": "atac_med2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 08:49:22,694] [UniVITrainer] [INFO] [Epoch 001] Train loss: 1690.3035 (beta=180.000, gamma=60.000)
2025-11-20 08:49:22 - INFO - [Epoch 001] Train loss: 1690.3035 (beta=180.000, gamma=60.000)
[2025-11-20 08:49:23,499] [UniVITrainer] [INFO] [Epoch 001] Val loss: 1123.2552 (beta=180.000, gamma=60.000)
2025-11-20 08:49:23 - INFO - [Epoch 001] Val loss: 1123.2552 (beta=180.000, gamma=60.000)
[2025-11-20 08:49:23,592] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 1123.2552
2025-11-20 08:49:23 - INFO - [Epoch 001] New best val loss: 1123.2552
[2025-11-20 08:49:30,847] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 879.2043
2025-11-20 08:49:30 - INFO - [Epoch 002] New best val loss: 879.2043
[2025-11-20 08:49:38,239] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 836.9519
2025-11-20 08:49:38 - INFO - [Epoch 003] New best val loss: 836.9519
[2025-11-20 08:49:44,703] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 819.6089
2025-11-20 08:49:44 - INFO - [Epoch 0

[2025-11-20 08:55:59,453] [UniVITrainer] [INFO] [Epoch 059] New best val loss: 771.7233
2025-11-20 08:55:59 - INFO - [Epoch 059] New best val loss: 771.7233
[2025-11-20 08:56:06,114] [UniVITrainer] [INFO] [Epoch 060] Train loss: 775.8984 (beta=180.000, gamma=60.000)
2025-11-20 08:56:06 - INFO - [Epoch 060] Train loss: 775.8984 (beta=180.000, gamma=60.000)
[2025-11-20 08:56:06,897] [UniVITrainer] [INFO] [Epoch 060] Val loss: 771.9242 (beta=180.000, gamma=60.000)
2025-11-20 08:56:06 - INFO - [Epoch 060] Val loss: 771.9242 (beta=180.000, gamma=60.000)
[2025-11-20 08:56:39,439] [UniVITrainer] [INFO] [Epoch 065] New best val loss: 771.5744
2025-11-20 08:56:39 - INFO - [Epoch 065] New best val loss: 771.5744
[2025-11-20 08:56:53,547] [UniVITrainer] [INFO] [Epoch 067] New best val loss: 771.3670
2025-11-20 08:56:53 - INFO - [Epoch 067] New best val loss: 771.3670
[2025-11-20 08:57:08,238] [UniVITrainer] [INFO] [Epoch 069] New best val loss: 771.0732
2025-11-20 08:57:08 - INFO - [Epoch 069] Ne

2025-11-20 09:05:45 - INFO - [Epoch 150] Val loss: 768.8871 (beta=180.000, gamma=60.000)
[2025-11-20 09:05:59,554] [UniVITrainer] [INFO] [Epoch 153] New best val loss: 768.6102
2025-11-20 09:05:59 - INFO - [Epoch 153] New best val loss: 768.6102
[2025-11-20 09:06:06,891] [UniVITrainer] [INFO] [Epoch 155] New best val loss: 768.4896
2025-11-20 09:06:06 - INFO - [Epoch 155] New best val loss: 768.4896
[2025-11-20 09:06:21,441] [UniVITrainer] [INFO] [Epoch 160] Train loss: 770.3677 (beta=180.000, gamma=60.000)
2025-11-20 09:06:21 - INFO - [Epoch 160] Train loss: 770.3677 (beta=180.000, gamma=60.000)
[2025-11-20 09:06:21,941] [UniVITrainer] [INFO] [Epoch 160] Val loss: 768.5466 (beta=180.000, gamma=60.000)
2025-11-20 09:06:21 - INFO - [Epoch 160] Val loss: 768.5466 (beta=180.000, gamma=60.000)
[2025-11-20 09:06:40,316] [UniVITrainer] [INFO] [Epoch 164] New best val loss: 768.3952
2025-11-20 09:06:40 - INFO - [Epoch 164] New best val loss: 768.3952
[2025-11-20 09:06:54,537] [UniVITrainer] [

[Config 29] Done in 20.2 min
  best_val_loss              = 767.822
  FOSCTTM (RNA vs ADT, val)  = 0.1684
  FOSCTTM (RNA vs ATAC, val) = 0.1758
[Config 29] FOSCTTM (ADT vs ATAC, val) = 0.2053
  Mean FOSCTTM (3 pairs)     = 0.1832
  Modality mixing (k=20)     = 0.0232
  Composite score            = 929.56

[Config 30] Hyperparameters:
{
  "latent_dim": 20,
  "beta": 500.0,
  "gamma": 1000.0,
  "lr": 0.001,
  "weight_decay": 1e-05,
  "encoder_dropout": 0.1,
  "decoder_batchnorm": true,
  "rna_arch": "rna_wide2",
  "adt_arch": "adt_med2",
  "atac_arch": "atac_wide2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 09:09:34,798] [UniVITrainer] [INFO] [Epoch 001] Train loss: 4182.2584 (beta=500.000, gamma=1000.000)
2025-11-20 09:09:34 - INFO - [Epoch 001] Train loss: 4182.2584 (beta=500.000, gamma=1000.000)
[2025-11-20 09:09:35,301] [UniVITrainer] [INFO] [Epoch 001] Val loss: 1364.9265 (beta=500.000, gamma=1000.000)
2025-11-20 09:09:35 - INFO - [Epoch 001] Val loss: 1364.9265 (beta=500.000, gamma=1000.000)
[2025-11-20 09:09:35,344] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 1364.9265
2025-11-20 09:09:35 - INFO - [Epoch 001] New best val loss: 1364.9265
[2025-11-20 09:09:40,122] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 1047.9574
2025-11-20 09:09:40 - INFO - [Epoch 002] New best val loss: 1047.9574
[2025-11-20 09:09:44,972] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 931.4203
2025-11-20 09:09:44 - INFO - [Epoch 003] New best val loss: 931.4203
[2025-11-20 09:09:49,631] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 883.5238
2025-11-20 09:09:49 - INFO 

[2025-11-20 09:13:28,104] [UniVITrainer] [INFO] [Epoch 050] New best val loss: 769.1137
2025-11-20 09:13:28 - INFO - [Epoch 050] New best val loss: 769.1137
[2025-11-20 09:13:37,541] [UniVITrainer] [INFO] [Epoch 052] New best val loss: 769.0675
2025-11-20 09:13:37 - INFO - [Epoch 052] New best val loss: 769.0675
[2025-11-20 09:13:42,247] [UniVITrainer] [INFO] [Epoch 053] New best val loss: 768.8837
2025-11-20 09:13:42 - INFO - [Epoch 053] New best val loss: 768.8837
[2025-11-20 09:13:51,777] [UniVITrainer] [INFO] [Epoch 055] New best val loss: 768.4408
2025-11-20 09:13:51 - INFO - [Epoch 055] New best val loss: 768.4408
[2025-11-20 09:14:01,201] [UniVITrainer] [INFO] [Epoch 057] New best val loss: 768.2605
2025-11-20 09:14:01 - INFO - [Epoch 057] New best val loss: 768.2605
[2025-11-20 09:14:06,064] [UniVITrainer] [INFO] [Epoch 058] New best val loss: 768.0854
2025-11-20 09:14:06 - INFO - [Epoch 058] New best val loss: 768.0854
[2025-11-20 09:14:14,936] [UniVITrainer] [INFO] [Epoch 060

2025-11-20 09:25:53 - INFO - [Epoch 170] Val loss: 766.5153 (beta=500.000, gamma=1000.000)
[2025-11-20 09:25:53,433] [UniVITrainer] [INFO] Early stopping at epoch 170 (best val loss = 766.3970)
2025-11-20 09:25:53 - INFO - Early stopping at epoch 170 (best val loss = 766.3970)
[2025-11-20 09:25:53,491] [UniVITrainer] [INFO] Restored best model from epoch 150 (val loss = 766.3970)
2025-11-20 09:25:53 - INFO - Restored best model from epoch 150 (val loss = 766.3970)
[2025-11-20 09:25:55,907] [UniVITrainer] [INFO] TrainingConfig:
2025-11-20 09:25:55 - INFO - TrainingConfig:
[2025-11-20 09:25:55,908] [UniVITrainer] [INFO]   n_epochs: 200
2025-11-20 09:25:55 - INFO -   n_epochs: 200
[2025-11-20 09:25:55,909] [UniVITrainer] [INFO]   batch_size: 256
2025-11-20 09:25:55 - INFO -   batch_size: 256
[2025-11-20 09:25:55,910] [UniVITrainer] [INFO]   lr: 0.001
2025-11-20 09:25:55 - INFO -   lr: 0.001
[2025-11-20 09:25:55,911] [UniVITrainer] [INFO]   weight_decay: 0.0001
2025-11-20 09:25:55 - INFO -

[Config 30] Done in 16.4 min
  best_val_loss              = 766.397
  FOSCTTM (RNA vs ADT, val)  = 0.4953
  FOSCTTM (RNA vs ATAC, val) = 0.4842
[Config 30] FOSCTTM (ADT vs ATAC, val) = 0.4983
  Mean FOSCTTM (3 pairs)     = 0.4926
  Modality mixing (k=20)     = 0.0071
  Composite score            = 1152.09

[Config 31] Hyperparameters:
{
  "latent_dim": 32,
  "beta": 0.0,
  "gamma": 40.0,
  "lr": 0.001,
  "weight_decay": 0.0001,
  "encoder_dropout": 0.0,
  "decoder_batchnorm": false,
  "rna_arch": "rna_wide2",
  "adt_arch": "adt_med2",
  "atac_arch": "atac_small2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 09:26:04,488] [UniVITrainer] [INFO] [Epoch 001] Train loss: 891.4039 (beta=0.000, gamma=40.000)
2025-11-20 09:26:04 - INFO - [Epoch 001] Train loss: 891.4039 (beta=0.000, gamma=40.000)
[2025-11-20 09:26:05,675] [UniVITrainer] [INFO] [Epoch 001] Val loss: 764.7662 (beta=0.000, gamma=40.000)
2025-11-20 09:26:05 - INFO - [Epoch 001] Val loss: 764.7662 (beta=0.000, gamma=40.000)
[2025-11-20 09:26:05,792] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 764.7662
2025-11-20 09:26:05 - INFO - [Epoch 001] New best val loss: 764.7662
[2025-11-20 09:26:14,209] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 711.1315
2025-11-20 09:26:14 - INFO - [Epoch 002] New best val loss: 711.1315
[2025-11-20 09:26:24,085] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 667.6153
2025-11-20 09:26:24 - INFO - [Epoch 003] New best val loss: 667.6153
[2025-11-20 09:26:31,668] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 657.8009
2025-11-20 09:26:31 - INFO - [Epoch 004] New best v

2025-11-20 09:35:09 - INFO - [Epoch 075] New best val loss: 614.1552
[2025-11-20 09:35:44,274] [UniVITrainer] [INFO] [Epoch 080] Train loss: 549.7147 (beta=0.000, gamma=40.000)
2025-11-20 09:35:44 - INFO - [Epoch 080] Train loss: 549.7147 (beta=0.000, gamma=40.000)
[2025-11-20 09:35:45,056] [UniVITrainer] [INFO] [Epoch 080] Val loss: 615.0278 (beta=0.000, gamma=40.000)
2025-11-20 09:35:45 - INFO - [Epoch 080] Val loss: 615.0278 (beta=0.000, gamma=40.000)
[2025-11-20 09:35:52,385] [UniVITrainer] [INFO] [Epoch 081] New best val loss: 612.2406
2025-11-20 09:35:52 - INFO - [Epoch 081] New best val loss: 612.2406
[2025-11-20 09:36:35,541] [UniVITrainer] [INFO] [Epoch 087] New best val loss: 612.1973
2025-11-20 09:36:35 - INFO - [Epoch 087] New best val loss: 612.1973
[2025-11-20 09:36:58,153] [UniVITrainer] [INFO] [Epoch 090] Train loss: 539.5264 (beta=0.000, gamma=40.000)
2025-11-20 09:36:58 - INFO - [Epoch 090] Train loss: 539.5264 (beta=0.000, gamma=40.000)
[2025-11-20 09:36:59,213] [Uni

[Config 31] Done in 14.3 min
  best_val_loss              = 610.963
  FOSCTTM (RNA vs ADT, val)  = 0.0729
  FOSCTTM (RNA vs ATAC, val) = 0.0906
[Config 31] FOSCTTM (ADT vs ATAC, val) = 0.1012
  Mean FOSCTTM (3 pairs)     = 0.0882
  Modality mixing (k=20)     = 0.4295
  Composite score            = 950.42

[Config 32] Hyperparameters:
{
  "latent_dim": 200,
  "beta": 300.0,
  "gamma": 240.0,
  "lr": 0.0005,
  "weight_decay": 1e-05,
  "encoder_dropout": 0.0,
  "decoder_batchnorm": false,
  "rna_arch": "rna_wide3",
  "adt_arch": "adt_small2",
  "atac_arch": "atac_small2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 09:40:25,101] [UniVITrainer] [INFO] [Epoch 001] Train loss: 19342.2047 (beta=300.000, gamma=240.000)
2025-11-20 09:40:25 - INFO - [Epoch 001] Train loss: 19342.2047 (beta=300.000, gamma=240.000)
[2025-11-20 09:40:26,161] [UniVITrainer] [INFO] [Epoch 001] Val loss: 9768.4284 (beta=300.000, gamma=240.000)
2025-11-20 09:40:26 - INFO - [Epoch 001] Val loss: 9768.4284 (beta=300.000, gamma=240.000)
[2025-11-20 09:40:26,363] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 9768.4284
2025-11-20 09:40:26 - INFO - [Epoch 001] New best val loss: 9768.4284
[2025-11-20 09:40:34,657] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 4843.1672
2025-11-20 09:40:34 - INFO - [Epoch 002] New best val loss: 4843.1672
[2025-11-20 09:40:41,343] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 2534.2644
2025-11-20 09:40:41 - INFO - [Epoch 003] New best val loss: 2534.2644
[2025-11-20 09:40:49,380] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 1803.0614
2025-11-20 09:40:49 - INFO

2025-11-20 09:47:14 - INFO - [Epoch 047] New best val loss: 845.8028
[2025-11-20 09:47:24,082] [UniVITrainer] [INFO] [Epoch 048] New best val loss: 843.9423
2025-11-20 09:47:24 - INFO - [Epoch 048] New best val loss: 843.9423
[2025-11-20 09:47:33,466] [UniVITrainer] [INFO] [Epoch 049] New best val loss: 841.1292
2025-11-20 09:47:33 - INFO - [Epoch 049] New best val loss: 841.1292
[2025-11-20 09:47:42,107] [UniVITrainer] [INFO] [Epoch 050] Train loss: 797.8342 (beta=300.000, gamma=240.000)
2025-11-20 09:47:42 - INFO - [Epoch 050] Train loss: 797.8342 (beta=300.000, gamma=240.000)
[2025-11-20 09:47:43,170] [UniVITrainer] [INFO] [Epoch 050] Val loss: 840.3665 (beta=300.000, gamma=240.000)
2025-11-20 09:47:43 - INFO - [Epoch 050] Val loss: 840.3665 (beta=300.000, gamma=240.000)
[2025-11-20 09:47:43,389] [UniVITrainer] [INFO] [Epoch 050] New best val loss: 840.3665
2025-11-20 09:47:43 - INFO - [Epoch 050] New best val loss: 840.3665
[2025-11-20 09:47:53,265] [UniVITrainer] [INFO] [Epoch 051

[2025-11-20 09:58:10,506] [UniVITrainer] [INFO] [Epoch 130] Val loss: 806.8873 (beta=300.000, gamma=240.000)
2025-11-20 09:58:10 - INFO - [Epoch 130] Val loss: 806.8873 (beta=300.000, gamma=240.000)
[2025-11-20 09:58:10,687] [UniVITrainer] [INFO] [Epoch 130] New best val loss: 806.8873
2025-11-20 09:58:10 - INFO - [Epoch 130] New best val loss: 806.8873
[2025-11-20 09:58:46,513] [UniVITrainer] [INFO] [Epoch 135] New best val loss: 806.5604
2025-11-20 09:58:46 - INFO - [Epoch 135] New best val loss: 806.5604
[2025-11-20 09:58:54,037] [UniVITrainer] [INFO] [Epoch 136] New best val loss: 806.3702
2025-11-20 09:58:54 - INFO - [Epoch 136] New best val loss: 806.3702
[2025-11-20 09:59:22,843] [UniVITrainer] [INFO] [Epoch 140] Train loss: 774.6080 (beta=300.000, gamma=240.000)
2025-11-20 09:59:22 - INFO - [Epoch 140] Train loss: 774.6080 (beta=300.000, gamma=240.000)
[2025-11-20 09:59:23,299] [UniVITrainer] [INFO] [Epoch 140] Val loss: 806.5336 (beta=300.000, gamma=240.000)
2025-11-20 09:59:2

[Config 32] Done in 27.0 min
  best_val_loss              = 795.962
  FOSCTTM (RNA vs ADT, val)  = 0.4856
  FOSCTTM (RNA vs ATAC, val) = 0.5037
[Config 32] FOSCTTM (ADT vs ATAC, val) = 0.4956
  Mean FOSCTTM (3 pairs)     = 0.4950
  Modality mixing (k=20)     = 0.0141
  Composite score            = 1206.70

[Config 33] Hyperparameters:
{
  "latent_dim": 72,
  "beta": 100.0,
  "gamma": 140.0,
  "lr": 0.001,
  "weight_decay": 0.0001,
  "encoder_dropout": 0.0,
  "decoder_batchnorm": false,
  "rna_arch": "rna_wide3",
  "adt_arch": "adt_med2",
  "atac_arch": "atac_wide2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 10:07:27,187] [UniVITrainer] [INFO] [Epoch 001] Train loss: 2060.1318 (beta=100.000, gamma=140.000)
2025-11-20 10:07:27 - INFO - [Epoch 001] Train loss: 2060.1318 (beta=100.000, gamma=140.000)
[2025-11-20 10:07:28,282] [UniVITrainer] [INFO] [Epoch 001] Val loss: 1003.6471 (beta=100.000, gamma=140.000)
2025-11-20 10:07:28 - INFO - [Epoch 001] Val loss: 1003.6471 (beta=100.000, gamma=140.000)
[2025-11-20 10:07:28,586] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 1003.6471
2025-11-20 10:07:28 - INFO - [Epoch 001] New best val loss: 1003.6471
[2025-11-20 10:07:38,587] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 858.8225
2025-11-20 10:07:38 - INFO - [Epoch 002] New best val loss: 858.8225
[2025-11-20 10:07:47,937] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 823.2824
2025-11-20 10:07:47 - INFO - [Epoch 003] New best val loss: 823.2824
[2025-11-20 10:07:58,194] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 808.4771
2025-11-20 10:07:58 - INFO - [Epo

[2025-11-20 10:18:08,771] [UniVITrainer] [INFO] [Epoch 059] New best val loss: 772.4293
2025-11-20 10:18:08 - INFO - [Epoch 059] New best val loss: 772.4293
[2025-11-20 10:18:17,580] [UniVITrainer] [INFO] [Epoch 060] Train loss: 770.8941 (beta=100.000, gamma=140.000)
2025-11-20 10:18:17 - INFO - [Epoch 060] Train loss: 770.8941 (beta=100.000, gamma=140.000)
[2025-11-20 10:18:18,675] [UniVITrainer] [INFO] [Epoch 060] Val loss: 772.6317 (beta=100.000, gamma=140.000)
2025-11-20 10:18:18 - INFO - [Epoch 060] Val loss: 772.6317 (beta=100.000, gamma=140.000)
[2025-11-20 10:18:33,962] [UniVITrainer] [INFO] [Epoch 061] New best val loss: 772.4088
2025-11-20 10:18:33 - INFO - [Epoch 061] New best val loss: 772.4088
[2025-11-20 10:19:53,698] [UniVITrainer] [INFO] [Epoch 066] New best val loss: 772.0888
2025-11-20 10:19:53 - INFO - [Epoch 066] New best val loss: 772.0888
[2025-11-20 10:20:43,813] [UniVITrainer] [INFO] [Epoch 069] New best val loss: 772.0533
2025-11-20 10:20:43 - INFO - [Epoch 069

[2025-11-20 10:31:26,814] [UniVITrainer] [INFO] [Epoch 140] Train loss: 765.0983 (beta=100.000, gamma=140.000)
2025-11-20 10:31:26 - INFO - [Epoch 140] Train loss: 765.0983 (beta=100.000, gamma=140.000)
[2025-11-20 10:31:27,530] [UniVITrainer] [INFO] [Epoch 140] Val loss: 769.2179 (beta=100.000, gamma=140.000)
2025-11-20 10:31:27 - INFO - [Epoch 140] Val loss: 769.2179 (beta=100.000, gamma=140.000)
[2025-11-20 10:31:52,507] [UniVITrainer] [INFO] [Epoch 143] New best val loss: 769.0420
2025-11-20 10:31:52 - INFO - [Epoch 143] New best val loss: 769.0420
[2025-11-20 10:32:32,061] [UniVITrainer] [INFO] [Epoch 148] New best val loss: 768.9637
2025-11-20 10:32:32 - INFO - [Epoch 148] New best val loss: 768.9637
[2025-11-20 10:32:40,167] [UniVITrainer] [INFO] [Epoch 149] New best val loss: 768.8677
2025-11-20 10:32:40 - INFO - [Epoch 149] New best val loss: 768.8677
[2025-11-20 10:32:47,252] [UniVITrainer] [INFO] [Epoch 150] Train loss: 769.4236 (beta=100.000, gamma=140.000)
2025-11-20 10:32

[Config 33] Done in 33.2 min
  best_val_loss              = 767.810
  FOSCTTM (RNA vs ADT, val)  = 0.4914
  FOSCTTM (RNA vs ATAC, val) = 0.5564
[Config 33] FOSCTTM (ADT vs ATAC, val) = 0.5110
  Mean FOSCTTM (3 pairs)     = 0.5196
  Modality mixing (k=20)     = 0.0093
  Composite score            = 1177.59

[Config 34] Hyperparameters:
{
  "latent_dim": 20,
  "beta": 80.0,
  "gamma": 1000.0,
  "lr": 0.001,
  "weight_decay": 1e-05,
  "encoder_dropout": 0.1,
  "decoder_batchnorm": true,
  "rna_arch": "rna_wide3",
  "adt_arch": "adt_small2",
  "atac_arch": "atac_med2"
}


Training UniVI:   0%|          | 0/200 [00:00<?, ?it/s]

[2025-11-20 10:40:42,499] [UniVITrainer] [INFO] [Epoch 001] Train loss: 3617.7135 (beta=80.000, gamma=1000.000)
2025-11-20 10:40:42 - INFO - [Epoch 001] Train loss: 3617.7135 (beta=80.000, gamma=1000.000)
[2025-11-20 10:40:43,217] [UniVITrainer] [INFO] [Epoch 001] Val loss: 1278.2282 (beta=80.000, gamma=1000.000)
2025-11-20 10:40:43 - INFO - [Epoch 001] Val loss: 1278.2282 (beta=80.000, gamma=1000.000)
[2025-11-20 10:40:43,428] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 1278.2282
2025-11-20 10:40:43 - INFO - [Epoch 001] New best val loss: 1278.2282
[2025-11-20 10:40:49,913] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 1000.9229
2025-11-20 10:40:49 - INFO - [Epoch 002] New best val loss: 1000.9229
[2025-11-20 10:40:52,521] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 885.9991
2025-11-20 10:40:52 - INFO - [Epoch 003] New best val loss: 885.9991
[2025-11-20 10:40:56,473] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 849.4610
2025-11-20 10:40:56 - INFO - [E

2025-11-20 10:47:10 - INFO - [Epoch 050] Train loss: 771.5643 (beta=80.000, gamma=1000.000)
[2025-11-20 10:47:11,682] [UniVITrainer] [INFO] [Epoch 050] Val loss: 767.6783 (beta=80.000, gamma=1000.000)
2025-11-20 10:47:11 - INFO - [Epoch 050] Val loss: 767.6783 (beta=80.000, gamma=1000.000)
[2025-11-20 10:47:30,998] [UniVITrainer] [INFO] [Epoch 052] New best val loss: 767.5096
2025-11-20 10:47:30 - INFO - [Epoch 052] New best val loss: 767.5096
[2025-11-20 10:47:40,834] [UniVITrainer] [INFO] [Epoch 053] New best val loss: 767.4125
2025-11-20 10:47:40 - INFO - [Epoch 053] New best val loss: 767.4125
[2025-11-20 10:48:09,882] [UniVITrainer] [INFO] [Epoch 056] New best val loss: 767.3844
2025-11-20 10:48:09 - INFO - [Epoch 056] New best val loss: 767.3844
[2025-11-20 10:48:29,491] [UniVITrainer] [INFO] [Epoch 058] New best val loss: 767.2456
2025-11-20 10:48:29 - INFO - [Epoch 058] New best val loss: 767.2456
[2025-11-20 10:48:47,489] [UniVITrainer] [INFO] [Epoch 060] Train loss: 769.8746 

In [None]:
# ------------------------------------------------------
# 7. Hyperparameter search diagnostics & plots (TEA-seq)
# ------------------------------------------------------

rows = []
for r in all_results:
    hp = r["hp"]
    rows.append(
        {
            "config_id": r["config_id"],
            "latent_dim": hp["latent_dim"],
            "beta": hp["beta"],
            "gamma": hp["gamma"],
            "lr": hp["lr"],
            "weight_decay": hp["weight_decay"],
            "encoder_dropout": hp["encoder_dropout"],
            "decoder_batchnorm": hp["decoder_batchnorm"],
            "rna_arch": hp["rna_arch"]["name"],
            "adt_arch": hp["adt_arch"]["name"],
            "atac_arch": hp["atac_arch"]["name"],
            "val_loss": r["best_val_loss"],
            "fos_rna_adt": r["fos_rna_adt_val"],
            "fos_rna_atac": r["fos_rna_atac_val"],
            "fos_adt_atac": r["fos_adt_atac_val"],
            "fos_mean": r["fos_mean_val"],
            "mixing_score": r["mixing_score_val"],
            "score": r["score"],
        }
    )

df = pd.DataFrame(rows)
print("TEA-seq hyperparameter search results (head):")
print(df.head())

# 7.1 Metric relationships: pairplot
metrics = ["val_loss", "fos_mean", "mixing_score", "score"]
g = sns.pairplot(df[metrics], diag_kind="kde")
g.fig.suptitle("Metric relationships across TEA-seq configs", y=1.02)
plt.show()

# 7.2 FOS_mean vs mixing, colored by score
plt.figure(figsize=(6, 5))
scat = plt.scatter(
    df["fos_mean"],
    df["mixing_score"],
    c=df["score"],
    s=70,
    cmap="viridis",
    edgecolor="k",
    alpha=0.8,
)
plt.colorbar(scat, label="Composite score (lower = better)")
plt.xlabel("Mean FOSCTTM across pairs (lower = better)")
plt.ylabel("Modality mixing score (lower = better)")
plt.title("TEA-seq: FOS_mean vs mixing, colored by score")
plt.tight_layout()
plt.show()

# 7.3 Score vs individual numeric hyperparameters
num_hps = ["latent_dim", "beta", "gamma", "lr", "weight_decay", "encoder_dropout"]

fig, axes = plt.subplots(2, 3, figsize=(16, 8))
axes = axes.ravel()

for ax, hp_name in zip(axes, num_hps):
    ax.scatter(df[hp_name], df["score"], s=60, alpha=0.8)
    ax.set_xlabel(hp_name)
    ax.set_ylabel("score")
    ax.set_title(f"Score vs {hp_name}")
    if hp_name in ["lr", "weight_decay"]:
        ax.set_xscale("log")

plt.suptitle("TEA-seq: Score vs numeric hyperparameters", y=1.02)
plt.tight_layout()
plt.show()

# 7.4 Categorical hyperparams vs score (boxplots)
plt.figure(figsize=(8, 4))
sns.boxplot(data=df, x="decoder_batchnorm", y="score")
plt.xlabel("decoder_batchnorm")
plt.ylabel("score")
plt.title("TEA-seq: effect of decoder_batchnorm on score")
plt.tight_layout()
plt.show()

plt.figure(figsize=(10, 4))
sns.boxplot(data=df, x="rna_arch", y="score")
plt.xlabel("rna_arch")
plt.ylabel("score")
plt.title("TEA-seq: score distribution by RNA architecture")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

plt.figure(figsize=(10, 4))
sns.boxplot(data=df, x="adt_arch", y="score")
plt.xlabel("adt_arch")
plt.ylabel("score")
plt.title("TEA-seq: score distribution by ADT architecture")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

plt.figure(figsize=(10, 4))
sns.boxplot(data=df, x="atac_arch", y="score")
plt.xlabel("atac_arch")
plt.ylabel("score")
plt.title("TEA-seq: score distribution by ATAC architecture")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

# 7.5 Beta vs gamma with score as color (2D hyperparam landscape)
plt.figure(figsize=(6, 5))
scat = plt.scatter(
    df["beta"],
    df["gamma"],
    c=df["score"],
    s=80,
    cmap="viridis",
    edgecolor="k",
    alpha=0.9,
)
plt.colorbar(scat, label="Composite score (lower = better)")
plt.xlabel("beta")
plt.ylabel("gamma")
plt.title("TEA-seq: beta vs gamma hyperparameter landscape")
plt.tight_layout()
plt.show()

# 7.6 Correlation heatmap: hyperparams & metrics
corr_cols = [
    "latent_dim",
    "beta",
    "gamma",
    "lr",
    "weight_decay",
    "encoder_dropout",
    "val_loss",
    "fos_mean",
    "mixing_score",
    "score",
]
corr = df[corr_cols].corr()

plt.figure(figsize=(9, 7))
sns.heatmap(
    corr,
    annot=True,
    fmt=".2f",
    cmap="vlag",
    center=0.0,
    square=True,
)
plt.title("TEA-seq: correlation between hyperparameters and metrics")
plt.tight_layout()
plt.show()

# 7.7 Training curves for top K configs (by score)
top_k = 3  # tweak as desired
top_ids = df.nsmallest(top_k, "score")["config_id"].tolist()
print(f"Top {top_k} TEA-seq configs by score:", top_ids)

for cid in top_ids:
    res_c = next(r for r in all_results if r["config_id"] == cid)
    hist_c = res_c["history"]
    epochs_c = np.arange(1, len(hist_c["train_loss"]) + 1)

    plt.figure(figsize=(6, 4))
    plt.plot(epochs_c, hist_c["train_loss"], label="train")
    plt.plot(epochs_c, hist_c["val_loss"], label="val")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(f"TEA-seq training curves – Config {cid} (score={res_c['score']:.2f})")
    plt.legend()
    plt.tight_layout()
    plt.show()

# 7.8 Overlay val-loss curves for all configs
plt.figure(figsize=(7, 5))
for r in all_results:
    hist_r = r["history"]
    epochs_r = np.arange(1, len(hist_r["train_loss"]) + 1)
    plt.plot(
        epochs_r,
        hist_r["val_loss"],
        alpha=0.3,
        linewidth=1.0,
    )

plt.xlabel("Epoch")
plt.ylabel("Val loss")
plt.title("TEA-seq: val loss curves for all configs (overlay)")
plt.tight_layout()
plt.show()

# 7.9 Training curve for best config
hist = best_result["history"]
epochs = np.arange(1, len(hist["train_loss"]) + 1)

plt.figure(figsize=(6, 4))
plt.plot(epochs, hist["train_loss"], label="train")
plt.plot(epochs, hist["val_loss"], label="val")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Best UniVI TEA-seq config training curves (search)")
plt.legend()
plt.tight_layout()
plt.show()


#### Initialize model and data via dataloaders

In [27]:
adata_dict = {
    "rna": rna,
    "adt": adt,
    "atac": atac,
}


In [28]:
from univi.data import MultiModalDataset
from univi.config import UniVIConfig, ModalityConfig, TrainingConfig
from univi.models.univi import UniVIMultiModalVAE
from univi.trainer import UniVITrainer
'''
# ---------- UniVI config (Gaussian for all 3) ----------
univi_cfg = UniVIConfig(
    latent_dim=60,
    beta=100.0,
    gamma=120.0,
    encoder_dropout=0.0,
    decoder_dropout=0.0,
    encoder_batchnorm=True,
    decoder_batchnorm=False,
    kl_anneal_start=0,
    kl_anneal_end=0,
    align_anneal_start=0,
    align_anneal_end=0,
    modalities=[
        ModalityConfig(
            name="rna",
            input_dim=rna.n_vars,
            encoder_hidden=[512, 256],
            decoder_hidden=[256, 512],
            likelihood="gaussian",
        ),
        ModalityConfig(
            name="adt",
            input_dim=adt.n_vars,
            encoder_hidden=[128, 64],
            decoder_hidden=[64, 128],
            likelihood="gaussian",
        ),
        ModalityConfig(
            name="atac",
            input_dim=atac.n_vars,  # n_lsi (e.g. 50)
            encoder_hidden=[128, 64],
            decoder_hidden=[64, 128],
            likelihood="gaussian",
        ),
    ],
)

train_cfg = TrainingConfig(
    n_epochs=200,
    batch_size=256,
    lr=1e-3,
    weight_decay=1e-4,
    #device="cuda",   # or "cpu"
    device=device,
    log_every=10,
    grad_clip=5.0,
    num_workers=0,
    seed=42,
    early_stopping=True,
    patience=20,
    min_delta=0.0,
)
'''

'\n# ---------- UniVI config (Gaussian for all 3) ----------\nunivi_cfg = UniVIConfig(\n    latent_dim=60,\n    beta=100.0,\n    gamma=120.0,\n    encoder_dropout=0.0,\n    decoder_dropout=0.0,\n    encoder_batchnorm=True,\n    decoder_batchnorm=False,\n    kl_anneal_start=0,\n    kl_anneal_end=0,\n    align_anneal_start=0,\n    align_anneal_end=0,\n    modalities=[\n        ModalityConfig(\n            name="rna",\n            input_dim=rna.n_vars,\n            encoder_hidden=[512, 256],\n            decoder_hidden=[256, 512],\n            likelihood="gaussian",\n        ),\n        ModalityConfig(\n            name="adt",\n            input_dim=adt.n_vars,\n            encoder_hidden=[128, 64],\n            decoder_hidden=[64, 128],\n            likelihood="gaussian",\n        ),\n        ModalityConfig(\n            name="atac",\n            input_dim=atac.n_vars,  # n_lsi (e.g. 50)\n            encoder_hidden=[128, 64],\n            decoder_hidden=[64, 128],\n            likelihood

In [43]:
from univi.config import UniVIConfig, ModalityConfig, TrainingConfig

univi_cfg = UniVIConfig(
    latent_dim=32,
    beta=40.0,          # softer
    gamma=60.0,         # moderate alignment
    encoder_dropout=0.0,
    decoder_dropout=0.0,
    encoder_batchnorm=True,
    decoder_batchnorm=False,
    kl_anneal_start=0,
    kl_anneal_end=0,   # ramp KL up over first 50 epochs
    align_anneal_start=5,  # let reconstructions stabilize a bit first
    align_anneal_end=15,
    modalities=[
        ModalityConfig(
            name="rna",
            input_dim=rna.n_vars,
            encoder_hidden=[512, 256],
            decoder_hidden=[256, 512],
            likelihood="gaussian",
        ),
        ModalityConfig(
            name="adt",
            input_dim=adt.n_vars,
            encoder_hidden=[128, 64],
            decoder_hidden=[64, 128],
            likelihood="gaussian",
        ),
        ModalityConfig(
            name="atac",
            input_dim=atac.n_vars,  # n_lsi
            encoder_hidden=[128, 64],
            decoder_hidden=[64, 128],
            likelihood="gaussian",
        ),
    ],
)

train_cfg = TrainingConfig(
    n_epochs=300,
    batch_size=256,
    lr=1e-3,
    weight_decay=1e-4,
    device=device,      # "cuda"
    log_every=10,
    grad_clip=5.0,
    num_workers=0,
    seed=42,
    early_stopping=True,
    patience=35,
    min_delta=0.0,
)


In [44]:
from torch.utils.data import DataLoader, Subset

dataset = MultiModalDataset(
    adata_dict=adata_dict,
    X_key="X",
    device=train_cfg.device,
)

n_cells = dataset.n_cells
indices = np.arange(n_cells)
rng = np.random.default_rng(42)
rng.shuffle(indices)

frac_train = 0.8
frac_val   = 0.1
n_train = int(frac_train * n_cells)
n_val   = int(frac_val * n_cells)

train_idx = indices[:n_train]
val_idx   = indices[n_train:n_train + n_val]
test_idx  = indices[n_train + n_val:]

train_ds = Subset(dataset, train_idx)
val_ds   = Subset(dataset, val_idx)

train_loader = DataLoader(
    train_ds,
    batch_size=train_cfg.batch_size,
    shuffle=True,
    num_workers=train_cfg.num_workers,
)

val_loader = DataLoader(
    val_ds,
    batch_size=train_cfg.batch_size,
    shuffle=False,
    num_workers=train_cfg.num_workers,
)

model = UniVIMultiModalVAE(univi_cfg).to(train_cfg.device)
trainer = UniVITrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    train_cfg=train_cfg,
    device=train_cfg.device,
)


[2025-11-19 21:49:07,215] [UniVITrainer] [INFO] TrainingConfig:
2025-11-19 21:49:07 - INFO - TrainingConfig:
[2025-11-19 21:49:07,217] [UniVITrainer] [INFO]   n_epochs: 300
2025-11-19 21:49:07 - INFO -   n_epochs: 300
[2025-11-19 21:49:07,252] [UniVITrainer] [INFO]   batch_size: 256
2025-11-19 21:49:07 - INFO -   batch_size: 256
[2025-11-19 21:49:07,259] [UniVITrainer] [INFO]   lr: 0.001
2025-11-19 21:49:07 - INFO -   lr: 0.001
[2025-11-19 21:49:07,260] [UniVITrainer] [INFO]   weight_decay: 0.0001
2025-11-19 21:49:07 - INFO -   weight_decay: 0.0001
[2025-11-19 21:49:07,260] [UniVITrainer] [INFO]   device: cuda
2025-11-19 21:49:07 - INFO -   device: cuda
[2025-11-19 21:49:07,263] [UniVITrainer] [INFO]   log_every: 10
2025-11-19 21:49:07 - INFO -   log_every: 10
[2025-11-19 21:49:07,268] [UniVITrainer] [INFO]   grad_clip: 5.0
2025-11-19 21:49:07 - INFO -   grad_clip: 5.0
[2025-11-19 21:49:07,268] [UniVITrainer] [INFO]   num_workers: 0
2025-11-19 21:49:07 - INFO -   num_workers: 0
[2025-1

#### Train model

In [None]:
# ---------- train ----------
history = trainer.fit()


Training UniVI:   0%|          | 0/300 [00:00<?, ?it/s]

[2025-11-19 21:49:16,390] [UniVITrainer] [INFO] [Epoch 001] Train loss: 917.2002 (beta=40.000, gamma=0.000)
2025-11-19 21:49:16 - INFO - [Epoch 001] Train loss: 917.2002 (beta=40.000, gamma=0.000)
[2025-11-19 21:49:17,040] [UniVITrainer] [INFO] [Epoch 001] Val loss: 800.3322 (beta=40.000, gamma=0.000)
2025-11-19 21:49:17 - INFO - [Epoch 001] Val loss: 800.3322 (beta=40.000, gamma=0.000)
[2025-11-19 21:49:17,081] [UniVITrainer] [INFO] [Epoch 001] New best val loss: 800.3322
2025-11-19 21:49:17 - INFO - [Epoch 001] New best val loss: 800.3322
[2025-11-19 21:49:24,745] [UniVITrainer] [INFO] [Epoch 002] New best val loss: 771.9595
2025-11-19 21:49:24 - INFO - [Epoch 002] New best val loss: 771.9595
[2025-11-19 21:49:31,834] [UniVITrainer] [INFO] [Epoch 003] New best val loss: 768.6938
2025-11-19 21:49:31 - INFO - [Epoch 003] New best val loss: 768.6938
[2025-11-19 21:49:39,265] [UniVITrainer] [INFO] [Epoch 004] New best val loss: 767.7239
2025-11-19 21:49:39 - INFO - [Epoch 004] New best v

2025-11-19 21:58:31 - INFO - [Epoch 077] New best val loss: 766.5197
[2025-11-19 21:58:53,502] [UniVITrainer] [INFO] [Epoch 080] Train loss: 767.9334 (beta=40.000, gamma=60.000)
2025-11-19 21:58:53 - INFO - [Epoch 080] Train loss: 767.9334 (beta=40.000, gamma=60.000)
[2025-11-19 21:58:54,313] [UniVITrainer] [INFO] [Epoch 080] Val loss: 766.5602 (beta=40.000, gamma=60.000)
2025-11-19 21:58:54 - INFO - [Epoch 080] Val loss: 766.5602 (beta=40.000, gamma=60.000)
[2025-11-19 21:59:01,973] [UniVITrainer] [INFO] [Epoch 081] New best val loss: 766.5141
2025-11-19 21:59:01 - INFO - [Epoch 081] New best val loss: 766.5141
[2025-11-19 21:59:16,139] [UniVITrainer] [INFO] [Epoch 083] New best val loss: 766.4804
2025-11-19 21:59:16 - INFO - [Epoch 083] New best val loss: 766.4804
[2025-11-19 21:59:38,931] [UniVITrainer] [INFO] [Epoch 086] New best val loss: 766.4589
2025-11-19 21:59:38 - INFO - [Epoch 086] New best val loss: 766.4589
[2025-11-19 22:00:08,377] [UniVITrainer] [INFO] [Epoch 090] Train 

In [None]:
import matplotlib.pyplot as plt

# Quick training curves
fig, ax = plt.subplots()
ax.plot(history["train_loss"], label="train")
ax.plot(history["val_loss"], label="val")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.set_title("UniVI Multiome training curves")
ax.legend()
plt.tight_layout()
plt.show()

fig, ax = plt.subplots()
ax.plot(history["beta"], label="beta")
ax.plot(history["gamma"], label="gamma")
ax.set_xlabel("Epoch")
ax.set_ylabel("Weight")
ax.set_title("KL / alignment annealing")
ax.legend()
plt.tight_layout()
plt.show()


In [None]:
from dataclasses import asdict
import torch

os.makedirs("../saved_models", exist_ok=True)

# trainer.model already has the best weights (because we restored best_state_dict)
ckpt_path = "../saved_models/univi_tea_seq_beta_40_gamma_60_32_latent_dims.pt"
torch.save(
    {
        "state_dict": trainer.model.state_dict(),
        "univi_cfg": asdict(univi_cfg),
        "best_epoch": trainer.best_epoch,
        "best_val_loss": trainer.best_val_loss,
    },
    ckpt_path,
)
print("Saved best model to:", ckpt_path)


#### Evaluate model

In [None]:
import torch
from univi.config import UniVIConfig, ModalityConfig
from univi.models.univi import UniVIMultiModalVAE

device = "cpu"  # or "cuda" if available

ckpt = torch.load(
    "../saved_models/univi_tea_seq_beta_40_gamma_60_32_latent_dims.pt",
    map_location=device,
)

# ---- Rebuild UniVIConfig, making sure modalities are ModalityConfig objects ----
cfg_dict = ckpt["univi_cfg"]

# If this is an OmegaConf object or similar, make sure it's a plain dict
try:
    from omegaconf import DictConfig, OmegaConf
    if isinstance(cfg_dict, DictConfig):
        cfg_dict = OmegaConf.to_container(cfg_dict, resolve=True)
except ImportError:
    pass

# Now rehydrate each modality
modalities = [ModalityConfig(**m) for m in cfg_dict["modalities"]]
cfg_dict = {**cfg_dict, "modalities": modalities}

univi_cfg_loaded = UniVIConfig(**cfg_dict)

# ---- Rebuild model + load weights ----
model_loaded = UniVIMultiModalVAE(univi_cfg_loaded).to(device)
model_loaded.load_state_dict(ckpt["state_dict"])

print("Best epoch was:", ckpt.get("best_epoch"), "val loss =", ckpt.get("best_val_loss"))


In [None]:
z_rna  = trainer.encode_modality(rna,  modality="rna")
z_adt  = trainer.encode_modality(adt,  modality="adt")
z_atac = trainer.encode_modality(atac, modality="atac")

rna.obsm["X_univi"]  = z_rna
adt.obsm["X_univi"]  = z_adt
atac.obsm["X_univi"] = z_atac


In [None]:
print(history)


In [None]:
# These should be the *same* cells across modalities
assert np.array_equal(rna.obs_names, adt.obs_names)
assert np.array_equal(rna.obs_names, atac.obs_names)


In [None]:
# ------------------------------
# Build train / val / test adatas
# ------------------------------

rna_train_adata  = rna[train_idx].copy()
rna_val_adata    = rna[val_idx].copy()
rna_test_adata   = rna[test_idx].copy()

adt_train_adata  = adt[train_idx].copy()
adt_val_adata    = adt[val_idx].copy()
adt_test_adata   = adt[test_idx].copy()

atac_train_adata = atac[train_idx].copy()
atac_val_adata   = atac[val_idx].copy()
atac_test_adata  = atac[test_idx].copy()


In [None]:
print(rna_test_adata)
print(adt_test_adata)
print(atac_test_adata)


In [None]:
# ============================
# UniVI TEA-seq evaluation (RNA / ADT / ATAC) – label-free, tri-modal
# ============================
import os
import numpy as np
import torch
import scipy.sparse as sp
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc

from sklearn.metrics import silhouette_score
from sklearn.neighbors import NearestNeighbors

from univi import evaluation as univi_eval
from univi import plotting as univi_plot  # not strictly needed, but left for convenience

# -----------------------------------------
# CONFIG
# -----------------------------------------
FIGDIR = "../figures/teaseq_univi_tri_modal_eval"
os.makedirs(FIGDIR, exist_ok=True)

device = train_cfg.device  # e.g. "cuda" or "cpu"

# -----------------------------------------
# 0. Sanity checks
# -----------------------------------------
assert (
    rna_test_adata.n_obs == adt_test_adata.n_obs == atac_test_adata.n_obs
), "RNA / ADT / ATAC TEST sets must have same #cells"
assert np.array_equal(rna_test_adata.obs_names, adt_test_adata.obs_names), (
    "RNA and ADT obs_names must match 1:1 for pairwise metrics."
)
assert np.array_equal(rna_test_adata.obs_names, atac_test_adata.obs_names), (
    "RNA and ATAC obs_names must match 1:1 for pairwise metrics."
)

print(f"Test cells: {rna_test_adata.n_obs}")

# -----------------------------------------
# 1. Encode latent embeddings for test sets
# -----------------------------------------
print("\nEncoding test sets into UniVI latent space...")

z_rna  = univi_eval.encode_adata(model, rna_test_adata,  modality="rna",  device=device)
z_adt  = univi_eval.encode_adata(model, adt_test_adata,  modality="adt",  device=device)
z_atac = univi_eval.encode_adata(model, atac_test_adata, modality="atac", device=device)

rna_test_adata.obsm["X_univi"]  = z_rna
adt_test_adata.obsm["X_univi"]  = z_adt
atac_test_adata.obsm["X_univi"] = z_atac

print("Latent shapes (test):")
print("  RNA :", z_rna.shape)
print("  ADT :", z_adt.shape)
print("  ATAC:", z_atac.shape)

# -----------------------------------------
# 2. FOSCTTM (pairwise alignment)
# -----------------------------------------
print("\nComputing FOSCTTM for each modality pair (lower = better)...")
fos_rna_adt  = univi_eval.compute_foscttm(z_rna,  z_adt)
fos_rna_atac = univi_eval.compute_foscttm(z_rna,  z_atac)
fos_adt_atac = univi_eval.compute_foscttm(z_adt,  z_atac)

print(f"  RNA  vs ADT : {fos_rna_adt:.4f}")
print(f"  RNA  vs ATAC: {fos_rna_atac:.4f}")
print(f"  ADT  vs ATAC: {fos_adt_atac:.4f}")

plt.figure(figsize=(4, 4))
pairs = ["RNA–ADT", "RNA–ATAC", "ADT–ATAC"]
vals = [fos_rna_adt, fos_rna_atac, fos_adt_atac]
sns.barplot(x=pairs, y=vals)
plt.ylabel("FOSCTTM (mean)")
plt.title("Tri-modal FOSCTTM (lower = better)")
plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, "foscttm_barplot.png"))
plt.show()
plt.close()

# -----------------------------------------
# 3. Modality mixing (all three modalities)
# -----------------------------------------
Z_joint = np.concatenate([z_rna, z_adt, z_atac], axis=0)
modality_labels = np.array(
    ["rna"]  * z_rna.shape[0]
    + ["adt"]  * z_adt.shape[0]
    + ["atac"] * z_atac.shape[0]
)

mixing_score = univi_eval.compute_modality_mixing(
    Z_joint,
    modality_labels,
    k=20,
)
print(f"\nGlobal modality mixing score (RNA/ADT/ATAC, k=20): {mixing_score:.3f}")

# kNN modality composition heatmap
print("Computing kNN modality composition...")
k = 20
nn = NearestNeighbors(n_neighbors=k + 1)
nn.fit(Z_joint)
_, idx = nn.kneighbors(Z_joint)

idx_neighbors = idx[:, 1:]  # drop self
neighbor_mods = modality_labels[idx_neighbors]

modalities = np.array(["rna", "adt", "atac"])
comp_matrix = np.zeros((len(modalities), len(modalities)))  # row = center, col = neighbor

for i, m_center in enumerate(modalities):
    mask_center = modality_labels == m_center
    neigh_for_center = neighbor_mods[mask_center].reshape(-1)
    for j, m_neigh in enumerate(modalities):
        comp_matrix[i, j] = (neigh_for_center == m_neigh).mean()

same_mod_frac = (neighbor_mods == modality_labels[:, None]).mean()
print(f"  Fraction of neighbors with same modality (k={k}): {same_mod_frac:.3f}")

plt.figure(figsize=(5, 4))
sns.heatmap(
    comp_matrix,
    annot=True,
    fmt=".2f",
    xticklabels=modalities,
    yticklabels=modalities,
    cmap="viridis",
)
plt.xlabel("Neighbor modality")
plt.ylabel("Center modality")
plt.title(f"kNN modality composition (k={k})")
plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, "knn_modality_composition.png"))
plt.show()
plt.close()

# -----------------------------------------
# 4. UMAP on UniVI latent (tri-modal)
# -----------------------------------------
print("\nBuilding tri-modal UMAP on UniVI latent...")

# Tag each test set with modality
rna_tmp  = rna_test_adata.copy()
adt_tmp  = adt_test_adata.copy()
atac_tmp = atac_test_adata.copy()

rna_tmp.obs["univi_source"]  = "rna"
adt_tmp.obs["univi_source"]  = "adt"
atac_tmp.obs["univi_source"] = "atac"

combined = rna_tmp.concatenate(
    adt_tmp,
    atac_tmp,
    join="outer",
    batch_key="univi_batch",
    batch_categories=["rna", "adt", "atac"],
    index_unique=None,
)

# ensure latent is correctly stacked
combined.obsm["X_univi"] = np.vstack([
    rna_test_adata.obsm["X_univi"],
    adt_test_adata.obsm["X_univi"],
    atac_test_adata.obsm["X_univi"],
])

# neighbors / UMAP / Leiden
sc.pp.neighbors(combined, use_rep="X_univi", n_neighbors=20)
sc.tl.umap(combined)
sc.tl.leiden(combined, key_added="univi_leiden", resolution=0.5)

# UMAP colored by modality
sc.pl.umap(
    combined,
    color="univi_source",
    size=8,
    alpha=0.7,
    show=False,
)
plt.savefig(os.path.join(FIGDIR, "umap_tri_modal_modality.png"), bbox_inches="tight")
plt.show()
plt.close()

# UMAP colored by Leiden clusters (pseudo-clusters)
sc.pl.umap(
    combined,
    color="univi_leiden",
    size=8,
    alpha=0.7,
    show=False,
)
plt.savefig(os.path.join(FIGDIR, "umap_tri_modal_leiden.png"), bbox_inches="tight")
plt.show()
plt.close()

# Per-modality UMAPs, colored by Leiden
for mod in ["rna", "adt", "atac"]:
    sub = combined[combined.obs["univi_source"] == mod].copy()
    sc.pl.umap(
        sub,
        color="univi_leiden",
        size=8,
        alpha=0.7,
        show=False,
    )
    plt.savefig(
        os.path.join(FIGDIR, f"umap_{mod}_only_leiden.png"),
        bbox_inches="tight",
    )
    plt.show()
    plt.close()

# -----------------------------------------
# 5. Latent geometry diagnostics
# -----------------------------------------
print("\nLatent geometry diagnostics...")

def latent_norms(z, label):
    norms = np.linalg.norm(z, axis=1)
    return norms, np.repeat(label, len(norms))

norm_rna,  lab_rna  = latent_norms(z_rna,  "RNA")
norm_adt,  lab_adt  = latent_norms(z_adt,  "ADT")
norm_atac, lab_atac = latent_norms(z_atac, "ATAC")

norms_all = np.concatenate([norm_rna, norm_adt, norm_atac])
labs_all  = np.concatenate([lab_rna, lab_adt, lab_atac])

plt.figure(figsize=(5, 4))
sns.violinplot(x=labs_all, y=norms_all, inner="box")
plt.ylabel("‖z‖ (L2 norm)")
plt.xlabel("Modality")
plt.title("Latent L2-norm distribution by modality")
plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, "latent_norms_by_modality.png"))
plt.show()
plt.close()

# Latent correlation heatmap of RNA latent dims
corr_latent = np.corrcoef(z_rna, rowvar=False)
plt.figure(figsize=(6, 5))
sns.heatmap(corr_latent, cmap="vlag", center=0)
plt.title("RNA latent dimension correlation (test set)")
plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, "latent_corr_heatmap_rna.png"))
plt.show()
plt.close()

# Silhouette score on modality (lower = better mixing)
if len(np.unique(modality_labels)) > 1:
    sil_mod = silhouette_score(Z_joint, modality_labels)
else:
    sil_mod = np.nan
print(f"Silhouette (modality) on UniVI latent: {sil_mod:.3f}")

# -----------------------------------------
# 6. Local modality entropy (kNN) + UMAP
# -----------------------------------------
print("\nComputing local modality entropy...")

n_neighbors_local = 20
nn_local = NearestNeighbors(n_neighbors=n_neighbors_local, metric="euclidean")
nn_local.fit(Z_joint)
_, idx_local = nn_local.kneighbors(Z_joint)

mods = modality_labels
local_modality_entropy = []

for i in range(Z_joint.shape[0]):
    neigh = idx_local[i, 1:]  # drop self
    neigh_mod = mods[neigh]

    # empirical distribution over modalities
    ent = 0.0
    for m in modalities:
        p = (neigh_mod == m).mean()
        if p > 0:
            ent -= p * np.log2(p)
    local_modality_entropy.append(ent)

local_modality_entropy = np.asarray(local_modality_entropy)
combined.obs["local_modality_entropy"] = local_modality_entropy

print(f"Mean local modality entropy (k={n_neighbors_local}): {local_modality_entropy.mean():.3f}")

plt.figure(figsize=(5, 4))
plt.hist(local_modality_entropy, bins=30)
plt.xlabel("Local modality entropy (bits)")
plt.ylabel("Cells")
plt.title("kNN modality entropy")
plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, "local_modality_entropy_hist.png"))
plt.show()
plt.close()

# UMAP colored by local modality entropy
sc.pl.umap(
    combined,
    color="local_modality_entropy",
    size=8,
    alpha=0.7,
    show=False,
)
plt.savefig(os.path.join(FIGDIR, "umap_local_modality_entropy.png"), bbox_inches="tight")
plt.show()
plt.close()

# -----------------------------------------
# 7. Pairwise matching metrics (top-k) for all modality pairs
# -----------------------------------------
print("\nPairwise matching metrics (top-1 / top-5 / top-10)...")

def topk_matching(z_src, z_tgt, pair_name: str, k_match: int = 10):
    nn = NearestNeighbors(n_neighbors=k_match, metric="euclidean")
    nn.fit(z_tgt)
    _, idx_knn = nn.kneighbors(z_src)

    true_idx = np.arange(z_src.shape[0])
    top1_hits  = (idx_knn[:, 0] == true_idx)
    top5_hits  = (idx_knn[:, :5] == true_idx[:, None]).any(axis=1)
    top10_hits = (idx_knn[:, :10] == true_idx[:, None]).any(axis=1)

    print(f"  {pair_name}:")
    print(f"    Top-1 accuracy:  {top1_hits.mean():.3f}")
    print(f"    Top-5 accuracy:  {top5_hits.mean():.3f}")
    print(f"    Top-10 accuracy: {top10_hits.mean():.3f}")

    plt.figure(figsize=(5, 4))
    plt.bar(
        ["Top-1", "Top-5", "Top-10"],
        [top1_hits.mean(), top5_hits.mean(), top10_hits.mean()],
    )
    plt.ylabel("Fraction of correctly matched pairs")
    plt.title(f"Cross-modal matching accuracy ({pair_name})")
    plt.tight_layout()
    fname = f"matching_{pair_name.replace(' ', '_').replace('→','to')}.png"
    plt.savefig(os.path.join(FIGDIR, fname))
    plt.show()
    plt.close()

topk_matching(z_rna,  z_adt,  "RNA→ADT")
topk_matching(z_adt,  z_rna,  "ADT→RNA")
topk_matching(z_rna,  z_atac, "RNA→ATAC")
topk_matching(z_atac, z_rna,  "ATAC→RNA")
topk_matching(z_adt,  z_atac, "ADT→ATAC")
topk_matching(z_atac, z_adt,  "ATAC→ADT")

# -----------------------------------------
# 8. Cross-modal reconstruction metrics
# -----------------------------------------
def _to_dense(X):
    return X.toarray() if sp.issparse(X) else np.asarray(X)

def cross_modal_metrics(
    model,
    src_adata,
    tgt_adata,
    src_mod: str,
    tgt_mod: str,
    name_prefix: str,
    device: str,
):
    Xhat_tgt = univi_eval.cross_modal_predict(
        model,
        adata_src=src_adata,
        src_mod=src_mod,
        tgt_mod=tgt_mod,
        device=device,
        batch_size=512,
    )

    X_tgt = _to_dense(tgt_adata.X)

    mse_feat  = univi_eval.mse_per_feature(X_tgt, Xhat_tgt)
    corr_feat = univi_eval.pearson_corr_per_feature(X_tgt, Xhat_tgt)

    print(f"\nCross-modal reconstruction: {src_mod} → {tgt_mod}")
    print(f"  Mean feature MSE: {mse_feat.mean():.4f}")
    print(f"  Mean feature Pearson r: {corr_feat.mean():.3f}")

    # Histogram of per-feature correlation
    plt.figure(figsize=(5, 4))
    sns.histplot(corr_feat, bins=40, kde=False)
    plt.xlabel("Per-feature Pearson r")
    plt.ylabel("Count")
    plt.title(f"{src_mod} → {tgt_mod}: feature-wise correlation")
    plt.tight_layout()
    plt.savefig(os.path.join(FIGDIR, f"{name_prefix}_corr_hist.png"))
    plt.show()
    plt.close()

    # Histogram of per-feature MSE
    plt.figure(figsize=(5, 4))
    sns.histplot(mse_feat, bins=40, kde=False)
    plt.xlabel("Per-feature MSE")
    plt.ylabel("Count")
    plt.title(f"{src_mod} → {tgt_mod}: feature-wise MSE")
    plt.tight_layout()
    plt.savefig(os.path.join(FIGDIR, f"{name_prefix}_mse_hist.png"))
    plt.show()
    plt.close()

    return mse_feat, corr_feat

# Evaluate key directions on TEST set
_ = cross_modal_metrics(model, rna_test_adata, adt_test_adata,
                        src_mod="rna", tgt_mod="adt",
                        name_prefix="RNA_to_ADT_test", device=device)

_ = cross_modal_metrics(model, rna_test_adata, atac_test_adata,
                        src_mod="rna", tgt_mod="atac",
                        name_prefix="RNA_to_ATAC_test", device=device)

_ = cross_modal_metrics(model, adt_test_adata, rna_test_adata,
                        src_mod="adt", tgt_mod="rna",
                        name_prefix="ADT_to_RNA_test", device=device)

_ = cross_modal_metrics(model, atac_test_adata, rna_test_adata,
                        src_mod="atac", tgt_mod="rna",
                        name_prefix="ATAC_to_RNA_test", device=device)

# -----------------------------------------
# 9. Denoising with decoders (on "unused" cells)
# -----------------------------------------
print("\nDenoising on *_unused sets (if present)...")

for adata, mod, tag in [
    (locals().get("rna_unused",  None), "rna",  "rna_unused"),
    (locals().get("adt_unused",  None), "adt",  "adt_unused"),
    (locals().get("atac_unused", None), "atac", "atac_unused"),
]:
    if adata is None or adata.n_obs == 0:
        print(f"  Skipping {tag} ({mod}) – no cells.")
        continue

    print(f"  Denoising {tag} ({mod})...")
    univi_eval.denoise_adata(
        model, adata, modality=mod,
        device=device, batch_size=512,
        out_layer="univi_denoised",
    )
    X_raw = _to_dense(adata.X)
    X_den = _to_dense(adata.layers["univi_denoised"])
    per_cell_mse = ((X_raw - X_den) ** 2).mean(axis=1)

    plt.figure(figsize=(5, 4))
    sns.histplot(per_cell_mse, bins=40)
    plt.xlabel("Per-cell MSE (raw vs denoised)")
    plt.ylabel("Count")
    plt.title(f"Denoising quality: {tag} ({mod})")
    plt.tight_layout()
    plt.savefig(os.path.join(FIGDIR, f"denoise_{tag}_{mod}.png"))
    plt.show()
    plt.close()

print("\nNo curated celltypes; all metrics are label-free (modality-based only).")


In [None]:
# ============================
# 10. Summarize TEA-seq metrics & save as JSON
# ============================
import json

teaseq_metrics = {
    # Pairwise FOSCTTM
    "foscttm_rna_adt": float(fos_rna_adt),
    "foscttm_rna_atac": float(fos_rna_atac),
    "foscttm_adt_atac": float(fos_adt_atac),

    # Global mixing
    "modality_mixing_k20": float(mixing_score),
    "same_modality_neighbor_frac_k20": float(same_mod_frac),

    # Latent geometry
    "silhouette_modality": float(sil_mod) if not np.isnan(sil_mod) else None,
    "mean_local_modality_entropy_k20": float(local_modality_entropy.mean()),
    "median_local_modality_entropy_k20": float(np.median(local_modality_entropy)),

    # Dataset sizes
    "n_cells_test": int(rna_test_adata.n_obs),
    "n_genes_rna": int(rna_test_adata.n_vars),
    "n_features_adt": int(adt_test_adata.n_vars),
    "n_features_atac": int(atac_test_adata.n_vars),
}

'''
metrics_path = os.path.join(FIGDIR, "teaseq_univi_metrics.json")
with open(metrics_path, "w") as f:
    json.dump(teaseq_metrics, f, indent=2)

print(f"\n[TEA-seq] Saved benchmark metrics to: {metrics_path}")
'''

# ============================
# 11. kNN distance diagnostics: same vs cross-modality neighbors
# ============================
print("\nComputing kNN distance distributions (same vs cross-modality)...")

from sklearn.neighbors import NearestNeighbors

k_dist = 20
nn_dist = NearestNeighbors(n_neighbors=k_dist + 1, metric="euclidean")
dists_all, idx_all = nn_dist.kneighbors(Z_joint)

# drop self
dists_neighbors = dists_all[:, 1:]              # (n_cells, k_dist)
idx_neighbors_dist = idx_all[:, 1:]             # already had neighbor_mods from earlier
flat_dists = dists_neighbors.reshape(-1)

# Center & neighbor modalities (flattened to per-edge view)
center_mods_flat = np.repeat(modality_labels, k_dist)
neighbor_mods_flat = neighbor_mods.reshape(-1)

same_mod_edge = neighbor_mods_flat == center_mods_flat

plt.figure(figsize=(6, 4))
sns.kdeplot(flat_dists[same_mod_edge], label="same modality", fill=True, alpha=0.6)
sns.kdeplot(flat_dists[~same_mod_edge], label="different modality", fill=True, alpha=0.6)
plt.xlabel("Euclidean distance in UniVI latent")
plt.ylabel("Density")
plt.title(f"kNN distance distribution (k={k_dist})")
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, "knn_distance_same_vs_cross_modality.png"), dpi=200)
plt.show()
plt.close()


# ============================
# 12. Cross-modal neighbor fraction per cell (and UMAP)
# ============================
print("\nComputing per-cell cross-modal neighbor fraction...")

# fraction of neighbors NOT sharing the cell's modality
cross_mod_neighbor_frac = 1.0 - (neighbor_mods == modality_labels[:, None]).mean(axis=1)
combined.obs["cross_mod_neighbor_frac"] = cross_mod_neighbor_frac

plt.figure(figsize=(5, 4))
sns.histplot(cross_mod_neighbor_frac, bins=30)
plt.xlabel("Fraction of neighbors with different modality")
plt.ylabel("Cells")
plt.title(f"Cross-modal neighbor fraction (k={k})")
plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, "cross_mod_neighbor_fraction_hist.png"), dpi=200)
plt.show()
plt.close()

# UMAP colored by cross-modal neighbor fraction
sc.pl.umap(
    combined,
    color="cross_mod_neighbor_frac",
    size=8,
    alpha=0.7,
    cmap="viridis",
    show=False,
)
plt.title("UMAP – cross-modal neighbor fraction")
plt.savefig(os.path.join(FIGDIR, "umap_cross_mod_neighbor_fraction.png"),
            bbox_inches="tight", dpi=200)
plt.show()
plt.close()


# ============================
# 13. Per-cluster modality composition (Leiden clusters)
# ============================
print("\nComputing modality composition per Leiden cluster...")

if "univi_leiden" in combined.obs.columns:
    comp_ct = pd.crosstab(combined.obs["univi_leiden"], combined.obs["univi_source"])
    comp_prop = comp_ct.div(comp_ct.sum(axis=1), axis=0)  # row-normalize

    plt.figure(figsize=(7, 5))
    comp_prop.sort_index().plot(
        kind="bar",
        stacked=True,
        width=0.9,
        colormap="tab10",
    )
    plt.xlabel("UniVI Leiden cluster")
    plt.ylabel("Fraction of cells")
    plt.title("Modality composition per UniVI Leiden cluster")
    plt.legend(title="Modality", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()
    plt.savefig(os.path.join(FIGDIR, "cluster_modality_composition_stacked_bar.png"),
                dpi=200)
    plt.show()
    plt.close()
else:
    print("  WARNING: 'univi_leiden' not found in combined.obs; skipping cluster composition plot.")


# ============================
# 14. Feature-level cross-modal scatter plots (RNA → ADT / ATAC)
# ============================
print("\nFeature-level scatter plots for cross-modal prediction (RNA→ADT / RNA→ATAC)...")

# Example: RNA → ADT for a small subset of cells & selected markers
# (adjust marker names to what you actually have in adt_test_adata.var_names)

n_scatter_cells = min(5000, adt_test_adata.n_obs)  # subsample if huge
cell_idx = np.random.default_rng(42).choice(adt_test_adata.n_obs, n_scatter_cells, replace=False)

# recompute predictions just for this subset to avoid reusing big arrays
Xhat_adt_sub = univi_eval.cross_modal_predict(
    model,
    adata_src=rna_test_adata[cell_idx],
    src_mod="rna",
    tgt_mod="adt",
    device=device,
    batch_size=512,
)
X_adt_sub = _to_dense(adt_test_adata[cell_idx].X)

adt_markers_to_plot = [
    # put your favorite ADT markers here
    # e.g. "CD3", "CD4", "CD8A", "CD56", ...
]

for marker in adt_markers_to_plot:
    if marker not in adt_test_adata.var_names:
        print(f"  [RNA→ADT] Marker '{marker}' not found in adt_test_adata.var_names; skipping.")
        continue

    j = np.where(adt_test_adata.var_names == marker)[0][0]
    y_true = X_adt_sub[:, j]
    y_pred = Xhat_adt_sub[:, j]

    plt.figure(figsize=(4.5, 4))
    plt.hexbin(y_true, y_pred, gridsize=50, mincnt=1)
    plt.xlabel(f"True ADT ({marker})")
    plt.ylabel(f"Predicted ADT ({marker})")
    plt.title(f"RNA→ADT prediction for {marker}")
    plt.tight_layout()
    fname = f"scatter_RNA_to_ADT_{marker}.png".replace(" ", "_")
    plt.savefig(os.path.join(FIGDIR, fname), dpi=200)
    plt.show()
    plt.close()


# Example: RNA → ATAC for a few peaks / gene-body features (if named)
# (This is more exploratory since ATAC features are often peaks; choose a few named ones if available.)

n_scatter_cells_atac = min(5000, atac_test_adata.n_obs)
cell_idx_atac = np.random.default_rng(123).choice(atac_test_adata.n_obs, n_scatter_cells_atac, replace=False)

Xhat_atac_sub = univi_eval.cross_modal_predict(
    model,
    adata_src=rna_test_adata[cell_idx_atac],
    src_mod="rna",
    tgt_mod="atac",
    device=device,
    batch_size=512,
)
X_atac_sub = _to_dense(atac_test_adata[cell_idx_atac].X)

# If you have named ATAC features (e.g. gene body aggregates), you can list them here.
atac_features_to_plot = [
    # e.g. "TNFRSF4_body", "IFNG_enh", ...
]

for feat in atac_features_to_plot:
    if feat not in atac_test_adata.var_names:
        print(f"  [RNA→ATAC] Feature '{feat}' not found in atac_test_adata.var_names; skipping.")
        continue

    j = np.where(atac_test_adata.var_names == feat)[0][0]
    y_true = X_atac_sub[:, j]
    y_pred = Xhat_atac_sub[:, j]

    plt.figure(figsize=(4.5, 4))
    plt.hexbin(y_true, y_pred, gridsize=50, mincnt=1)
    plt.xlabel(f"True ATAC ({feat})")
    plt.ylabel(f"Predicted ATAC ({feat})")
    plt.title(f"RNA→ATAC prediction for {feat}")
    plt.tight_layout()
    fname = f"scatter_RNA_to_ATAC_{feat}.png".replace(" ", "_")
    plt.savefig(os.path.join(FIGDIR, fname), dpi=200)
    plt.show()
    plt.close()
