# Fast Compositional Null Calculator

This notebook builds and validates a fast compositional null generator using ML model predictions.

## Purpose

Create a fast method to compute metapath null probabilities:
- **Use ML predictions** instead of analytical formula
- **Compositional calculation**: P(path) = Σ P(edge1) × P(edge2) × ...
- **Validate** against true null from permutations
- **Benchmark** speed and accuracy

## Key Insight

**Current problem**: Analytical formula fails (r=0.065)  
**Solution**: Replace analytical with ML predictions (r=0.83), keep compositional structure

## Workflow

1. Load null models from notebook 13
2. Create prediction lookup tables
3. Implement compositional calculator
4. Validate on test metapath (CbGpPW)
5. Error analysis by degree
6. Performance benchmarks

In [1]:
# Papermill parameters
test_metapath = "CbGpPW"  # Test metapath for validation
test_edge_types = ["CbG", "GpPW"]  # Edge types in test metapath
validation_perm_range = (21, 31)  # Perms 21-30 for validation
model_type = "rf"  # 'rf' or 'poly' or 'ensemble'

In [2]:
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.sparse as sp
from scipy.stats import pearsonr, ks_2samp
import joblib
import time
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

# Setup paths
repo_dir = Path.cwd()
data_dir = repo_dir / 'data'
null_models_dir = repo_dir / 'results' / 'null_models'
results_dir = repo_dir / 'results' / 'compositional_null'
results_dir.mkdir(parents=True, exist_ok=True)

print(f"Repository directory: {repo_dir}")
print(f"Null models directory: {null_models_dir}")
print(f"Results will be saved to: {results_dir}")

# Set plot style
sns.set_style('whitegrid')
plt.rcParams['figure.dpi'] = 100

Repository directory: /projects/lgillenwater@xsede.org/repositories/Context-Aware-Path-Probability
Null models directory: /projects/lgillenwater@xsede.org/repositories/Context-Aware-Path-Probability/results/null_models
Results will be saved to: /projects/lgillenwater@xsede.org/repositories/Context-Aware-Path-Probability/results/compositional_null


## 1. Load Null Models

In [3]:
def load_null_model(edge_type, model_type='rf'):
    """
    Load trained null model for edge type.
    
    Parameters:
    -----------
    edge_type : str
        Edge type (e.g., 'CbG')
    model_type : str
        'rf' for Random Forest, 'poly' for Polynomial LogReg, 'ensemble' for both
        
    Returns:
    --------
    dict with 'rf' and/or 'poly' keys containing models
    """
    models = {}
    
    if model_type in ['rf', 'ensemble']:
        rf_file = null_models_dir / f'{edge_type}_rf_null.pkl'
        if rf_file.exists():
            models['rf'] = joblib.load(rf_file)
    
    if model_type in ['poly', 'ensemble']:
        poly_file = null_models_dir / f'{edge_type}_poly_null.pkl'
        poly_features_file = null_models_dir / f'{edge_type}_poly_features.pkl'
        if poly_file.exists() and poly_features_file.exists():
            models['poly'] = joblib.load(poly_file)
            models['poly_features'] = joblib.load(poly_features_file)
    
    return models

# Test loading models for test metapath
print(f"Loading null models for metapath {test_metapath}...")
null_models = {}

for edge_type in test_edge_types:
    models = load_null_model(edge_type, model_type)
    null_models[edge_type] = models
    print(f"  {edge_type}: {list(models.keys())}")

print(f"\n✓ Loaded models for {len(null_models)} edge types")

Loading null models for metapath CbGpPW...


  CbG: ['rf']
  GpPW: ['rf']

✓ Loaded models for 2 edge types


## 2. Get Degree Distributions

In [4]:
def get_degree_distribution(node_type, edge_types_in_metapath, perm_id=0):
    """
    Get degree distribution for intermediate nodes in metapath.
    
    Parameters:
    -----------
    node_type : str
        Node type (e.g., 'Gene', 'Pathway')
    edge_types_in_metapath : list
        Edge types that connect to this node type
    perm_id : int
        Permutation ID (0 for Hetionet)
        
    Returns:
    --------
    dict: {degree: frequency}
    """
    # Map edge types to node indices
    # For CbGpPW: Gene is intermediate, connects via CbG (target) and GpPW (source)
    
    degrees = []
    
    for edge_type in edge_types_in_metapath:
        edge_file = data_dir / 'permutations' / f'{perm_id:03d}.hetmat' / 'edges' / f'{edge_type}.sparse.npz'
        
        if edge_file.exists():
            matrix = sp.load_npz(str(edge_file))
            
            # Determine if node type is source or target in this edge type
            # For simplicity, we'll use total degree (sum of in and out)
            if 'G' in edge_type:  # Gene-related edges
                if edge_type.endswith('G'):  # Gene is target
                    node_degrees = np.array(matrix.sum(axis=0)).flatten()
                else:  # Gene is source
                    node_degrees = np.array(matrix.sum(axis=1)).flatten()
                
                degrees.extend(node_degrees[node_degrees > 0].tolist())
    
    # Count degree frequencies
    degree_counts = Counter(degrees)
    total = sum(degree_counts.values())
    degree_freq = {deg: count/total for deg, count in degree_counts.items()}
    
    return degree_freq

# Get gene degree distribution for CbGpPW metapath
print("Computing degree distributions for intermediate nodes...")
gene_degree_freq = get_degree_distribution('Gene', test_edge_types, perm_id=0)
print(f"  Gene degrees: {len(gene_degree_freq)} unique degrees")
print(f"  Degree range: {min(gene_degree_freq.keys())}-{max(gene_degree_freq.keys())}")
print(f"  Sample frequencies: {dict(list(gene_degree_freq.items())[:5])}")

Computing degree distributions for intermediate nodes...
  Gene degrees: 132 unique degrees
  Degree range: 1-516
  Sample frequencies: {1: 0.17547806524184478, 6: 0.06092988376452944, 5: 0.07845894263217097, 3: 0.08361454818147732, 12: 0.017622797150356206}


## 3. Fast Compositional Null Calculator

In [5]:
def predict_edge_probability(source_deg, target_deg, models, model_type='rf'):
    """
    Predict edge probability using null model.
    """
    if model_type == 'rf' and 'rf' in models:
        pred = models['rf'].predict([[source_deg, target_deg]])[0]
    elif model_type == 'poly' and 'poly' in models:
        X = models['poly_features'].transform([[source_deg, target_deg]])
        pred = models['poly'].predict(X)[0]
    elif model_type == 'ensemble' and 'rf' in models and 'poly' in models:
        pred_rf = models['rf'].predict([[source_deg, target_deg]])[0]
        X = models['poly_features'].transform([[source_deg, target_deg]])
        pred_poly = models['poly'].predict(X)[0]
        pred = 0.5 * pred_rf + 0.5 * pred_poly
    else:
        pred = 0.0
    
    return np.clip(pred, 0, 1)

def compute_metapath_null_2edge(source_degrees, target_degrees, 
                                edge1_models, edge2_models, 
                                intermediate_degree_freq,
                                model_type='rf'):
    """
    Compute null probability for 2-edge metapath (e.g., C-G-P).
    
    P_null(C→P) = Σ_g P(C→g) × P(g→P) × freq(g)
    
    Parameters:
    -----------
    source_degrees : array
        Source node degrees (e.g., compound degrees)
    target_degrees : array
        Target node degrees (e.g., pathway degrees)
    edge1_models : dict
        Null models for first edge type
    edge2_models : dict
        Null models for second edge type
    intermediate_degree_freq : dict
        {degree: frequency} for intermediate nodes
    model_type : str
        'rf', 'poly', or 'ensemble'
        
    Returns:
    --------
    array of null probabilities
    """
    null_probs = []
    
    for source_deg, target_deg in zip(source_degrees, target_degrees):
        total_prob = 0.0
        
        # Sum over all intermediate node degrees
        for inter_deg, freq in intermediate_degree_freq.items():
            # P(source → intermediate)
            p1 = predict_edge_probability(source_deg, inter_deg, edge1_models, model_type)
            
            # P(intermediate → target)
            p2 = predict_edge_probability(inter_deg, target_deg, edge2_models, model_type)
            
            # Compositional multiplication weighted by intermediate degree frequency
            total_prob += p1 * p2 * freq
        
        null_probs.append(total_prob)
    
    return np.array(null_probs)

print("Compositional null calculator ready!")

Compositional null calculator ready!


## 4. Extract True Null from Validation Permutations

In [6]:
def extract_metapath_frequencies(metapath_edges, perm_id):
    """
    Extract observed metapath frequencies from a permutation.
    
    For 2-edge metapath C-G-P:
    - Load C→G and G→P edge matrices
    - Compute metapath matrix: C-G-P = (C→G) @ (G→P)
    - Return frequencies for each (compound, pathway) pair
    """
    edge1_type, edge2_type = metapath_edges
    
    # Load edge matrices
    edge1_file = data_dir / 'permutations' / f'{perm_id:03d}.hetmat' / 'edges' / f'{edge1_type}.sparse.npz'
    edge2_file = data_dir / 'permutations' / f'{perm_id:03d}.hetmat' / 'edges' / f'{edge2_type}.sparse.npz'
    
    if not edge1_file.exists() or not edge2_file.exists():
        return None
    
    matrix1 = sp.load_npz(str(edge1_file))  # C × G
    matrix2 = sp.load_npz(str(edge2_file))  # G × P
    
    # Compute metapath matrix: C × P
    metapath_matrix = matrix1 @ matrix2
    
    # Get degrees
    source_degrees = np.array(matrix1.sum(axis=1)).flatten()  # Compound degrees
    target_degrees = np.array(matrix2.sum(axis=0)).flatten()  # Pathway degrees
    
    # Extract metapath frequencies
    data = []
    for i, j in zip(*metapath_matrix.nonzero()):
        data.append({
            'source_idx': i,
            'target_idx': j,
            'source_degree': int(source_degrees[i]),
            'target_degree': int(target_degrees[j]),
            'metapath_count': int(metapath_matrix[i, j]),
            'perm_id': perm_id
        })
    
    return pd.DataFrame(data)

# Extract true null from validation permutations
print(f"Extracting true null from permutations {validation_perm_range[0]}-{validation_perm_range[1]-1}...")
all_true_null = []

for perm_id in range(validation_perm_range[0], validation_perm_range[1]):
    df = extract_metapath_frequencies(test_edge_types, perm_id)
    if df is not None:
        all_true_null.append(df)

true_null_df = pd.concat(all_true_null, ignore_index=True)

# Aggregate across permutations
true_null_agg = true_null_df.groupby(['source_idx', 'target_idx', 'source_degree', 'target_degree']).agg({
    'metapath_count': 'mean'
}).reset_index()

# Normalize to probabilities
total_paths = true_null_agg['metapath_count'].sum()
true_null_agg['true_null_prob'] = true_null_agg['metapath_count'] / total_paths

print(f"  Extracted {len(true_null_agg)} metapath pairs")
print(f"  Source degree range: {true_null_agg['source_degree'].min()}-{true_null_agg['source_degree'].max()}")
print(f"  Target degree range: {true_null_agg['target_degree'].min()}-{true_null_agg['target_degree'].max()}")
print(f"  Mean null probability: {true_null_agg['true_null_prob'].mean():.6e}")

Extracting true null from permutations 21-30...


  Extracted 684757 metapath pairs
  Source degree range: 1-132
  Target degree range: 2-1956
  Mean null probability: 1.460372e-06


## 5. Compute ML-Compositional Null

In [None]:
print(f"Computing ML-compositional null using {model_type} model...")
start_time = time.time()

# Compute null for each (source, target) pair
ml_null_probs = compute_metapath_null_2edge(
    true_null_agg['source_degree'].values,
    true_null_agg['target_degree'].values,
    null_models[test_edge_types[0]],
    null_models[test_edge_types[1]],
    gene_degree_freq,
    model_type=model_type
)

elapsed_time = time.time() - start_time

true_null_agg['ml_null_prob'] = ml_null_probs

print(f"  ✓ Computed {len(ml_null_probs)} null probabilities")
print(f"  Computation time: {elapsed_time:.2f} seconds")
print(f"  Speed: {len(ml_null_probs)/elapsed_time:.0f} pairs/second")
print(f"  Mean ML-null probability: {ml_null_probs.mean():.6e}")

## 6. Validate Against True Null

In [None]:
# Remove any zero or invalid values
valid_mask = (true_null_agg['true_null_prob'] > 0) & (true_null_agg['ml_null_prob'] > 0)
valid_data = true_null_agg[valid_mask].copy()

print(f"\nValidation Results:")
print(f"  Valid pairs: {len(valid_data)} / {len(true_null_agg)}")

# Correlation
corr, p_val = pearsonr(valid_data['true_null_prob'], valid_data['ml_null_prob'])
print(f"  Pearson correlation: r = {corr:.4f} (p = {p_val:.2e})")

# MAE and RMSE
mae = np.abs(valid_data['true_null_prob'] - valid_data['ml_null_prob']).mean()
rmse = np.sqrt(((valid_data['true_null_prob'] - valid_data['ml_null_prob'])**2).mean())
print(f"  MAE: {mae:.6e}")
print(f"  RMSE: {rmse:.6e}")

# R²
from sklearn.metrics import r2_score
r2 = r2_score(valid_data['true_null_prob'], valid_data['ml_null_prob'])
print(f"  R²: {r2:.4f}")

# KS test
ks_stat, ks_p = ks_2samp(valid_data['true_null_prob'], valid_data['ml_null_prob'])
print(f"  KS test: D = {ks_stat:.4f} (p = {ks_p:.2e})")

# Success criteria
print(f"\nSuccess Criteria:")
print(f"  Correlation > 0.75: {'✓' if corr > 0.75 else '✗'} ({corr:.4f})")
print(f"  RMSE < 0.20: {'✓' if rmse < 0.20 else '✗'} ({rmse:.4f})")
print(f"  R² > 0.50: {'✓' if r2 > 0.50 else '✗'} ({r2:.4f})")

## 7. Error Analysis by Degree

In [None]:
def analyze_errors_by_degree(predictions, actuals, source_degrees, target_degrees):
    """Analyze prediction errors stratified by degree bins."""
    degree_bins = [0, 1, 2, 5, 10, 20, 50, 100, 500, np.inf]
    degree_labels = ['0', '1', '2-4', '5-9', '10-19', '20-49', '50-99', '100-499', '500+']
    
    results = []
    
    source_bins = pd.cut(source_degrees, bins=degree_bins, labels=degree_labels)
    target_bins = pd.cut(target_degrees, bins=degree_bins, labels=degree_labels)
    
    for bin_label in degree_labels:
        mask = source_bins == bin_label
        if mask.sum() >= 10:
            results.append({
                'stratification': 'source_degree',
                'bin': bin_label,
                'n_samples': int(mask.sum()),
                'mean_prediction': float(predictions[mask].mean()),
                'mean_actual': float(actuals[mask].mean()),
                'mae': float(np.abs(predictions[mask] - actuals[mask]).mean()),
                'rmse': float(np.sqrt(((predictions[mask] - actuals[mask])**2).mean())),
                'correlation': float(np.corrcoef(predictions[mask], actuals[mask])[0,1]) if mask.sum() > 1 else np.nan
            })
    
    for bin_label in degree_labels:
        mask = target_bins == bin_label
        if mask.sum() >= 10:
            results.append({
                'stratification': 'target_degree',
                'bin': bin_label,
                'n_samples': int(mask.sum()),
                'mean_prediction': float(predictions[mask].mean()),
                'mean_actual': float(actuals[mask].mean()),
                'mae': float(np.abs(predictions[mask] - actuals[mask]).mean()),
                'rmse': float(np.sqrt(((predictions[mask] - actuals[mask])**2).mean())),
                'correlation': float(np.corrcoef(predictions[mask], actuals[mask])[0,1]) if mask.sum() > 1 else np.nan
            })
    
    return pd.DataFrame(results)

error_analysis_df = analyze_errors_by_degree(
    valid_data['ml_null_prob'].values,
    valid_data['true_null_prob'].values,
    valid_data['source_degree'].values,
    valid_data['target_degree'].values
)

print("\nError Analysis by Degree:")
print("\nSource Degree:")
source_errors = error_analysis_df[error_analysis_df['stratification'] == 'source_degree']
print(source_errors[['bin', 'n_samples', 'correlation', 'mae', 'rmse']].to_string(index=False))

print("\nTarget Degree:")
target_errors = error_analysis_df[error_analysis_df['stratification'] == 'target_degree']
print(target_errors[['bin', 'n_samples', 'correlation', 'mae', 'rmse']].to_string(index=False))

## 8. Visualizations

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# 1. True null vs ML null scatter
ax = axes[0, 0]
ax.scatter(valid_data['true_null_prob'], valid_data['ml_null_prob'], alpha=0.5, s=20)
ax.plot([valid_data['true_null_prob'].min(), valid_data['true_null_prob'].max()],
        [valid_data['true_null_prob'].min(), valid_data['true_null_prob'].max()],
        'r--', alpha=0.8, label='Perfect agreement')
ax.set_xlabel('True Null Probability', fontsize=12)
ax.set_ylabel('ML-Compositional Null Probability', fontsize=12)
ax.set_title(f'ML-Null vs True Null (r={corr:.3f})', fontsize=14, fontweight='bold')
ax.set_xscale('log')
ax.set_yscale('log')
ax.legend()
ax.grid(True, alpha=0.3)

# 2. Error by source degree
ax = axes[0, 1]
source_err = error_analysis_df[error_analysis_df['stratification'] == 'source_degree']
ax.bar(range(len(source_err)), source_err['mae'], alpha=0.7)
ax.set_xticks(range(len(source_err)))
ax.set_xticklabels(source_err['bin'], rotation=45, ha='right')
ax.set_xlabel('Source Degree Bin', fontsize=12)
ax.set_ylabel('MAE', fontsize=12)
ax.set_title('Error by Source Degree', fontsize=14, fontweight='bold')
ax.grid(axis='y', alpha=0.3)

# 3. Error by target degree
ax = axes[1, 0]
target_err = error_analysis_df[error_analysis_df['stratification'] == 'target_degree']
ax.bar(range(len(target_err)), target_err['mae'], alpha=0.7, color='orange')
ax.set_xticks(range(len(target_err)))
ax.set_xticklabels(target_err['bin'], rotation=45, ha='right')
ax.set_xlabel('Target Degree Bin', fontsize=12)
ax.set_ylabel('MAE', fontsize=12)
ax.set_title('Error by Target Degree', fontsize=14, fontweight='bold')
ax.grid(axis='y', alpha=0.3)

# 4. Distribution comparison
ax = axes[1, 1]
ax.hist(np.log10(valid_data['true_null_prob'] + 1e-10), bins=50, alpha=0.6, 
        label='True Null', density=True, edgecolor='black')
ax.hist(np.log10(valid_data['ml_null_prob'] + 1e-10), bins=50, alpha=0.6, 
        label='ML-Compositional Null', density=True, edgecolor='black')
ax.set_xlabel('log₁₀(Probability)', fontsize=12)
ax.set_ylabel('Density', fontsize=12)
ax.set_title('Probability Distributions', fontsize=14, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(results_dir / f'{test_metapath}_validation.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"Saved validation plots to {results_dir / f'{test_metapath}_validation.png'}")

## 9. Save Results

In [None]:
# Save validation data
valid_data.to_csv(results_dir / f'{test_metapath}_null_validation.csv', index=False)
print(f"Saved validation data to {results_dir / f'{test_metapath}_null_validation.csv'}")

# Save error analysis
error_analysis_df.to_csv(results_dir / f'{test_metapath}_error_analysis.csv', index=False)
print(f"Saved error analysis to {results_dir / f'{test_metapath}_error_analysis.csv'}")

# Save summary statistics
summary = {
    'metapath': test_metapath,
    'edge_types': test_edge_types,
    'model_type': model_type,
    'n_pairs': len(valid_data),
    'correlation': corr,
    'p_value': p_val,
    'mae': mae,
    'rmse': rmse,
    'r2': r2,
    'ks_stat': ks_stat,
    'ks_p': ks_p,
    'computation_time_sec': elapsed_time,
    'pairs_per_second': len(ml_null_probs)/elapsed_time
}

summary_df = pd.DataFrame([summary])
summary_df.to_csv(results_dir / f'{test_metapath}_summary.csv', index=False)
print(f"Saved summary to {results_dir / f'{test_metapath}_summary.csv'}")

## 10. Summary

In [None]:
print("\n" + "="*70)
print("FAST COMPOSITIONAL NULL VALIDATION SUMMARY")
print("="*70)

print(f"\nMetapath: {test_metapath}")
print(f"Edge types: {' → '.join(test_edge_types)}")
print(f"Model type: {model_type}")

print(f"\nPerformance:")
print(f"  Correlation with true null: r = {corr:.4f} {'✓' if corr > 0.75 else '✗'}")
print(f"  MAE: {mae:.6e}")
print(f"  RMSE: {rmse:.6e} {'✓' if rmse < 0.20 else '✗'}")
print(f"  R²: {r2:.4f} {'✓' if r2 > 0.50 else '✗'}")

print(f"\nSpeed:")
print(f"  Computation time: {elapsed_time:.2f} seconds")
print(f"  Pairs computed: {len(ml_null_probs):,}")
print(f"  Speed: {len(ml_null_probs)/elapsed_time:.0f} pairs/second")

print(f"\nComparison to Current Method:")
print(f"  Current degree-aware correlation: 0.065")
print(f"  ML-compositional correlation: {corr:.4f}")
print(f"  Improvement: {(corr - 0.065):.4f} ({(corr/0.065 - 1)*100:.1f}% better)")

print(f"\nOutput Files:")
for file in results_dir.glob(f'{test_metapath}_*'):
    print(f"  - {file.name}")

print("\n" + "="*70)
print("VALIDATION COMPLETE!")
print("="*70)