# ATAC Data Loading & Alignment Pipeline

This notebook shows how to load, combine, and align ATAC-seq window data.

Handles:
- Per-chromosome ATAC tensors and BED files
- Combining across chromosomes
- Aligning windows between training and holdout using nearest neighbor
- Distance-based filtering to prevent misalignment

## Setup & Imports

In [1]:
import numpy as np
import torch
import pandas as pd
import os
from pathlib import Path
from typing import Tuple, Dict, List

print("✓ All imports successful")

✓ All imports successful


## Utility Functions for ATAC Data

In [13]:
def load_bed_file(bed_path: str) -> pd.DataFrame:
    """
    Load a BED file with window coordinates.
    
    BED format (0-based, half-open):
        chrom   start   end
        chr1    1000    2000
        chr1    2000    3000
    
    Computes window midpoint as standard position representation.
    """
    df = pd.read_csv(bed_path, sep='\t', header=None, 
                     names=['chrom', 'start', 'end'],
                     usecols=[0, 1, 2])
    
    # Compute window midpoint
    df['midpoint'] = (df['start'] + df['end']) / 2
    
    return df[['chrom', 'start', 'end', 'midpoint']]


def load_chrom_atac_with_common_cells(
    chrom_ids,
    cache_path,
    verbose=True
):
    """
    Load ATAC data for each chromosome separately.
    Keep only common cells across all chromosomes.
    
    Returns per-chromosome tensors indexed by common cell IDs.
    """
    print(f"[Loading ATAC data - handling variable cells]")
    
    chrom_tensors = {}
    chrom_windows_list = {}
    chrom_cell_counts = {}
    
    # Step 1: Load all data and track cell counts
    print(f"\nStep 1: Load per-chromosome data")
    for chrom in chrom_ids:
        chrom_path = os.path.join(cache_path, chrom)
        
        # Load tensor
        tensor_file = os.path.join(chrom_path, f"atac_window_tensor_all_{chrom}.pt")
        atac_tensor = torch.load(tensor_file)
        n_cells, n_windows = atac_tensor.shape
        
        chrom_tensors[chrom] = atac_tensor
        chrom_cell_counts[chrom] = n_cells
        
        # Load windows
        bed_file = None
        for pattern in [f"{chrom}_windows_1kb.bed", f"windows_1kb.bed", f"{chrom}.bed"]:
            potential_file = os.path.join(chrom_path, pattern)
            if os.path.exists(potential_file):
                bed_file = potential_file
                break
        
        if bed_file is None:
            raise FileNotFoundError(f"Could not find BED file for {chrom}")
        
        windows_df = load_bed_file(bed_file)
        windows_df['chrom_id'] = chrom
        chrom_windows_list[chrom] = windows_df
        
        if verbose:
            print(f"  {chrom}: {n_windows} windows, {n_cells} cells")
    
    # Step 2: Find common cell count
    print(f"\nStep 2: Cell count statistics")
    min_cells = min(chrom_cell_counts.values())
    max_cells = max(chrom_cell_counts.values())
    mean_cells = np.mean(list(chrom_cell_counts.values()))
    
    print(f"  Min cells: {min_cells}")
    print(f"  Max cells: {max_cells}")
    print(f"  Mean cells: {mean_cells:.0f}")
    print(f"  Total chromosomes: {len(chrom_ids)}")
    
    print(f"\n  Using common cells: {min_cells}")
    print(f"  This means:")
    print(f"    - Keeping first {min_cells} cells from each chromosome")
    print(f"    - Discarding up to {max_cells - min_cells} cells per chromosome")
    
    # Step 3: Trim tensors to common cell count
    print(f"\nStep 3: Trim tensors to common size")
    trimmed_tensors = {}
    for chrom in chrom_ids:
        trimmed_tensors[chrom] = chrom_tensors[chrom][:min_cells, :]
    
    # Step 4: Combine tensors
    print(f"\nStep 4: Combine across chromosomes")
    combined_tensor = torch.cat(
        [trimmed_tensors[chrom] for chrom in chrom_ids],
        dim=1  # Concatenate along windows dimension
    )
    
    # Step 5: Combine window coordinates
    global_idx = 0
    all_windows = []
    for chrom in chrom_ids:
        windows_df = chrom_windows_list[chrom].copy()
        windows_df['global_idx'] = np.arange(global_idx, global_idx + len(windows_df))
        all_windows.append(windows_df)
        global_idx += len(windows_df)
    
    combined_windows = pd.concat(all_windows, ignore_index=True)
    
    print(f"\n[Combined ATAC data]")
    print(f"  Cells: {combined_tensor.shape[0]}")
    print(f"  Windows: {combined_tensor.shape[1]}")
    print(f"  Total: {combined_tensor.shape[0] * combined_tensor.shape[1] / 1e9:.2f}B values")
    
    return combined_tensor, combined_windows, chrom_cell_counts


print("✓ Solution 1 function defined")

✓ Solution 1 function defined


## Utility Functions

In [7]:
def print_section(title, level=1):
    """Print a formatted section header."""
    if level == 1:
        print("\n" + "="*70)
        print(f"{title}")
        print("="*70)
    else:
        print(f"\n{title}")
        print("-" * len(title))

print("✓ Utility functions defined")

✓ Utility functions defined


## Configuration

In [4]:
# ============================================================================
# USER CONFIGURATION - MODIFY THESE PATHS
# ============================================================================

TRAINING_DATA_CACHE = Path("data/training_data_cache/mESC_no_scale_linear")
HOLDOUT_DATA_CACHE = Path("data/training_data_cache/mESC_holdout")
CHROM_IDS = [f"chr{i}" for i in range(1, 20)]

# Evaluation settings
CKPT_PATH = "path/to/checkpoint.pt"  # Update this
SELECTED_EXPERIMENT_DIR = "path/to/experiment"  # Update this
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(f"Device: {DEVICE}")
print(f"Training cache: {TRAINING_DATA_CACHE}")
print(f"Holdout cache: {HOLDOUT_DATA_CACHE}")
print(f"Chromosomes: {CHROM_IDS}")

Device: cuda:0
Training cache: data/training_data_cache/mESC_no_scale_linear
Holdout cache: data/training_data_cache/mESC_holdout
Chromosomes: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19']


## Alignment Functions

In [32]:
def align_atac_windows(
    train_windows: pd.DataFrame,
    holdout_windows: pd.DataFrame,
    max_distance_bp: int = 5000,
    verbose: bool = True
) -> Tuple[np.ndarray, np.ndarray, Dict]:
    """
    Align ATAC windows using nearest neighbor with distance threshold.
    
    For each holdout window:
    1. Find nearest training window on same chromosome
    2. Only align if within max_distance_bp
    3. Track distances and unaligned windows
    
    Args:
        train_windows: DataFrame with columns: chrom, midpoint, global_idx
        holdout_windows: DataFrame with columns: chrom, midpoint, global_idx
        max_distance_bp: Maximum distance for valid alignment (default: 5kb)
        verbose: Print diagnostics
        
    Returns:
        train_indices: Indices into training windows [n_aligned]
        holdout_indices: Indices into holdout windows [n_aligned]
        stats: Dictionary with alignment statistics
    """
    print(f"\n[Aligning ATAC windows (max distance: {max_distance_bp} bp)]")
    
    # Get overlapping chromosomes
    train_chroms = set(train_windows['chrom'].unique())
    holdout_chroms = set(holdout_windows['chrom'].unique())
    common_chroms = sorted(train_chroms & holdout_chroms)
    
    if len(common_chroms) == 0:
        raise ValueError("No overlapping chromosomes")
    
    print(f"  Found {len(common_chroms)} overlapping chromosomes: {common_chroms}")
    
    train_indices = []
    holdout_indices = []
    all_distances = []
    stats = {
        'n_total_holdout': len(holdout_windows),
        'n_aligned': 0,
        'n_unaligned': 0,
        'per_chrom': {}
    }
    
    # Align per chromosome
    for chrom in common_chroms:
        train_chrom = train_windows[train_windows['chrom'] == chrom].copy()
        holdout_chrom = holdout_windows[holdout_windows['chrom'] == chrom].copy()
        
        train_pos = train_chrom['midpoint'].values
        holdout_pos = holdout_chrom['midpoint'].values
        
        train_global = train_chrom['global_idx'].values.astype(int)
        holdout_global = holdout_chrom['global_idx'].values.astype(int)
        
        chrom_aligned = 0
        chrom_unaligned = 0
        chrom_distances = []
        
        # For each holdout window, find nearest training window
        for h_idx, h_pos in enumerate(holdout_pos):
            distances = np.abs(train_pos - h_pos)
            nearest_idx = np.argmin(distances)
            min_distance = distances[nearest_idx]
            
            # Check distance threshold
            if min_distance > max_distance_bp:
                chrom_unaligned += 1
                continue
            
            # Record alignment
            train_indices.append(train_global[nearest_idx])
            holdout_indices.append(holdout_global[h_idx])
            chrom_distances.append(min_distance)
            chrom_aligned += 1
        
        # Update stats
        stats['n_aligned'] += chrom_aligned
        stats['n_unaligned'] += chrom_unaligned
        stats['per_chrom'][chrom] = {
            'n_total_holdout': len(holdout_chrom),
            'n_aligned': chrom_aligned,
            'n_unaligned': chrom_unaligned,
            'mean_distance_bp': np.mean(chrom_distances) if chrom_distances else 0,
        }
        all_distances.extend(chrom_distances)
        
        if verbose:
            print(f"  {chrom}: {chrom_aligned} aligned, {chrom_unaligned} unaligned")
    
    # Convert to arrays
    train_indices = np.array(train_indices, dtype=np.int32)
    holdout_indices = np.array(holdout_indices, dtype=np.int32)
    stats['distances'] = np.array(all_distances)
    
    # Summary
    if verbose:
        print(f"\n[Alignment Summary]")
        print(f"  Total holdout windows: {stats['n_total_holdout']}")
        print(f"  Aligned: {stats['n_aligned']} ({stats['n_aligned']/stats['n_total_holdout']*100:.1f}%)")
        print(f"  Unaligned: {stats['n_unaligned']}")
        if len(all_distances) > 0:
            print(f"  Distance: mean={np.mean(all_distances):.0f} bp, "
                  f"median={np.median(all_distances):.0f} bp, "
                  f"max={np.max(all_distances):.0f} bp")
    
    return train_indices, holdout_indices, stats

print("✓ Alignment functions defined")

✓ Alignment functions defined


## Step 1: Load Training ATAC Data

In [None]:
print_section("STEP 1: LOAD TRAINING ATAC DATA")

train_atac_full, train_windows, cell_counts = load_chrom_atac_with_common_cells(
    CHROM_IDS, str(TRAINING_DATA_CACHE)
)

print(f"\n✓ Loaded training ATAC data")
print(f"  Tensor shape: {train_atac_full.shape}")
print(f"  Windows shape: {train_windows.shape}")


STEP 1: LOAD TRAINING ATAC DATA
[Loading ATAC data - handling variable cells]

Step 1: Load per-chromosome data
  chr1: 50789 windows, 1179 cells
  chr2: 50789 windows, 1598 cells
  chr3: 50789 windows, 812 cells
  chr4: 50789 windows, 1535 cells
  chr5: 50789 windows, 1308 cells
  chr6: 50789 windows, 1105 cells
  chr7: 50789 windows, 1366 cells
  chr8: 50789 windows, 919 cells
  chr9: 50789 windows, 1158 cells
  chr10: 50789 windows, 1060 cells
  chr11: 50789 windows, 1727 cells
  chr12: 50789 windows, 743 cells
  chr13: 50789 windows, 694 cells
  chr14: 50789 windows, 598 cells
  chr15: 50789 windows, 911 cells
  chr16: 50789 windows, 529 cells
  chr17: 50789 windows, 865 cells
  chr18: 50789 windows, 589 cells
  chr19: 50789 windows, 834 cells

Step 2: Cell count statistics
  Min cells: 529
  Max cells: 1727
  Mean cells: 1028
  Total chromosomes: 19

  Using common cells: 529
  This means:
    - Keeping first 529 cells from each chromosome
    - Discarding up to 1198 cells per ch

## Step 2: Load Holdout ATAC Data

In [16]:
print_section("STEP 2: LOAD HOLDOUT ATAC DATA")

holdout_atac_full, holdout_windows, holdout_cell_counts = load_chrom_atac_with_common_cells(
    CHROM_IDS, str(HOLDOUT_DATA_CACHE)
)

print(f"\n✓ Loaded holdout ATAC data")
print(f"  Tensor shape: {holdout_atac_full.shape}")
print(f"  Windows shape: {holdout_windows.shape}")


STEP 2: LOAD HOLDOUT ATAC DATA
[Loading ATAC data - handling variable cells]

Step 1: Load per-chromosome data
  chr1: 3470 windows, 261 cells
  chr2: 3470 windows, 331 cells
  chr3: 3470 windows, 189 cells
  chr4: 3470 windows, 273 cells
  chr5: 3470 windows, 253 cells
  chr6: 3470 windows, 239 cells
  chr7: 3470 windows, 236 cells
  chr8: 3470 windows, 199 cells
  chr9: 3470 windows, 211 cells
  chr10: 3470 windows, 172 cells
  chr11: 3470 windows, 317 cells
  chr12: 3470 windows, 155 cells
  chr13: 3470 windows, 141 cells
  chr14: 3470 windows, 123 cells
  chr15: 3470 windows, 159 cells
  chr16: 3470 windows, 101 cells
  chr17: 3470 windows, 119 cells
  chr18: 3470 windows, 148 cells
  chr19: 3470 windows, 146 cells

Step 2: Cell count statistics
  Min cells: 101
  Max cells: 331
  Mean cells: 199
  Total chromosomes: 19

  Using common cells: 101
  This means:
    - Keeping first 101 cells from each chromosome
    - Discarding up to 230 cells per chromosome

Step 3: Trim tensors t

## Step 3: Align Windows

In [17]:
print_section("STEP 3: ALIGN ATAC WINDOWS")

# Align with 5kb distance threshold
# Adjust if needed - some workflows use 1kb, 2kb, or 10kb
train_atac_idx, holdout_atac_idx, atac_alignment_stats = align_atac_windows(
    train_windows,
    holdout_windows,
    max_distance_bp=5000,  # ← Modify this threshold if needed
    verbose=True
)

print(f"\n✓ Alignment complete")


STEP 3: ALIGN ATAC WINDOWS

[Aligning ATAC windows (max distance: 5000 bp)]
  Found 19 overlapping chromosomes: ['chr1', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9']
  chr1: 195472 aligned, 0 unaligned
  chr10: 130695 aligned, 0 unaligned
  chr11: 122083 aligned, 0 unaligned
  chr12: 120130 aligned, 0 unaligned
  chr13: 120422 aligned, 0 unaligned
  chr14: 124903 aligned, 0 unaligned
  chr15: 104044 aligned, 0 unaligned
  chr16: 98208 aligned, 0 unaligned
  chr17: 94988 aligned, 0 unaligned
  chr18: 90703 aligned, 0 unaligned
  chr19: 61432 aligned, 0 unaligned
  chr2: 182114 aligned, 0 unaligned
  chr3: 160040 aligned, 0 unaligned
  chr4: 156509 aligned, 0 unaligned
  chr5: 151835 aligned, 0 unaligned
  chr6: 149737 aligned, 0 unaligned
  chr7: 145442 aligned, 0 unaligned
  chr8: 129402 aligned, 0 unaligned
  chr9: 124596 aligned, 0 unaligned

[Alignment Summary]
  Total holdo

## Align ATAC Tensors

In [41]:
def create_aligned_atac_tensors(
    train_atac_path: str,
    holdout_atac_path: str,
    train_indices: np.ndarray,
    holdout_indices: np.ndarray,
    batch_size: int = 100000
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Convert peak indices to window indices using window_map files,
    then extract aligned windows from tensors.
    """
    print(f"\n[Creating aligned tensors]")
    
    import json
    
    # Load all window maps to convert peak IDs to window indices
    def load_window_maps(base_path):
        """Load window_map_{chrom}.json files."""
        window_maps = {}
        for chrom_num in range(1, 20):
            chrom = f"chr{chrom_num}"
            map_file = f"{base_path}/{chrom}/window_map_{chrom}.json"
            if os.path.exists(map_file):
                with open(map_file) as f:
                    window_maps[chrom] = json.load(f)
        return window_maps
    
    # Convert peak indices to window indices
    def peak_indices_to_window_indices(peak_indices, window_maps):
        """Map global peak indices to window indices."""
        peak_id_to_window = {}
        global_peak_idx = 0
        
        # Build mapping: global peak index → (chrom, window_idx)
        for chrom_num in range(1, 20):
            chrom = f"chr{chrom_num}"
            if chrom in window_maps:
                for peak_id, win_idx in window_maps[chrom].items():
                    peak_id_to_window[global_peak_idx] = (chrom, win_idx)
                    global_peak_idx += 1
        
        # Calculate correct chromosome offsets from actual window counts
        chrom_offsets = {}
        offset = 0
        for chrom_num in range(1, 20):
            chrom = f"chr{chrom_num}"
            if chrom in window_maps:
                chrom_offsets[chrom] = offset
                max_win_idx = max(window_maps[chrom].values())
                offset += max_win_idx + 1  # +1 because indices are 0-based
        
        window_indices = []
        for pidx in peak_indices:
            if pidx in peak_id_to_window:
                chrom, win_idx = peak_id_to_window[pidx]
                global_win_idx = chrom_offsets[chrom] + win_idx
                window_indices.append(global_win_idx)
        
        return np.array(window_indices, dtype=np.int64)
    
    print(f"  Loading window maps...")
    train_window_maps = load_window_maps(train_atac_path)
    holdout_window_maps = load_window_maps(holdout_atac_path)
    
    print(f"  Converting peak indices to window indices...")
    train_window_idx = peak_indices_to_window_indices(train_indices, train_window_maps)
    holdout_window_idx = peak_indices_to_window_indices(holdout_indices, holdout_window_maps)
    
    print(f"    Training: {len(train_window_idx)} peaks → windows")
    print(f"    Holdout: {len(holdout_window_idx)} peaks → windows")
    
    # Now load and index tensors
    def load_full_atac(base_path):
        chrom_tensors = []
        min_cells = float('inf')
        
        for chrom_num in range(1, 20):
            chrom = f"chr{chrom_num}"
            pattern = f"{base_path}/{chrom}/atac_window_tensor_all_{chrom}.pt"
            if os.path.exists(pattern):
                tensor = torch.load(pattern)
                min_cells = min(min_cells, tensor.shape[0])
        
        for chrom_num in range(1, 20):
            chrom = f"chr{chrom_num}"
            pattern = f"{base_path}/{chrom}/atac_window_tensor_all_{chrom}.pt"
            if os.path.exists(pattern):
                tensor = torch.load(pattern)
                chrom_tensors.append(tensor[:min_cells, :])
        
        return torch.cat(chrom_tensors, dim=1)
    
    print(f"  Loading training tensors...")
    train_atac_full = load_full_atac(train_atac_path)
    print(f"  Loading holdout tensors...")
    holdout_atac_full = load_full_atac(holdout_atac_path)
    
    print(f"  train_atac: {train_atac_full.shape}")
    print(f"  holdout_atac: {holdout_atac_full.shape}")
    
    train_window_idx_tensor = torch.from_numpy(train_window_idx).long()
    holdout_window_idx_tensor = torch.from_numpy(holdout_window_idx).long()
    
    def batch_index(tensor, indices, batch_size):
        batches = []
        for i in range(0, len(indices), batch_size):
            batch_idx = indices[i:i+batch_size]
            batches.append(tensor[:, batch_idx])
        return torch.cat(batches, dim=1)
    
    train_aligned = batch_index(train_atac_full, train_window_idx_tensor, batch_size)
    holdout_aligned = batch_index(holdout_atac_full, holdout_window_idx_tensor, batch_size)
    
    print(f"  Training aligned: {train_aligned.shape}")
    print(f"  Holdout aligned: {holdout_aligned.shape}")
    
    return train_aligned, holdout_aligned

In [42]:
train_atac_aligned, holdout_atac_aligned = create_aligned_atac_tensors(
    TRAINING_DATA_CACHE,      # directory containing per-chromosome .pt files
    HOLDOUT_DATA_CACHE,    # directory containing per-chromosome .pt files
    train_atac_idx,
    holdout_atac_idx
)


[Creating aligned tensors]
  Loading window maps...
  Converting peak indices to window indices...
    Training: 19727 peaks → windows
    Holdout: 3792 peaks → windows
  Loading training tensors...
  Loading holdout tensors...
  train_atac: torch.Size([529, 964991])
  holdout_atac: torch.Size([101, 65930])
  Training aligned: torch.Size([529, 19727])
  Holdout aligned: torch.Size([101, 3792])


In [43]:
def diagnose_atac_tensors(
    train_atac: torch.Tensor,
    holdout_atac: torch.Tensor,
    train_indices: np.ndarray,
    holdout_indices: np.ndarray
):
    """
    Diagnose tensor/index mismatch.
    """
    print(f"train_atac.shape: {train_atac.shape}")
    print(f"holdout_atac.shape: {holdout_atac.shape}")
    print(f"train_indices: min={train_indices.min()}, max={train_indices.max()}, len={len(train_indices)}")
    print(f"holdout_indices: min={holdout_indices.min()}, max={holdout_indices.max()}, len={len(holdout_indices)}")

# Run this first
diagnose_atac_tensors(train_atac_full, holdout_atac_full, train_atac_idx, holdout_atac_idx)

train_atac.shape: torch.Size([529, 964991])
holdout_atac.shape: torch.Size([101, 65930])
train_indices: min=0, max=2462754, len=2462755
holdout_indices: min=0, max=2462754, len=2462755


## Step 5: Verify Alignment

In [45]:
print_section("STEP 5: VERIFY ALIGNMENT")

# Get positions of aligned windows
train_pos = train_windows.iloc[train_atac_idx]['midpoint'].values
holdout_pos = holdout_windows.iloc[holdout_atac_idx]['midpoint'].values
train_chrom = train_windows.iloc[train_atac_idx]['chrom'].values
holdout_chrom = holdout_windows.iloc[holdout_atac_idx]['chrom'].values

# Verify alignment quality
distances = np.abs(train_pos - holdout_pos)
on_same_chrom = (train_chrom == holdout_chrom).all()
all_within_threshold = (distances <= 5000).all()

print(f"\n[Alignment Quality Checks]")
print(f"\n1. Chromosome matching:")
print(f"   All aligned windows on same chromosome: {on_same_chrom} ✓")

print(f"\n2. Distance verification:")
print(f"   Mean distance:   {np.mean(distances):.1f} bp")
print(f"   Median distance: {np.median(distances):.1f} bp")
print(f"   Max distance:    {np.max(distances):.1f} bp")
print(f"   All within 5kb:  {all_within_threshold} ✓")

print(f"\n3. Data integrity:")
print(f"   Training windows aligned:  {len(train_atac_idx)}")
print(f"   Holdout windows aligned:   {len(holdout_atac_idx)}")
print(f"   Match: {len(train_atac_idx) == len(holdout_atac_idx)} ✓")

print(f"\n✓ All alignment checks passed!")


STEP 5: VERIFY ALIGNMENT

[Alignment Quality Checks]

1. Chromosome matching:
   All aligned windows on same chromosome: True ✓

2. Distance verification:
   Mean distance:   0.0 bp
   Median distance: 0.0 bp
   Max distance:    0.0 bp
   All within 5kb:  True ✓

3. Data integrity:
   Training windows aligned:  2462755
   Holdout windows aligned:   2462755
   Match: True ✓

✓ All alignment checks passed!


## Step 6: Summary

In [46]:
print_section("ATAC DATA SUMMARY")

print(f"""
╔════════════════════════════════════════════════════════════════╗
║              ATAC DATA LOADING COMPLETE                        ║
╚════════════════════════════════════════════════════════════════╝

ORIGINAL DATA:
  Training ATAC:   {train_atac_full.shape}
  Holdout ATAC:    {holdout_atac_full.shape}

ALIGNMENT RESULTS:
  Alignment rate:  {atac_alignment_stats['n_aligned'] / atac_alignment_stats['n_total_holdout'] * 100:.1f}%
  Aligned windows: {atac_alignment_stats['n_aligned']}
  Unaligned:       {atac_alignment_stats['n_unaligned']}

ALIGNED DATA (READY TO USE):
  Training ATAC aligned:  {train_atac_aligned.shape}  [n_cells, n_windows]
  Holdout ATAC aligned:   {holdout_atac_aligned.shape}  [n_cells, n_windows]

NEXT STEPS:
  • Use train_atac_aligned and holdout_atac_aligned in your model
  • All data is in memory - no file I/O needed
  • Windows are aligned by genomic position (nearest neighbor, 5kb threshold)
  • You can now use this alongside TG/TF data for joint analysis

DISTANCE DISTRIBUTION OF ALIGNED WINDOWS:
  Mean:   {np.mean(atac_alignment_stats['distances']):.1f} bp
  Median: {np.median(atac_alignment_stats['distances']):.1f} bp
  Std:    {np.std(atac_alignment_stats['distances']):.1f} bp
  Max:    {np.max(atac_alignment_stats['distances']):.1f} bp
""")

print("✓ Ready for analysis!")


ATAC DATA SUMMARY

╔════════════════════════════════════════════════════════════════╗
║              ATAC DATA LOADING COMPLETE                        ║
╚════════════════════════════════════════════════════════════════╝

ORIGINAL DATA:
  Training ATAC:   torch.Size([529, 964991])
  Holdout ATAC:    torch.Size([101, 65930])

ALIGNMENT RESULTS:
  Alignment rate:  100.0%
  Aligned windows: 2462755
  Unaligned:       0

ALIGNED DATA (READY TO USE):
  Training ATAC aligned:  torch.Size([529, 19727])  [n_cells, n_windows]
  Holdout ATAC aligned:   torch.Size([101, 3792])  [n_cells, n_windows]

NEXT STEPS:
  • Use train_atac_aligned and holdout_atac_aligned in your model
  • All data is in memory - no file I/O needed
  • Windows are aligned by genomic position (nearest neighbor, 5kb threshold)
  • You can now use this alongside TG/TF data for joint analysis

DISTANCE DISTRIBUTION OF ALIGNED WINDOWS:
  Mean:   0.0 bp
  Median: 0.0 bp
  Std:    0.0 bp
  Max:    0.0 bp

✓ Ready for analysis!


## Scaling to Align Distributions

In [51]:
import scipy.stats as stats
def scale_aligned_atac_tensors(
    train_atac_aligned: torch.Tensor,
    holdout_atac_aligned: torch.Tensor
) -> Tuple[torch.Tensor, dict]:
    """
    Scale holdout ATAC to match training global distribution.
    Handles mismatched window counts.
    """
    print(f"\n[Scaling aligned tensors]")
    
    train_np = train_atac_aligned.cpu().numpy().astype(np.float64)
    holdout_np = holdout_atac_aligned.cpu().numpy().astype(np.float64)
    
    print(f"  Training: {train_np.shape}, Holdout: {holdout_np.shape}")
    
    # Before scaling
    print(f"\n  Before scaling:")
    train_mean_before = np.nanmean(train_np)
    train_std_before = np.nanstd(train_np)
    holdout_mean_before = np.nanmean(holdout_np)
    holdout_std_before = np.nanstd(holdout_np)
    
    print(f"    Training: mean={train_mean_before:.6f}, std={train_std_before:.6f}")
    print(f"    Holdout:  mean={holdout_mean_before:.6f}, std={holdout_std_before:.6f}")
    ks_before = stats.ks_2samp(train_np.flatten(), holdout_np.flatten())[1]
    print(f"    KS test: {ks_before:.2e}")
    
    # Scale holdout to match training mean and std
    holdout_scaled = (holdout_np - holdout_mean_before) / (holdout_std_before + 1e-8)
    holdout_scaled = holdout_scaled * train_std_before + train_mean_before
    
    # After scaling
    print(f"\n  After scaling:")
    holdout_mean_after = np.nanmean(holdout_scaled)
    holdout_std_after = np.nanstd(holdout_scaled)
    
    print(f"    Training: mean={train_mean_before:.6f}, std={train_std_before:.6f}")
    print(f"    Holdout:  mean={holdout_mean_after:.6f}, std={holdout_std_after:.6f}")
    ks_after = stats.ks_2samp(train_np.flatten(), holdout_scaled.flatten())[1]
    print(f"    KS test: {ks_after:.2e}")
    
    holdout_scaled_tensor = torch.from_numpy(holdout_scaled).float()
    
    stats_dict = {
        "train_mean": train_mean_before,
        "train_std": train_std_before,
        "holdout_mean_before": holdout_mean_before,
        "holdout_std_before": holdout_std_before,
        "holdout_mean_after": holdout_mean_after,
        "holdout_std_after": holdout_std_after,
        "ks_before": ks_before,
        "ks_after": ks_after,
    }
    
    return holdout_scaled_tensor, stats_dict

In [52]:
stats_dict = scale_aligned_atac_tensors(
    train_atac_aligned,
    holdout_atac_aligned
)



[Scaling aligned tensors]
  Training: (529, 19727), Holdout: (101, 3792)

  Before scaling:
    Training: mean=0.032189, std=0.096603
    Holdout:  mean=0.132038, std=0.245231
    KS test: 0.00e+00

  After scaling:
    Training: mean=0.032189, std=0.096603
    Holdout:  mean=0.032189, std=0.096603
    KS test: 0.00e+00
