In [1]:

import pandas as pd
import numpy as np
import autoencodix as acx
from autoencodix.data.datapackage import DataPackage
from autoencodix.configs.stackix_config import StackixConfig
from autoencodix.configs.default_config import DataCase

# --- Step 1: generate synthetic data ---
np.random.seed(42)
n_samples = 100
modalities = ["rna", "protein", "atac", "cytokine", "metabolite"]

n_paired = int(n_samples * 0.65)
paired_ids = [f"sample_{i}" for i in range(n_paired)]
remaining_ids = [f"sample_{i}" for i in range(n_paired, n_samples)]

data_frames = {}
for mod in modalities:
    mod_ids = paired_ids + list(np.random.choice(remaining_ids, size=int(len(remaining_ids)*0.7), replace=False))
    df = pd.DataFrame(
        np.random.randn(len(mod_ids), 10),
        index=mod_ids,
        columns=[f"{mod}_feature_{i}" for i in range(10)]
    )
    data_frames[mod] = df

dp = DataPackage(
    multi_bulk=data_frames,
    annotation={mod: df.copy() for mod, df in data_frames.items()}
)

stackix_config = StackixConfig(
    data_case=DataCase.MULTI_BULK,
    requires_paired=False
)

stackix = acx.Stackix(data=dp, config=stackix_config)
result = stackix.run()

for mod in modalities:
    train_len = len(result.datasets.train.datasets[mod].sample_ids)
    valid_len = len(result.datasets.valid.datasets[mod].sample_ids)
    test_len = len(result.datasets.test.datasets[mod].sample_ids)
    total_len = train_len + valid_len + test_len
    print(f"Modality {mod}: train={train_len}, valid={valid_len}, test={test_len}, total={total_len}")
    assert total_len == len(data_frames[mod]), "All samples must be assigned"


"""
This checks if a sample_id does only occur in one split.
Ths should be the case because of our pairing aware splitting,
so if one sample occurs in multiple data modalites, this sample
needs to be in the same split for these data modalites to prevent
leakage. If the sample does only occur in one data modality, however,
it can also be in only one split, so the following test makes sense

"""
all_sample_ids = set().union(*[df.index for df in data_frames.values()])

for sid in all_sample_ids:
    splits = set()
    for split_name, dataset in zip(
        ["train", "valid", "test"],
        [result.datasets.train, result.datasets.valid, result.datasets.test]
    ):
        for mod in modalities:
            if sid in dataset.datasets[mod].sample_ids:
                splits.add(split_name)
    assert len(splits) == 1, f"Sample {sid} appears in multiple splits across modalities"

print("Integration test passed!")

in handle_direct_user_data with data: <class 'autoencodix.data.datapackage.DataPackage'>
--- Running Pairing-Aware Split ---
Training each modality model...
Training modality: rna
Training modality: rna
Epoch 1 - Train Loss: 8.4255
Sub-losses: recon_loss: 8.4255, var_loss: 0.0000, anneal_factor: 0.0000, effective_beta_factor: 0.0000
Epoch 1 - Valid Loss: 5.1928
Sub-losses: recon_loss: 5.1928, var_loss: 0.0000, anneal_factor: 0.0000, effective_beta_factor: 0.0000
Epoch 2 - Train Loss: 7.6681
Sub-losses: recon_loss: 7.6567, var_loss: 0.0114, anneal_factor: 0.0344, effective_beta_factor: 0.0034
Epoch 2 - Valid Loss: 5.7376
Sub-losses: recon_loss: 5.7361, var_loss: 0.0015, anneal_factor: 0.0344, effective_beta_factor: 0.0034
Epoch 3 - Train Loss: 8.5381
Sub-losses: recon_loss: 8.2019, var_loss: 0.3362, anneal_factor: 0.9656, effective_beta_factor: 0.0966
Epoch 3 - Valid Loss: 5.1187
Sub-losses: recon_loss: 5.0709, var_loss: 0.0478, anneal_factor: 0.9656, effective_beta_factor: 0.0966
Train

In [2]:
import pandas as pd
import numpy as np
import autoencodix as acx
from autoencodix.data.datapackage import DataPackage
from autoencodix.configs.stackix_config import StackixConfig
from autoencodix.configs.default_config import DataCase


def test_single_global_annotation():
    """Test with a single global annotation file containing all samples."""
    print("\n=== TEST 1: Single Global Annotation ===")
    
    np.random.seed(42)
    n_samples = 100
    modalities = ["rna", "protein", "atac", "cytokine", "metabolite"]
    
    # Create paired and unpaired samples
    n_paired = int(n_samples * 0.65)
    paired_ids = [f"sample_{i}" for i in range(n_paired)]
    remaining_ids = [f"sample_{i}" for i in range(n_paired, n_samples)]
    
    # Build modality data frames with partial overlap
    data_frames = {}
    for mod in modalities:
        mod_ids = paired_ids + list(
            np.random.choice(remaining_ids, size=int(len(remaining_ids)*0.7), replace=False)
        )
        df = pd.DataFrame(
            np.random.randn(len(mod_ids), 10),
            index=mod_ids,
            columns=[f"{mod}_feature_{i}" for i in range(10)]
        )
        data_frames[mod] = df
    
    # Create ONE global annotation containing all unique sample IDs
    all_sample_ids = sorted(set().union(*[df.index for df in data_frames.values()]))
    global_annotation = pd.DataFrame(
        {"metadata": np.random.choice(["A", "B", "C"], size=len(all_sample_ids))},
        index=all_sample_ids
    )
    
    dp = DataPackage(
        multi_bulk=data_frames,
        annotation={"global": global_annotation}  # Single global annotation
    )
    
    stackix_config = StackixConfig(
        data_case=DataCase.MULTI_BULK,
        requires_paired=False
    )
    
    stackix = acx.Stackix(data=dp, config=stackix_config)
    result = stackix.run()
    
    # Test 1: All samples in each modality must be assigned to a split
    print("\nTest 1.1: All modality samples assigned")
    for mod in modalities:
        train_len = len(result.datasets.train.datasets[mod].sample_ids)
        valid_len = len(result.datasets.valid.datasets[mod].sample_ids)
        test_len = len(result.datasets.test.datasets[mod].sample_ids)
        total_len = train_len + valid_len + test_len
        print(f"  {mod}: train={train_len}, valid={valid_len}, test={test_len}, total={total_len}")
        assert total_len == len(data_frames[mod]), f"All samples in {mod} must be assigned"
    
    # Test 2: Pairing-aware splitting - each sample appears in only one split
    print("\nTest 1.2: Pairing-aware splitting (no leakage)")
    all_sample_ids_set = set().union(*[df.index for df in data_frames.values()])
    for sid in all_sample_ids_set:
        splits = set()
        for split_name, dataset in zip(
            ["train", "valid", "test"],
            [result.datasets.train, result.datasets.valid, result.datasets.test]
        ):
            for mod in modalities:
                if sid in dataset.datasets[mod].sample_ids:
                    splits.add(split_name)
        assert len(splits) == 1, f"Sample {sid} appears in multiple splits: {splits}"
    print("  ✓ All samples appear in exactly one split across modalities")
    
    # Test 3: Annotation split contains all assigned samples
    print("\nTest 1.3: Global annotation split consistency")
    anno_train = result.datasets.train.annotation["global"]
    anno_valid = result.datasets.valid.annotation["global"]
    anno_test = result.datasets.test.annotation["global"]
    
    # Collect all sample IDs from annotation splits
    anno_train_ids = set(anno_train.index)
    anno_valid_ids = set(anno_valid.index)
    anno_test_ids = set(anno_test.index)
    
    print(f"  Annotation splits: train={len(anno_train_ids)}, valid={len(anno_valid_ids)}, test={len(anno_test_ids)}")
    
    # Check no overlap between annotation splits
    assert len(anno_train_ids & anno_valid_ids) == 0, "Train/valid annotation overlap"
    assert len(anno_train_ids & anno_test_ids) == 0, "Train/test annotation overlap"
    assert len(anno_valid_ids & anno_test_ids) == 0, "Valid/test annotation overlap"
    print("  ✓ No overlap between annotation splits")
    
    # Verify annotation IDs match modality IDs
    all_modality_ids = set()
    for mod in modalities:
        all_modality_ids.update(result.datasets.train.datasets[mod].sample_ids)
        all_modality_ids.update(result.datasets.valid.datasets[mod].sample_ids)
        all_modality_ids.update(result.datasets.test.datasets[mod].sample_ids)
    
    all_annotation_ids = anno_train_ids | anno_valid_ids | anno_test_ids
    assert all_modality_ids == all_annotation_ids, "Annotation and modality IDs must match"
    print("  ✓ Annotation IDs match modality IDs exactly")
    
    print("\n✓ TEST 1 PASSED: Single global annotation works correctly\n")


def test_multiple_per_modality_annotations():
    """Test with separate annotation files for each modality."""
    print("\n=== TEST 2: Multiple Per-Modality Annotations ===")
    
    np.random.seed(42)
    n_samples = 100
    modalities = ["rna", "protein", "atac"]
    
    # Create paired and unpaired samples
    n_paired = int(n_samples * 0.6)
    paired_ids = [f"sample_{i}" for i in range(n_paired)]
    remaining_ids = [f"sample_{i}" for i in range(n_paired, n_samples)]
    
    # Build modality data frames with partial overlap
    data_frames = {}
    annotations = {}
    
    for mod in modalities:
        # Each modality has different sample coverage
        mod_ids = paired_ids + list(
            np.random.choice(remaining_ids, size=int(len(remaining_ids)*0.6), replace=False)
        )
        
        # Modality data
        df = pd.DataFrame(
            np.random.randn(len(mod_ids), 10),
            index=mod_ids,
            columns=[f"{mod}_feature_{i}" for i in range(10)]
        )
        data_frames[mod] = df
        
        # Per-modality annotation (matching child key)
        anno_df = pd.DataFrame(
            {"label": np.random.choice(["typeA", "typeB"], size=len(mod_ids))},
            index=mod_ids
        )
        annotations[mod] = anno_df
    
    dp = DataPackage(
        multi_bulk=data_frames,
        annotation=annotations  # Multiple annotations, one per modality
    )
    
    stackix_config = StackixConfig(
        data_case=DataCase.MULTI_BULK,
        requires_paired=False
    )
    
    stackix = acx.Stackix(data=dp, config=stackix_config)
    result = stackix.run()
    
    # Test 1: All samples assigned
    print("\nTest 2.1: All modality samples assigned")
    for mod in modalities:
        train_len = len(result.datasets.train.datasets[mod].sample_ids)
        valid_len = len(result.datasets.valid.datasets[mod].sample_ids)
        test_len = len(result.datasets.test.datasets[mod].sample_ids)
        total_len = train_len + valid_len + test_len
        print(f"  {mod}: train={train_len}, valid={valid_len}, test={test_len}, total={total_len}")
        assert total_len == len(data_frames[mod]), f"All samples in {mod} must be assigned"
    
    # Test 2: Pairing-aware splitting
    print("\nTest 2.2: Pairing-aware splitting (no leakage)")
    all_sample_ids = set().union(*[df.index for df in data_frames.values()])
    for sid in all_sample_ids:
        splits = set()
        for split_name, dataset in zip(
            ["train", "valid", "test"],
            [result.datasets.train, result.datasets.valid, result.datasets.test]
        ):
            for mod in modalities:
                if sid in dataset.datasets[mod].sample_ids:
                    splits.add(split_name)
        assert len(splits) == 1, f"Sample {sid} appears in multiple splits: {splits}"
    print("  ✓ All samples appear in exactly one split across modalities")
    
    # Test 3: Each annotation matches its modality
    print("\nTest 2.3: Per-modality annotation consistency")
    for mod in modalities:
        # Get modality sample IDs from each split
        mod_train_ids = set(result.datasets.train.datasets[mod].sample_ids)
        mod_valid_ids = set(result.datasets.valid.datasets[mod].sample_ids)
        mod_test_ids = set(result.datasets.test.datasets[mod].sample_ids)
        
        # Get annotation sample IDs from each split
        anno_train_ids = set(result.datasets.train.annotation[mod].index)
        anno_valid_ids = set(result.datasets.valid.annotation[mod].index)
        anno_test_ids = set(result.datasets.test.annotation[mod].index)
        
        # Verify they match exactly
        assert mod_train_ids == anno_train_ids, f"{mod}: train IDs don't match"
        assert mod_valid_ids == anno_valid_ids, f"{mod}: valid IDs don't match"
        assert mod_test_ids == anno_test_ids, f"{mod}: test IDs don't match"
        
        print(f"  {mod}: ✓ Annotation splits match modality splits")
        
        # Verify no overlap in annotation splits
        assert len(anno_train_ids & anno_valid_ids) == 0, f"{mod}: train/valid overlap"
        assert len(anno_train_ids & anno_test_ids) == 0, f"{mod}: train/test overlap"
        assert len(anno_valid_ids & anno_test_ids) == 0, f"{mod}: valid/test overlap"
    
    print("\n✓ TEST 2 PASSED: Multiple per-modality annotations work correctly\n")


def test_mixed_annotation_coverage():
    """Test where annotation doesn't cover all modality samples (should drop uncovered)."""
    print("\n=== TEST 3: Partial Annotation Coverage ===")
    
    np.random.seed(123)
    modalities = ["rna", "protein"]
    
    # Create samples
    all_ids = [f"sample_{i}" for i in range(50)]
    annotation_ids = all_ids[:40]  # Annotation only covers 40 samples
    
    data_frames = {}
    for mod in modalities:
        # Modalities have all 50 samples
        df = pd.DataFrame(
            np.random.randn(len(all_ids), 5),
            index=all_ids,
            columns=[f"{mod}_f{i}" for i in range(5)]
        )
        data_frames[mod] = df
    
    # Global annotation only covers 40 samples
    global_annotation = pd.DataFrame(
        {"info": ["x"] * len(annotation_ids)},
        index=annotation_ids
    )
    
    dp = DataPackage(
        multi_bulk=data_frames,
        annotation={"meta": global_annotation}
    )
    
    stackix_config = StackixConfig(
        data_case=DataCase.MULTI_BULK,
        requires_paired=False
    )
    
    stackix = acx.Stackix(data=dp, config=stackix_config)
    result = stackix.run()
    
    # Test: Only annotated samples should be in splits
    print("\nTest 3.1: Only annotated samples included")
    for mod in modalities:
        all_split_ids = set()
        all_split_ids.update(result.datasets.train.datasets[mod].sample_ids)
        all_split_ids.update(result.datasets.valid.datasets[mod].sample_ids)
        all_split_ids.update(result.datasets.test.datasets[mod].sample_ids)
        
        print(f"  {mod}: {len(all_split_ids)} samples (expected 40)")
        assert len(all_split_ids) == 40, f"{mod} should only have annotated samples"
        assert all_split_ids == set(annotation_ids), f"{mod} IDs must match annotation"
    
    print("  ✓ Unannotated samples correctly dropped")
    
    print("\n✓ TEST 3 PASSED: Partial annotation coverage handled correctly\n")


if __name__ == "__main__":
    test_single_global_annotation()
    test_multiple_per_modality_annotations()
    test_mixed_annotation_coverage()
    print("\n" + "="*60)
    print("ALL TESTS PASSED! ✓")
    print("="*60)


=== TEST 1: Single Global Annotation ===
in handle_direct_user_data with data: <class 'autoencodix.data.datapackage.DataPackage'>
--- Running Pairing-Aware Split ---


KeyError: 'annotation'