# Fast Compositional Null Calculator (Optimized Version)

This notebook is an **optimized version** that addresses HPC performance issues:

## Key Optimizations

1. **Vectorized batch predictions** - Process many degree pairs simultaneously
2. **Lookup table caching** - Pre-compute predictions for common degree pairs
3. **Memory-efficient processing** - Process in chunks to avoid memory issues
4. **Error handling** - Handle missing models and data gracefully
5. **Checkpointing** - Save intermediate results for recovery

## Performance Target

- Original: ~10 pairs/second
- Target: 1000+ pairs/second (100x speedup)
- HPC time: Complete within 2 hours

In [None]:
# Papermill parameters
test_metapath = "CbGpPW"  # Test metapath for validation
test_edge_types = ["CbG", "GpPW"]  # Edge types in test metapath
validation_perm_range = (21, 31)  # Perms 21-30 for validation
model_type = "rf"  # 'rf' or 'poly' or 'ensemble'
chunk_size = 10000  # Process predictions in chunks
use_cache = True  # Use lookup table caching
save_checkpoints = True  # Save intermediate results

In [None]:
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.sparse as sp
from scipy.stats import pearsonr, ks_2samp
import joblib
import time
from collections import Counter
import warnings
import pickle
import gc
warnings.filterwarnings('ignore')

# Setup paths
repo_dir = Path.cwd()
data_dir = repo_dir / 'data'
null_models_dir = repo_dir / 'results' / 'null_models'
results_dir = repo_dir / 'results' / 'compositional_null'
results_dir.mkdir(parents=True, exist_ok=True)

print(f"Repository directory: {repo_dir}")
print(f"Null models directory: {null_models_dir}")
print(f"Results will be saved to: {results_dir}")
print(f"\nOptimization settings:")
print(f"  Chunk size: {chunk_size:,}")
print(f"  Use cache: {use_cache}")
print(f"  Save checkpoints: {save_checkpoints}")

# Set plot style
sns.set_style('whitegrid')
plt.rcParams['figure.dpi'] = 100

## 1. Load Null Models with Error Handling

In [None]:
def load_null_model(edge_type, model_type='rf'):
    """
    Load trained null model for edge type with error handling.
    """
    models = {}
    
    try:
        if model_type in ['rf', 'ensemble']:
            rf_file = null_models_dir / f'{edge_type}_rf_null.pkl'
            if rf_file.exists():
                print(f"    Loading {rf_file.name}...")
                models['rf'] = joblib.load(rf_file)
            else:
                print(f"    ⚠️ Warning: {rf_file.name} not found")
        
        if model_type in ['poly', 'ensemble']:
            poly_file = null_models_dir / f'{edge_type}_poly_null.pkl'
            poly_features_file = null_models_dir / f'{edge_type}_poly_features.pkl'
            if poly_file.exists() and poly_features_file.exists():
                print(f"    Loading {poly_file.name}...")
                models['poly'] = joblib.load(poly_file)
                models['poly_features'] = joblib.load(poly_features_file)
            else:
                print(f"    ⚠️ Warning: Polynomial model files not found")
    except Exception as e:
        print(f"    ❌ Error loading models for {edge_type}: {e}")
    
    return models

# Test loading models for test metapath
print(f"Loading null models for metapath {test_metapath}...")
null_models = {}

for edge_type in test_edge_types:
    print(f"\n  {edge_type}:")
    models = load_null_model(edge_type, model_type)
    if models:
        null_models[edge_type] = models
        print(f"    ✅ Loaded: {list(models.keys())}")
    else:
        print(f"    ❌ Failed to load models")
        raise FileNotFoundError(f"Required models for {edge_type} not found")

print(f"\n✓ Loaded models for {len(null_models)} edge types")

## 2. Get Degree Distributions (Optimized)

In [None]:
def get_degree_distribution(node_type, edge_types_in_metapath, perm_id=0):
    """
    Get degree distribution for intermediate nodes in metapath.
    Optimized version with better memory management.
    """
    degrees = []
    
    for edge_type in edge_types_in_metapath:
        edge_file = data_dir / 'permutations' / f'{perm_id:03d}.hetmat' / 'edges' / f'{edge_type}.sparse.npz'
        
        if not edge_file.exists():
            # Try alternative path structure
            edge_file = data_dir / 'edges' / f'{edge_type}.sparse.npz'
        
        if edge_file.exists():
            print(f"    Loading {edge_file.name}...")
            matrix = sp.load_npz(str(edge_file))
            
            # Determine if node type is source or target
            if 'G' in edge_type:  # Gene-related edges
                if edge_type.endswith('G'):  # Gene is target
                    node_degrees = np.array(matrix.sum(axis=0)).flatten()
                else:  # Gene is source
                    node_degrees = np.array(matrix.sum(axis=1)).flatten()
                
                # Only keep non-zero degrees
                nonzero_degrees = node_degrees[node_degrees > 0]
                degrees.extend(nonzero_degrees.tolist())
            
            # Clear matrix to free memory
            del matrix
            gc.collect()
        else:
            print(f"    ⚠️ Warning: {edge_file} not found")
    
    if not degrees:
        print(f"    ⚠️ Warning: No degrees found, using default distribution")
        # Use a reasonable default distribution if no data found
        degrees = [1] * 100 + [2] * 50 + [3] * 30 + [5] * 20 + [10] * 10
    
    # Count degree frequencies
    degree_counts = Counter(degrees)
    total = sum(degree_counts.values())
    degree_freq = {deg: count/total for deg, count in degree_counts.items()}
    
    return degree_freq

# Get gene degree distribution
print("\nComputing degree distributions for intermediate nodes...")
gene_degree_freq = get_degree_distribution('Gene', test_edge_types, perm_id=0)

if gene_degree_freq:
    print(f"  Gene degrees: {len(gene_degree_freq)} unique degrees")
    if len(gene_degree_freq) > 0:
        print(f"  Degree range: {min(gene_degree_freq.keys())}-{max(gene_degree_freq.keys())}")
        print(f"  Sample frequencies: {dict(list(gene_degree_freq.items())[:5])}")
else:
    print(f"  ⚠️ No degree distribution found")

## 3. Optimized Compositional Null Calculator

In [None]:
class OptimizedCompositionalCalculator:
    """
    Optimized compositional null calculator with caching and vectorization.
    """
    
    def __init__(self, edge1_models, edge2_models, intermediate_degree_freq, 
                 model_type='rf', use_cache=True):
        self.edge1_models = edge1_models
        self.edge2_models = edge2_models
        self.intermediate_degree_freq = intermediate_degree_freq
        self.model_type = model_type
        self.use_cache = use_cache
        self.cache = {} if use_cache else None
        
        # Pre-compute intermediate degrees and frequencies as arrays for vectorization
        self.inter_degrees = np.array(list(intermediate_degree_freq.keys()))
        self.inter_freqs = np.array([intermediate_degree_freq[d] for d in self.inter_degrees])
    
    def _batch_predict(self, degrees_array, models, edge_num):
        """
        Batch predict edge probabilities for many degree pairs at once.
        """
        if self.model_type == 'rf' and 'rf' in models:
            preds = models['rf'].predict(degrees_array)
        elif self.model_type == 'poly' and 'poly' in models:
            X_poly = models['poly_features'].transform(degrees_array)
            preds = models['poly'].predict(X_poly)
        else:
            preds = np.zeros(len(degrees_array))
        
        return np.clip(preds, 0, 1)
    
    def compute_metapath_null_vectorized(self, source_degrees, target_degrees, 
                                        chunk_size=10000):
        """
        Compute null probabilities using vectorized operations.
        Process in chunks to manage memory.
        """
        n_pairs = len(source_degrees)
        null_probs = np.zeros(n_pairs)
        
        # Process in chunks
        for chunk_start in range(0, n_pairs, chunk_size):
            chunk_end = min(chunk_start + chunk_size, n_pairs)
            chunk_sources = source_degrees[chunk_start:chunk_end]
            chunk_targets = target_degrees[chunk_start:chunk_end]
            chunk_size_actual = len(chunk_sources)
            
            # Initialize chunk results
            chunk_probs = np.zeros(chunk_size_actual)
            
            # For each source-target pair in chunk
            for i, (src_deg, tgt_deg) in enumerate(zip(chunk_sources, chunk_targets)):
                # Check cache
                if self.use_cache:
                    cache_key = (src_deg, tgt_deg)
                    if cache_key in self.cache:
                        chunk_probs[i] = self.cache[cache_key]
                        continue
                
                # Create degree pairs for all intermediate nodes
                src_inter_pairs = np.column_stack([
                    np.full(len(self.inter_degrees), src_deg),
                    self.inter_degrees
                ])
                
                inter_tgt_pairs = np.column_stack([
                    self.inter_degrees,
                    np.full(len(self.inter_degrees), tgt_deg)
                ])
                
                # Batch predict all edge probabilities
                p1_vec = self._batch_predict(src_inter_pairs, self.edge1_models, 1)
                p2_vec = self._batch_predict(inter_tgt_pairs, self.edge2_models, 2)
                
                # Compositional calculation: sum over intermediate nodes
                total_prob = np.sum(p1_vec * p2_vec * self.inter_freqs)
                chunk_probs[i] = total_prob
                
                # Update cache
                if self.use_cache:
                    self.cache[cache_key] = total_prob
            
            null_probs[chunk_start:chunk_end] = chunk_probs
            
            # Progress update
            if chunk_end % (chunk_size * 10) == 0 or chunk_end == n_pairs:
                progress = chunk_end / n_pairs * 100
                print(f"    Processed {chunk_end:,} / {n_pairs:,} pairs ({progress:.1f}%)")
        
        return null_probs

# Create optimized calculator
print("\nCreating optimized compositional calculator...")
calculator = OptimizedCompositionalCalculator(
    null_models[test_edge_types[0]],
    null_models[test_edge_types[1]],
    gene_degree_freq,
    model_type=model_type,
    use_cache=use_cache
)
print("✅ Calculator ready!")

## 4. Extract True Null with Checkpointing

In [None]:
def extract_metapath_frequencies(metapath_edges, perm_id):
    """
    Extract observed metapath frequencies from a permutation.
    """
    edge1_type, edge2_type = metapath_edges
    
    # Try multiple path structures
    edge1_file = data_dir / 'permutations' / f'{perm_id:03d}.hetmat' / 'edges' / f'{edge1_type}.sparse.npz'
    edge2_file = data_dir / 'permutations' / f'{perm_id:03d}.hetmat' / 'edges' / f'{edge2_type}.sparse.npz'
    
    # Alternative paths if not found
    if not edge1_file.exists():
        edge1_file = data_dir / 'edges' / f'{edge1_type}.sparse.npz'
    if not edge2_file.exists():
        edge2_file = data_dir / 'edges' / f'{edge2_type}.sparse.npz'
    
    if not edge1_file.exists() or not edge2_file.exists():
        print(f"    ⚠️ Edge files not found for perm {perm_id}")
        return None
    
    try:
        matrix1 = sp.load_npz(str(edge1_file))  # C × G
        matrix2 = sp.load_npz(str(edge2_file))  # G × P
        
        # Compute metapath matrix
        metapath_matrix = matrix1 @ matrix2
        
        # Get degrees
        source_degrees = np.array(matrix1.sum(axis=1)).flatten()
        target_degrees = np.array(matrix2.sum(axis=0)).flatten()
        
        # Extract metapath frequencies
        data = []
        for i, j in zip(*metapath_matrix.nonzero()):
            data.append({
                'source_idx': i,
                'target_idx': j,
                'source_degree': int(source_degrees[i]),
                'target_degree': int(target_degrees[j]),
                'metapath_count': int(metapath_matrix[i, j]),
                'perm_id': perm_id
            })
        
        # Clean up
        del matrix1, matrix2, metapath_matrix
        gc.collect()
        
        return pd.DataFrame(data)
    
    except Exception as e:
        print(f"    ❌ Error processing perm {perm_id}: {e}")
        return None

# Check if checkpoint exists
checkpoint_file = results_dir / f'{test_metapath}_true_null_checkpoint.pkl'

if checkpoint_file.exists() and not save_checkpoints:
    print(f"Loading checkpoint from {checkpoint_file}...")
    with open(checkpoint_file, 'rb') as f:
        true_null_agg = pickle.load(f)
    print(f"  ✅ Loaded {len(true_null_agg)} metapath pairs from checkpoint")
else:
    # Extract true null from validation permutations
    print(f"\nExtracting true null from permutations {validation_perm_range[0]}-{validation_perm_range[1]-1}...")
    all_true_null = []
    
    for perm_id in range(validation_perm_range[0], validation_perm_range[1]):
        print(f"  Processing permutation {perm_id}...")
        df = extract_metapath_frequencies(test_edge_types, perm_id)
        if df is not None and len(df) > 0:
            all_true_null.append(df)
            print(f"    Found {len(df):,} metapath pairs")
    
    if all_true_null:
        true_null_df = pd.concat(all_true_null, ignore_index=True)
        
        # Aggregate across permutations
        true_null_agg = true_null_df.groupby(['source_idx', 'target_idx', 'source_degree', 'target_degree']).agg({
            'metapath_count': 'mean'
        }).reset_index()
        
        # Normalize to probabilities
        total_paths = true_null_agg['metapath_count'].sum()
        if total_paths > 0:
            true_null_agg['true_null_prob'] = true_null_agg['metapath_count'] / total_paths
        else:
            true_null_agg['true_null_prob'] = 0
        
        # Save checkpoint
        if save_checkpoints:
            with open(checkpoint_file, 'wb') as f:
                pickle.dump(true_null_agg, f)
            print(f"  💾 Saved checkpoint to {checkpoint_file}")
    else:
        print("  ❌ No validation data found!")
        true_null_agg = pd.DataFrame()

if len(true_null_agg) > 0:
    print(f"\n📊 Extracted {len(true_null_agg)} metapath pairs")
    print(f"  Source degree range: {true_null_agg['source_degree'].min()}-{true_null_agg['source_degree'].max()}")
    print(f"  Target degree range: {true_null_agg['target_degree'].min()}-{true_null_agg['target_degree'].max()}")
    print(f"  Mean null probability: {true_null_agg['true_null_prob'].mean():.6e}")

## 5. Compute ML-Compositional Null (Optimized)

In [None]:
if len(true_null_agg) > 0:
    print(f"\nComputing ML-compositional null using {model_type} model...")
    print(f"  Processing {len(true_null_agg):,} pairs in chunks of {chunk_size:,}")
    
    start_time = time.time()
    
    # Compute null for each (source, target) pair
    ml_null_probs = calculator.compute_metapath_null_vectorized(
        true_null_agg['source_degree'].values,
        true_null_agg['target_degree'].values,
        chunk_size=chunk_size
    )
    
    elapsed_time = time.time() - start_time
    
    true_null_agg['ml_null_prob'] = ml_null_probs
    
    print(f"\n✅ Computation complete!")
    print(f"  Computed {len(ml_null_probs):,} null probabilities")
    print(f"  Computation time: {elapsed_time:.2f} seconds")
    print(f"  Speed: {len(ml_null_probs)/elapsed_time:.0f} pairs/second")
    print(f"  Mean ML-null probability: {ml_null_probs.mean():.6e}")
    
    if use_cache:
        print(f"  Cache size: {len(calculator.cache):,} unique degree pairs")
        cache_hit_rate = len(calculator.cache) / len(ml_null_probs) * 100
        print(f"  Cache coverage: {cache_hit_rate:.1f}%")
else:
    print("⚠️ No data to compute null probabilities")

## 6. Validate Against True Null

In [None]:
if len(true_null_agg) > 0 and 'ml_null_prob' in true_null_agg.columns:
    # Remove any zero or invalid values
    valid_mask = (true_null_agg['true_null_prob'] > 0) & (true_null_agg['ml_null_prob'] > 0)
    valid_data = true_null_agg[valid_mask].copy()
    
    print(f"\n📊 Validation Results:")
    print(f"  Valid pairs: {len(valid_data)} / {len(true_null_agg)}")
    
    if len(valid_data) > 1:
        # Correlation
        corr, p_val = pearsonr(valid_data['true_null_prob'], valid_data['ml_null_prob'])
        print(f"  Pearson correlation: r = {corr:.4f} (p = {p_val:.2e})")
        
        # MAE and RMSE
        mae = np.abs(valid_data['true_null_prob'] - valid_data['ml_null_prob']).mean()
        rmse = np.sqrt(((valid_data['true_null_prob'] - valid_data['ml_null_prob'])**2).mean())
        print(f"  MAE: {mae:.6e}")
        print(f"  RMSE: {rmse:.6e}")
        
        # R²
        from sklearn.metrics import r2_score
        r2 = r2_score(valid_data['true_null_prob'], valid_data['ml_null_prob'])
        print(f"  R²: {r2:.4f}")
        
        # Success criteria
        print(f"\n✅ Success Criteria:")
        print(f"  Correlation > 0.75: {'✓' if corr > 0.75 else '✗'} ({corr:.4f})")
        print(f"  RMSE < 0.20: {'✓' if rmse < 0.20 else '✗'} ({rmse:.4f})")
        print(f"  R² > 0.50: {'✓' if r2 > 0.50 else '✗'} ({r2:.4f})")
        
        # Performance improvement
        print(f"\n🚀 Performance Improvement:")
        speedup = (len(ml_null_probs)/elapsed_time) / 10  # Original was ~10 pairs/sec
        print(f"  Speedup: {speedup:.1f}x faster than original")
        print(f"  Time to process 1M pairs: {1000000/(len(ml_null_probs)/elapsed_time)/60:.1f} minutes")
else:
    print("⚠️ Insufficient data for validation")

## 7. Save Results

In [None]:
if len(true_null_agg) > 0 and 'ml_null_prob' in true_null_agg.columns:
    # Save validation data
    output_file = results_dir / f'{test_metapath}_null_validation_optimized.csv'
    true_null_agg.to_csv(output_file, index=False)
    print(f"\n💾 Saved validation data to {output_file}")
    
    # Save summary
    summary = {
        'metapath': test_metapath,
        'edge_types': test_edge_types,
        'model_type': model_type,
        'n_pairs': len(valid_data) if 'valid_data' in locals() else 0,
        'correlation': corr if 'corr' in locals() else None,
        'mae': mae if 'mae' in locals() else None,
        'rmse': rmse if 'rmse' in locals() else None,
        'r2': r2 if 'r2' in locals() else None,
        'computation_time_sec': elapsed_time,
        'pairs_per_second': len(ml_null_probs)/elapsed_time,
        'optimization_settings': {
            'chunk_size': chunk_size,
            'use_cache': use_cache,
            'cache_size': len(calculator.cache) if use_cache else 0
        }
    }
    
    summary_file = results_dir / f'{test_metapath}_summary_optimized.json'
    import json
    with open(summary_file, 'w') as f:
        json.dump(summary, f, indent=2)
    print(f"💾 Saved summary to {summary_file}")

print("\n" + "="*70)
print("OPTIMIZED COMPOSITIONAL NULL COMPLETE!")
print("="*70)