# Metapath Compositionality Analysis

## Research Question

**Are metapath edge probabilities compositional (independent) or conditional (dependent)?**

### Hypothesis

For a metapath like **CbG → GpPW** (Compound binds Gene, Gene participates in Pathway):

- **Compositional (H0)**: P(CbG ∩ GpPW) = P(CbG) × P(GpPW)
  - Edge probabilities are independent
  - Simple multiplication of marginal probabilities
  - Analytical prior should work well

- **Conditional (H1)**: P(CbG ∩ GpPW) ≠ P(CbG) × P(GpPW)
  - Edge probabilities are conditionally dependent
  - P(GpPW | CbG) ≠ P(GpPW)
  - Need learned/empirical priors

### Measurement

We use **Pointwise Mutual Information (PMI)** to quantify dependency:

```
PMI(edge1, edge2) = log₂(P(edge1, edge2) / (P(edge1) × P(edge2)))
```

- **PMI = 0**: Independent (compositional)
- **PMI > 0**: Positive association (conditional)
- **PMI < 0**: Negative association (anti-correlation)

### Test Metapath

**CbGpPW**: Compound → binds Gene → participates in Pathway

We compare:
1. **Observed**: Actual metapath frequencies in Hetionet
2. **Compositional**: Analytical prior (independent edge probabilities)
3. **Learned** (optional): Learned formula or empirical priors from N permutations

In [None]:
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.sparse as sp
from scipy.stats import pearsonr, spearmanr
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

# Setup paths
repo_dir = Path.cwd().parent
src_dir = repo_dir / 'src'
data_dir = repo_dir / 'data'
results_dir = repo_dir / 'results'

sys.path.append(str(src_dir))

print(f"Repository: {repo_dir}")

# Set plot style
sns.set_style('whitegrid')
plt.rcParams['figure.dpi'] = 100

## Configuration

In [None]:
# Test metapath: CbGpPW
metapath = ['CbG', 'GpPW']
metapath_name = 'CbGpPW'

# Source type: Compound
# Intermediate type: Gene
# Target type: Pathway

print(f"Testing metapath: {metapath_name}")
print(f"  Edge 1: {metapath[0]} (Compound → Gene)")
print(f"  Edge 2: {metapath[1]} (Gene → Pathway)")

# Prior options (can be swapped)
PRIOR_TYPE = 'analytical'  # Options: 'analytical', 'learned', 'empirical'

print(f"\nUsing prior: {PRIOR_TYPE}")

## Helper Functions

In [None]:
def load_edge_matrix(edge_type: str, perm_id: int = 0) -> sp.csr_matrix:
    """Load edge matrix for given edge type and permutation."""
    edge_file = data_dir / 'permutations' / f'{perm_id:03d}.hetmat' / 'edges' / f'{edge_type}.sparse.npz'
    return sp.load_npz(edge_file)

def filter_zero_degrees(matrix: sp.csr_matrix):
    """Remove nodes with zero degree."""
    # Get degrees
    source_degrees = np.array(matrix.sum(axis=1)).flatten()
    target_degrees = np.array(matrix.sum(axis=0)).flatten()
    
    # Find non-zero indices
    source_nonzero = np.where(source_degrees > 0)[0]
    target_nonzero = np.where(target_degrees > 0)[0]
    
    # Filter matrix
    filtered = matrix[source_nonzero, :][:, target_nonzero]
    
    return filtered, source_nonzero, target_nonzero

def analytical_prior(u: float, v: float, m: float) -> float:
    """Current analytical formula for edge probability."""
    uv = u * v
    denominator = np.sqrt(uv**2 + (m - u - v + 1)**2)
    return uv / denominator if denominator > 0 else 0.0

def compute_edge_priors_sparse(edge_matrix: sp.csr_matrix, prior_type: str = 'analytical'):
    """
    Compute prior probabilities ONLY for existing edges (sparse).
    
    Parameters
    ----------
    edge_matrix : sparse matrix
        Edge matrix (sources × targets)
    prior_type : str
        'analytical', 'learned', or 'empirical'
    
    Returns
    -------
    priors : dict
        {(source_idx, target_idx): probability}
    """
    n_sources, n_targets = edge_matrix.shape
    m = edge_matrix.nnz
    
    # Get degrees
    source_degrees = np.array(edge_matrix.sum(axis=1)).flatten()
    target_degrees = np.array(edge_matrix.sum(axis=0)).flatten()
    
    priors = {}
    
    if prior_type == 'analytical':
        # Analytical prior - only for existing edges
        rows, cols = edge_matrix.nonzero()
        for i, j in zip(rows, cols):
            u, v = source_degrees[i], target_degrees[j]
            if u > 0 and v > 0:
                priors[(i, j)] = analytical_prior(u, v, m)
    
    elif prior_type == 'learned':
        # TODO: Load learned formula and compute priors
        # For now, fallback to analytical
        print("WARNING: Learned prior not yet implemented, using analytical")
        return compute_edge_priors_sparse(edge_matrix, 'analytical')
    
    elif prior_type == 'empirical':
        # Empirical prior from permutations
        # TODO: Load from empirical frequency files
        print("WARNING: Empirical prior not yet implemented, using analytical")
        return compute_edge_priors_sparse(edge_matrix, 'analytical')
    
    return priors

## Load Hetionet Data (Observed)

In [None]:
# Load Hetionet (permutation 000)
print("Loading Hetionet edge matrices...\n")

edge1_matrix = load_edge_matrix(metapath[0], perm_id=0)
edge2_matrix = load_edge_matrix(metapath[1], perm_id=0)

# Filter zero degrees
edge1_filtered, edge1_source_map, edge1_target_map = filter_zero_degrees(edge1_matrix)
edge2_filtered, edge2_source_map, edge2_target_map = filter_zero_degrees(edge2_matrix)

print(f"{metapath[0]} (Compound → Gene):")
print(f"  Original shape: {edge1_matrix.shape}")
print(f"  Filtered shape: {edge1_filtered.shape}")
print(f"  Edges: {edge1_filtered.nnz}")
print(f"  Density: {edge1_filtered.nnz / (edge1_filtered.shape[0] * edge1_filtered.shape[1]):.6f}")

print(f"\n{metapath[1]} (Gene → Pathway):")
print(f"  Original shape: {edge2_matrix.shape}")
print(f"  Filtered shape: {edge2_filtered.shape}")
print(f"  Edges: {edge2_filtered.nnz}")
print(f"  Density: {edge2_filtered.nnz / (edge2_filtered.shape[0] * edge2_filtered.shape[1]):.6f}")

# Align gene dimensions - use ORIGINAL unfiltered matrices
# Both matrices should have the same gene dimension in the original data
print(f"\nAligning to common gene space (using original dimensions)...")
print(f"  CbG genes (columns): {edge1_matrix.shape[1]}")
print(f"  GpPW genes (rows): {edge2_matrix.shape[0]}")

# Use original matrices for metapath computation
assert edge1_matrix.shape[1] == edge2_matrix.shape[0], "Gene dimension mismatch in original matrices!"

# Filter only source (Compound) and target (Pathway) nodes with zero degree
# Keep all genes to maintain alignment
compound_degrees = np.array(edge1_matrix.sum(axis=1)).flatten()
pathway_degrees = np.array(edge2_matrix.sum(axis=0)).flatten()

compound_nonzero = np.where(compound_degrees > 0)[0]
pathway_nonzero = np.where(pathway_degrees > 0)[0]

# Filter only compounds and pathways, keep all genes
edge1_aligned = edge1_matrix[compound_nonzero, :]
edge2_aligned = edge2_matrix[:, pathway_nonzero]

n_compounds = edge1_aligned.shape[0]
n_genes = edge1_aligned.shape[1]
n_pathways = edge2_aligned.shape[1]

print(f"\nAligned matrices:")
print(f"  CbG: {edge1_aligned.shape} (Compounds × Genes)")
print(f"  GpPW: {edge2_aligned.shape} (Genes × Pathways)")
print(f"\nNode counts:")
print(f"  Compounds: {n_compounds}")
print(f"  Genes: {n_genes}")
print(f"  Pathways: {n_pathways}")

## Compute Metapath Frequencies (Observed)

In [None]:
# Compute metapath matrix: Compound → Gene → Pathway
print("Computing metapath matrix...")

# Metapath matrix = edge1 × edge2
metapath_matrix = edge1_aligned @ edge2_aligned

print(f"\nMetapath matrix {metapath_name}:")
print(f"  Shape: {metapath_matrix.shape} (Compounds × Pathways)")
print(f"  Metapaths: {metapath_matrix.nnz}")
print(f"  Max count: {metapath_matrix.max()}")
print(f"  Mean count (nonzero): {metapath_matrix.data.mean():.2f}")

# Compute observed frequencies
# For each (compound, pathway) pair, normalize by number of possible paths
print("\nComputing observed metapath frequencies...")

observed_freq = {}

for i, j in zip(*metapath_matrix.nonzero()):
    # Get compound i's gene neighbors
    compound_genes = edge1_aligned.getrow(i).nonzero()[1]
    
    # Get pathway j's gene neighbors  
    pathway_genes = edge2_aligned.getcol(j).nonzero()[0]
    
    # Count shared genes (actual metapaths)
    shared_genes = set(compound_genes) & set(pathway_genes)
    n_paths = len(shared_genes)
    
    # Total possible paths = all genes that connect compound to ANY pathway
    n_possible = len(compound_genes)
    
    if n_possible > 0:
        observed_freq[(i, j)] = n_paths / n_possible

print(f"Observed frequencies computed for {len(observed_freq)} (compound, pathway) pairs")

## Compute Compositional Probabilities (Prior)

In [None]:
print(f"Computing compositional probabilities using {PRIOR_TYPE} prior...\n")

# Compute priors for each edge type (sparse - only existing edges)
edge1_priors = compute_edge_priors_sparse(edge1_aligned, PRIOR_TYPE)
edge2_priors = compute_edge_priors_sparse(edge2_aligned, PRIOR_TYPE)

print(f"Computed {len(edge1_priors)} priors for {metapath[0]}")
print(f"Computed {len(edge2_priors)} priors for {metapath[1]}")

# Compute compositional metapath probabilities
# P(compound → pathway) = Σ_gene P(compound → gene) × P(gene → pathway)
print("\nComputing compositional metapath probabilities...")

compositional_prob = {}

for i in range(n_compounds):
    # Get genes connected to this compound
    compound_genes = edge1_aligned.getrow(i).nonzero()[1]
    
    for j in range(n_pathways):
        # Get genes connected to this pathway
        pathway_genes = edge2_aligned.getcol(j).nonzero()[0]
        
        # Sum over all possible intermediate genes
        total_prob = 0.0
        
        for gene in set(compound_genes) & set(pathway_genes):
            # Compositional assumption: independent
            p_edge1 = edge1_priors.get((i, gene), 0.0)
            p_edge2 = edge2_priors.get((gene, j), 0.0)
            total_prob += p_edge1 * p_edge2
        
        if total_prob > 0:
            compositional_prob[(i, j)] = total_prob

print(f"Compositional probabilities computed for {len(compositional_prob)} (compound, pathway) pairs")

## Compute Pointwise Mutual Information (PMI)

In [None]:
print("Computing Pointwise Mutual Information (PMI)...\n")

# Align observed and compositional probabilities
common_pairs = set(observed_freq.keys()) & set(compositional_prob.keys())

print(f"Common (compound, pathway) pairs: {len(common_pairs)}")

pmi_values = {}
observed_vals = []
compositional_vals = []

for pair in common_pairs:
    p_observed = observed_freq[pair]
    p_compositional = compositional_prob[pair]
    
    # PMI = log₂(P(observed) / P(compositional))
    if p_observed > 0 and p_compositional > 0:
        pmi = np.log2(p_observed / p_compositional)
        pmi_values[pair] = pmi
        observed_vals.append(p_observed)
        compositional_vals.append(p_compositional)

pmi_array = np.array(list(pmi_values.values()))

print(f"\nPMI Statistics:")
print(f"  Mean PMI: {pmi_array.mean():.4f}")
print(f"  Median PMI: {np.median(pmi_array):.4f}")
print(f"  Std PMI: {pmi_array.std():.4f}")
print(f"  Min PMI: {pmi_array.min():.4f}")
print(f"  Max PMI: {pmi_array.max():.4f}")

# Interpretation
print(f"\nInterpretation:")
if pmi_array.mean() > 0.5:
    print(f"  ✓ STRONG CONDITIONAL DEPENDENCY (mean PMI = {pmi_array.mean():.4f} > 0.5)")
    print(f"    → Metapath probabilities are NOT compositional")
    print(f"    → P(GpPW | CbG) ≠ P(GpPW)")
    print(f"    → Need learned/empirical priors")
elif pmi_array.mean() > 0.1:
    print(f"  → MODERATE CONDITIONAL DEPENDENCY (mean PMI = {pmi_array.mean():.4f})")
    print(f"    → Some conditional structure present")
elif abs(pmi_array.mean()) < 0.1:
    print(f"  ✓ APPROXIMATELY INDEPENDENT (mean PMI ≈ {pmi_array.mean():.4f} ≈ 0)")
    print(f"    → Metapath probabilities are compositional")
    print(f"    → P(GpPW | CbG) ≈ P(GpPW)")
    print(f"    → Analytical prior works well")
else:
    print(f"  → NEGATIVE ASSOCIATION (mean PMI = {pmi_array.mean():.4f} < 0)")
    print(f"    → Anti-correlation between edge types")

## Correlation Analysis

In [None]:
# Compute correlations
observed_vals = np.array(observed_vals)
compositional_vals = np.array(compositional_vals)

pearson_r, pearson_p = pearsonr(observed_vals, compositional_vals)
spearman_r, spearman_p = spearmanr(observed_vals, compositional_vals)

print("\n" + "="*80)
print("CORRELATION BETWEEN OBSERVED AND COMPOSITIONAL PROBABILITIES")
print("="*80)
print(f"\nPearson correlation: r = {pearson_r:.4f} (p = {pearson_p:.2e})")
print(f"Spearman correlation: ρ = {spearman_r:.4f} (p = {spearman_p:.2e})")

# Compute MAE and RMSE
mae = np.mean(np.abs(observed_vals - compositional_vals))
rmse = np.sqrt(np.mean((observed_vals - compositional_vals)**2))

print(f"\nMean Absolute Error: {mae:.6f}")
print(f"Root Mean Squared Error: {rmse:.6f}")

# Interpretation
print(f"\nInterpretation:")
if pearson_r > 0.9:
    print(f"  ✓ EXCELLENT FIT (r = {pearson_r:.4f} > 0.9)")
    print(f"    → Compositional model works very well")
    print(f"    → Analytical prior is appropriate")
elif pearson_r > 0.7:
    print(f"  → GOOD FIT (r = {pearson_r:.4f} > 0.7)")
    print(f"    → Compositional model captures most variance")
    print(f"    → But some conditional structure exists")
elif pearson_r > 0.5:
    print(f"  → MODERATE FIT (r = {pearson_r:.4f})")
    print(f"    → Significant conditional dependencies")
    print(f"    → Consider learned/empirical priors")
else:
    print(f"  ✗ POOR FIT (r = {pearson_r:.4f} < 0.5)")
    print(f"    → Strong conditional dependencies")
    print(f"    → Compositional assumption fails")
    print(f"    → Must use learned/empirical priors")

## Visualizations

In [None]:
# Create comprehensive visualization
fig, axes = plt.subplots(2, 2, figsize=(16, 14))

# Plot 1: Observed vs Compositional scatter
ax = axes[0, 0]
ax.scatter(compositional_vals, observed_vals, alpha=0.5, s=20, edgecolors='black', linewidths=0.5)
ax.plot([0, max(compositional_vals)], [0, max(compositional_vals)], 'r--', linewidth=2, label='Perfect composition')
ax.set_xlabel(f'Compositional Probability ({PRIOR_TYPE} prior)', fontsize=12)
ax.set_ylabel('Observed Frequency (Hetionet)', fontsize=12)
ax.set_title(f'{metapath_name}: Observed vs Compositional\nr = {pearson_r:.4f}', fontsize=14, fontweight='bold')
ax.legend()
ax.grid(alpha=0.3)

# Plot 2: PMI distribution
ax = axes[0, 1]
ax.hist(pmi_array, bins=50, edgecolor='black', alpha=0.7, color='steelblue')
ax.axvline(0, color='red', linestyle='--', linewidth=2, label='PMI = 0 (independent)')
ax.axvline(pmi_array.mean(), color='orange', linestyle='-', linewidth=2, label=f'Mean PMI = {pmi_array.mean():.2f}')
ax.set_xlabel('Pointwise Mutual Information (PMI)', fontsize=12)
ax.set_ylabel('Frequency', fontsize=12)
ax.set_title(f'{metapath_name}: PMI Distribution', fontsize=14, fontweight='bold')
ax.legend()
ax.grid(alpha=0.3)

# Plot 3: Residuals
ax = axes[1, 0]
residuals = observed_vals - compositional_vals
ax.scatter(compositional_vals, residuals, alpha=0.5, s=20, edgecolors='black', linewidths=0.5)
ax.axhline(0, color='red', linestyle='--', linewidth=2)
ax.set_xlabel(f'Compositional Probability ({PRIOR_TYPE} prior)', fontsize=12)
ax.set_ylabel('Residual (Observed - Compositional)', fontsize=12)
ax.set_title(f'{metapath_name}: Residuals\nMAE = {mae:.6f}, RMSE = {rmse:.6f}', fontsize=14, fontweight='bold')
ax.grid(alpha=0.3)

# Plot 4: PMI vs Compositional probability
ax = axes[1, 1]
pmi_list = [pmi_values[pair] for pair in common_pairs if pair in pmi_values]
ax.scatter(compositional_vals, pmi_list, alpha=0.5, s=20, edgecolors='black', linewidths=0.5, c=observed_vals, cmap='viridis')
ax.axhline(0, color='red', linestyle='--', linewidth=2, label='PMI = 0')
ax.set_xlabel(f'Compositional Probability ({PRIOR_TYPE} prior)', fontsize=12)
ax.set_ylabel('PMI', fontsize=12)
ax.set_title(f'{metapath_name}: PMI vs Compositional Probability', fontsize=14, fontweight='bold')
cbar = plt.colorbar(ax.collections[0], ax=ax)
cbar.set_label('Observed Frequency', fontsize=10)
ax.legend()
ax.grid(alpha=0.3)

plt.tight_layout()
plt.savefig(results_dir / f'metapath_{metapath_name}_compositionality_analysis.png', dpi=300, bbox_inches='tight')
plt.show()
print(f"\nSaved plot to: {results_dir / f'metapath_{metapath_name}_compositionality_analysis.png'}")

## Save Results

In [None]:
# Create results DataFrame
results_data = []

for pair in common_pairs:
    i, j = pair
    results_data.append({
        'compound_idx': i,
        'pathway_idx': j,
        'observed_freq': observed_freq[pair],
        'compositional_prob': compositional_prob[pair],
        'pmi': pmi_values.get(pair, np.nan),
        'residual': observed_freq[pair] - compositional_prob[pair]
    })

results_df = pd.DataFrame(results_data)

# Save to CSV
output_file = results_dir / f'metapath_{metapath_name}_compositionality_results.csv'
results_df.to_csv(output_file, index=False)
print(f"\nResults saved to: {output_file}")

# Save summary statistics
summary = {
    'metapath': metapath_name,
    'prior_type': PRIOR_TYPE,
    'n_pairs': len(common_pairs),
    'pearson_r': pearson_r,
    'pearson_p': pearson_p,
    'spearman_r': spearman_r,
    'spearman_p': spearman_p,
    'mae': mae,
    'rmse': rmse,
    'mean_pmi': pmi_array.mean(),
    'median_pmi': np.median(pmi_array),
    'std_pmi': pmi_array.std(),
    'min_pmi': pmi_array.min(),
    'max_pmi': pmi_array.max()
}

summary_df = pd.DataFrame([summary])
summary_file = results_dir / f'metapath_{metapath_name}_compositionality_summary.csv'
summary_df.to_csv(summary_file, index=False)
print(f"Summary saved to: {summary_file}")

## Conclusion and Recommendations

In [None]:
print("\n" + "="*80)
print("CONCLUSIONS")
print("="*80)

print(f"\nMetapath: {metapath_name}")
print(f"Prior type: {PRIOR_TYPE}")

print(f"\n1. COMPOSITIONALITY TEST:")
if abs(pmi_array.mean()) < 0.1:
    print(f"   ✓ COMPOSITIONAL (mean PMI ≈ 0)")
    print(f"   → Edge probabilities are approximately independent")
else:
    print(f"   ✗ NON-COMPOSITIONAL (mean PMI = {pmi_array.mean():.4f})")
    print(f"   → Edge probabilities exhibit conditional dependencies")

print(f"\n2. PRIOR PERFORMANCE:")
if pearson_r > 0.9:
    print(f"   ✓ EXCELLENT (r = {pearson_r:.4f})")
    print(f"   → {PRIOR_TYPE.capitalize()} prior works very well")
elif pearson_r > 0.7:
    print(f"   ✓ GOOD (r = {pearson_r:.4f})")
    print(f"   → {PRIOR_TYPE.capitalize()} prior captures most structure")
elif pearson_r > 0.5:
    print(f"   → MODERATE (r = {pearson_r:.4f})")
    print(f"   → {PRIOR_TYPE.capitalize()} prior has limitations")
else:
    print(f"   ✗ POOR (r = {pearson_r:.4f})")
    print(f"   → {PRIOR_TYPE.capitalize()} prior fails to capture structure")

print(f"\n3. RECOMMENDATIONS:")
if pearson_r > 0.9 and abs(pmi_array.mean()) < 0.1:
    print(f"   → Continue using {PRIOR_TYPE} prior")
    print(f"   → Compositional assumption is valid")
    print(f"   → No need for learned/empirical priors")
elif pearson_r > 0.7:
    print(f"   → {PRIOR_TYPE.capitalize()} prior is reasonable baseline")
    print(f"   → Consider testing learned/empirical priors for improvement")
    print(f"   → Some conditional structure could be captured")
else:
    print(f"   → Switch to learned/empirical priors")
    print(f"   → Strong conditional dependencies detected")
    print(f"   → Compositional assumption violated")

print(f"\n{'='*80}")
print("ANALYSIS COMPLETE")
print(f"{'='*80}")

## How to Swap Priors

To test with different priors, change the `PRIOR_TYPE` variable at the top:

```python
PRIOR_TYPE = 'analytical'  # Current analytical formula
PRIOR_TYPE = 'learned'     # Learned formula from notebook 8
PRIOR_TYPE = 'empirical'   # Empirical frequencies from N permutations
```

The `compute_edge_priors()` function can be extended to support:
- **Learned**: Load trained LearnedAnalyticalFormula and use its predictions
- **Empirical**: Load frequency files from `results/empirical_edge_frequencies/`
- **ML Model**: Load trained RF/LogReg model and use its predictions

This makes it easy to compare different prior sources!