# Edge Prediction Using Ensemble Models

This notebook applies the aggregated ensemble models to make predictions on permuted networks and provides comprehensive reporting of predictions both overall and by edge type.

## Overview

- **Input**: Aggregated ensemble models and permuted network data
- **Process**: Generate predictions for all edge types across all permutations
- **Output**: Comprehensive prediction reports with statistical analysis

## Methodology

1. **Model Loading**: Load pre-trained ensemble models for each model type
2. **Data Preparation**: Process permuted networks to extract features
3. **Prediction Generation**: Apply ensemble models to generate edge probabilities
4. **Analysis & Reporting**: Create detailed reports by edge type and overall statistics
5. **Visualization**: Generate comparative plots and heatmaps of predictions

In [None]:
# Import required libraries
import warnings
import pathlib
import sys
import json
import pickle
import glob
from collections import defaultdict
import re
from itertools import combinations

import numpy as np
import pandas as pd
import scipy.sparse
import torch
import torch.nn as nn
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# Set up paths
repo_dir = pathlib.Path().cwd().parent
src_dir = repo_dir / "src"
sys.path.insert(0, str(src_dir))

# Import custom modules
from models import EdgePredictionNN
from data_processing import load_permutation_data, prepare_edge_prediction_data

warnings.filterwarnings('ignore')
print("Libraries imported successfully")
print(f"Repository directory: {repo_dir}")

## Parameters

Configure the prediction process:

In [None]:
# Parameters for prediction
aggregated_models_dir = repo_dir / "aggregated_models"
permutations_dir = repo_dir / "data" / "permutations"
output_dir = repo_dir / "prediction_results"

# Edge types to analyze (if None, will discover all available)
target_edge_types = None  # or specify like ["AeG", "CbG", "DaG"]

# Permutations to analyze (if None, will use all available)
target_permutations = None  # or specify like ["000.hetmat", "001.hetmat"]

# Prediction threshold for binary classification
prediction_threshold = 0.5

# Create output directory
output_dir.mkdir(exist_ok=True)

print(f"Aggregated models directory: {aggregated_models_dir}")
print(f"Permutations directory: {permutations_dir}")
print(f"Output directory: {output_dir}")
print(f"Prediction threshold: {prediction_threshold}")

## Discover Available Models and Data

Find all available ensemble models and permuted network data:

In [None]:
def discover_ensemble_models(models_dir):
    """
    Discover all available ensemble models.
    """
    ensemble_files = {}
    
    if not models_dir.exists():
        print(f"Warning: Models directory {models_dir} does not exist")
        return ensemble_files
    
    # Find ensemble model files
    for model_file in models_dir.glob("ensemble_*.pkl"):
        # Parse filename: ensemble_{model_type}_{edge_type}.pkl
        parts = model_file.stem.split('_')
        if len(parts) >= 3:
            model_type = parts[1]
            edge_type = '_'.join(parts[2:])  # Handle edge types with underscores
            
            if edge_type not in ensemble_files:
                ensemble_files[edge_type] = {}
            
            ensemble_files[edge_type][model_type] = model_file
    
    return ensemble_files

def discover_permutations(permutations_dir):
    """
    Discover all available permutations.
    """
    if not permutations_dir.exists():
        print(f"Warning: Permutations directory {permutations_dir} does not exist")
        return []
    
    permutations = [p.name for p in permutations_dir.iterdir() 
                   if p.is_dir() and p.name.endswith('.hetmat')]
    return sorted(permutations)

# Discover available models and data
available_models = discover_ensemble_models(aggregated_models_dir)
available_permutations = discover_permutations(permutations_dir)

print("Available ensemble models:")
for edge_type, models in available_models.items():
    print(f"  {edge_type}: {list(models.keys())}")

print(f"\nAvailable permutations: {len(available_permutations)}")
for perm in available_permutations[:5]:  # Show first 5
    print(f"  {perm}")
if len(available_permutations) > 5:
    print(f"  ... and {len(available_permutations) - 5} more")

## Load Ensemble Models

Load the ensemble models for prediction:

In [None]:
def load_ensemble_models(available_models, target_edge_types=None):
    """
    Load ensemble models from disk.
    """
    loaded_models = {}
    
    edge_types_to_load = target_edge_types if target_edge_types else list(available_models.keys())
    
    for edge_type in edge_types_to_load:
        if edge_type not in available_models:
            print(f"Warning: No models found for edge type {edge_type}")
            continue
        
        loaded_models[edge_type] = {}
        
        for model_type, model_file in available_models[edge_type].items():
            try:
                with open(model_file, 'rb') as f:
                    model = pickle.load(f)
                
                loaded_models[edge_type][model_type] = model
                print(f"Loaded {model_type} ensemble for {edge_type}")
                
            except Exception as e:
                print(f"Error loading {model_type} for {edge_type}: {e}")
    
    return loaded_models

# Load ensemble models
ensemble_models = load_ensemble_models(available_models, target_edge_types)

print(f"\nLoaded ensemble models for {len(ensemble_models)} edge types")

## Define Edge Type Mappings

Map edge types to their corresponding node types:

In [None]:
# Define edge type to node type mappings
edge_type_mappings = {
    # Anatomy edges
    "AdG": ("Anatomy", "Gene"),
    "AeG": ("Anatomy", "Gene"),
    "AuG": ("Anatomy", "Gene"),
    
    # Compound edges
    "CbG": ("Compound", "Gene"),
    "CcSE": ("Compound", "Side Effect"),
    "CdG": ("Compound", "Gene"),
    "CpD": ("Compound", "Disease"),
    "CrC": ("Compound", "Compound"),
    "CtD": ("Compound", "Disease"),
    "CuG": ("Compound", "Gene"),
    
    # Disease edges
    "DaG": ("Disease", "Gene"),
    "DdG": ("Disease", "Gene"),
    "DlA": ("Disease", "Anatomy"),
    "DpS": ("Disease", "Symptom"),
    "DrD": ("Disease", "Disease"),
    "DuG": ("Disease", "Gene"),
    
    # Gene edges
    "GcG": ("Gene", "Gene"),
    "GiG": ("Gene", "Gene"),
    "GpBP": ("Gene", "Biological Process"),
    "GpCC": ("Gene", "Cellular Component"),
    "GpMF": ("Gene", "Molecular Function"),
    "GpPW": ("Gene", "Pathway"),
    "Gr>G": ("Gene", "Gene"),
    
    # Pharmacologic Class edges
    "PCiC": ("Pharmacologic Class", "Compound")
}

print(f"Defined mappings for {len(edge_type_mappings)} edge types")

## Generate Predictions

Apply ensemble models to generate predictions on permuted networks:

In [None]:
def generate_predictions_for_permutation(permutation_name, ensemble_models, edge_type_mappings):
    """
    Generate predictions for a single permutation across all edge types.
    """
    results = {}
    
    for edge_type, models in ensemble_models.items():
        if edge_type not in edge_type_mappings:
            print(f"Warning: No mapping found for edge type {edge_type}")
            continue
        
        source_type, target_type = edge_type_mappings[edge_type]
        
        try:
            # Load permutation data
            perm_data = load_permutation_data(
                permutations_dir,
                permutation_name,
                edge_type=edge_type,
                source_node_type=source_type,
                target_node_type=target_type
            )
            
            if not perm_data:
                print(f"Warning: Could not load data for {permutation_name}, {edge_type}")
                continue
            
            # Prepare features
            features, labels = prepare_edge_prediction_data(perm_data, sample_negative_ratio=1.0)
            
            # Generate predictions with each model type
            edge_results = {
                'features': features,
                'true_labels': labels,
                'predictions': {},
                'metadata': {
                    'num_samples': len(features),
                    'num_positive': labels.sum(),
                    'num_negative': len(labels) - labels.sum(),
                    'source_type': source_type,
                    'target_type': target_type
                }
            }
            
            for model_type, model in models.items():
                try:
                    # Generate predictions
                    pred_proba = model.predict_proba(features)
                    pred_binary = (pred_proba[:, 1] > prediction_threshold).astype(int)
                    
                    # Calculate metrics
                    auc_score = roc_auc_score(labels, pred_proba[:, 1])
                    ap_score = average_precision_score(labels, pred_proba[:, 1])
                    
                    edge_results['predictions'][model_type] = {
                        'probabilities': pred_proba[:, 1],
                        'binary_predictions': pred_binary,
                        'auc': auc_score,
                        'average_precision': ap_score,
                        'accuracy': (pred_binary == labels).mean()
                    }
                    
                except Exception as e:
                    print(f"Error predicting with {model_type} for {edge_type}: {e}")
            
            results[edge_type] = edge_results
            
        except Exception as e:
            print(f"Error processing {edge_type} for {permutation_name}: {e}")
    
    return results

print("Prediction function defined")

In [None]:
# Generate predictions for all permutations
all_predictions = {}

permutations_to_process = target_permutations if target_permutations else available_permutations

print(f"Generating predictions for {len(permutations_to_process)} permutations...")

for i, permutation in enumerate(tqdm(permutations_to_process, desc="Processing permutations")):
    print(f"\nProcessing permutation {i+1}/{len(permutations_to_process)}: {permutation}")
    
    perm_results = generate_predictions_for_permutation(
        permutation, ensemble_models, edge_type_mappings
    )
    
    all_predictions[permutation] = perm_results
    
    # Print summary for this permutation
    if perm_results:
        print(f"  Generated predictions for {len(perm_results)} edge types")
        for edge_type, results in perm_results.items():
            models_count = len(results['predictions'])
            print(f"    {edge_type}: {models_count} models, {results['metadata']['num_samples']} samples")

print(f"\nCompleted predictions for {len(all_predictions)} permutations")

## Prediction Analysis and Reporting

Analyze predictions and create comprehensive reports:

In [None]:
def analyze_predictions(all_predictions):
    """
    Analyze predictions across all permutations and edge types.
    """
    analysis = {
        'overall_stats': {},
        'by_edge_type': {},
        'by_model_type': {},
        'by_permutation': {}
    }
    
    # Collect all metrics
    all_aucs = []
    all_aps = []
    all_accuracies = []
    
    edge_type_stats = defaultdict(lambda: defaultdict(list))
    model_type_stats = defaultdict(lambda: defaultdict(list))
    permutation_stats = defaultdict(lambda: defaultdict(list))
    
    # Process all predictions
    for permutation, perm_results in all_predictions.items():
        for edge_type, edge_results in perm_results.items():
            for model_type, pred_results in edge_results['predictions'].items():
                auc = pred_results['auc']
                ap = pred_results['average_precision']
                acc = pred_results['accuracy']
                
                # Overall stats
                all_aucs.append(auc)
                all_aps.append(ap)
                all_accuracies.append(acc)
                
                # By edge type
                edge_type_stats[edge_type]['auc'].append(auc)
                edge_type_stats[edge_type]['ap'].append(ap)
                edge_type_stats[edge_type]['accuracy'].append(acc)
                
                # By model type
                model_type_stats[model_type]['auc'].append(auc)
                model_type_stats[model_type]['ap'].append(ap)
                model_type_stats[model_type]['accuracy'].append(acc)
                
                # By permutation
                permutation_stats[permutation]['auc'].append(auc)
                permutation_stats[permutation]['ap'].append(ap)
                permutation_stats[permutation]['accuracy'].append(acc)
    
    # Calculate overall statistics
    analysis['overall_stats'] = {
        'total_predictions': len(all_aucs),
        'auc': {
            'mean': np.mean(all_aucs),
            'std': np.std(all_aucs),
            'min': np.min(all_aucs),
            'max': np.max(all_aucs),
            'median': np.median(all_aucs)
        },
        'average_precision': {
            'mean': np.mean(all_aps),
            'std': np.std(all_aps),
            'min': np.min(all_aps),
            'max': np.max(all_aps),
            'median': np.median(all_aps)
        },
        'accuracy': {
            'mean': np.mean(all_accuracies),
            'std': np.std(all_accuracies),
            'min': np.min(all_accuracies),
            'max': np.max(all_accuracies),
            'median': np.median(all_accuracies)
        }
    }
    
    # Calculate statistics by category
    def calc_stats(values):
        return {
            'mean': np.mean(values),
            'std': np.std(values),
            'min': np.min(values),
            'max': np.max(values),
            'median': np.median(values),
            'count': len(values)
        }
    
    # By edge type
    for edge_type, metrics in edge_type_stats.items():
        analysis['by_edge_type'][edge_type] = {
            'auc': calc_stats(metrics['auc']),
            'ap': calc_stats(metrics['ap']),
            'accuracy': calc_stats(metrics['accuracy'])
        }
    
    # By model type
    for model_type, metrics in model_type_stats.items():
        analysis['by_model_type'][model_type] = {
            'auc': calc_stats(metrics['auc']),
            'ap': calc_stats(metrics['ap']),
            'accuracy': calc_stats(metrics['accuracy'])
        }
    
    # By permutation
    for permutation, metrics in permutation_stats.items():
        analysis['by_permutation'][permutation] = {
            'auc': calc_stats(metrics['auc']),
            'ap': calc_stats(metrics['ap']),
            'accuracy': calc_stats(metrics['accuracy'])
        }
    
    return analysis

# Perform analysis
prediction_analysis = analyze_predictions(all_predictions)

print("Prediction analysis completed")
print(f"Total predictions analyzed: {prediction_analysis['overall_stats']['total_predictions']}")
print(f"Overall mean AUC: {prediction_analysis['overall_stats']['auc']['mean']:.4f} ± {prediction_analysis['overall_stats']['auc']['std']:.4f}")
print(f"Overall mean AP: {prediction_analysis['overall_stats']['average_precision']['mean']:.4f} ± {prediction_analysis['overall_stats']['average_precision']['std']:.4f}")

## Performance Summary Tables

Create detailed performance tables:

In [None]:
# Create summary tables
def create_summary_tables(analysis):
    """
    Create summary tables for different groupings.
    """
    tables = {}
    
    # Performance by edge type
    edge_data = []
    for edge_type, stats in analysis['by_edge_type'].items():
        edge_data.append({
            'Edge Type': edge_type,
            'Mean AUC': f"{stats['auc']['mean']:.4f}",
            'Std AUC': f"{stats['auc']['std']:.4f}",
            'Mean AP': f"{stats['ap']['mean']:.4f}",
            'Std AP': f"{stats['ap']['std']:.4f}",
            'Mean Accuracy': f"{stats['accuracy']['mean']:.4f}",
            'Count': stats['auc']['count']
        })
    
    tables['by_edge_type'] = pd.DataFrame(edge_data).sort_values('Mean AUC', ascending=False)
    
    # Performance by model type
    model_data = []
    for model_type, stats in analysis['by_model_type'].items():
        model_data.append({
            'Model Type': model_type.replace('_', ' ').title(),
            'Mean AUC': f"{stats['auc']['mean']:.4f}",
            'Std AUC': f"{stats['auc']['std']:.4f}",
            'Mean AP': f"{stats['ap']['mean']:.4f}",
            'Std AP': f"{stats['ap']['std']:.4f}",
            'Mean Accuracy': f"{stats['accuracy']['mean']:.4f}",
            'Count': stats['auc']['count']
        })
    
    tables['by_model_type'] = pd.DataFrame(model_data).sort_values('Mean AUC', ascending=False)
    
    # Performance by permutation (top 10)
    perm_data = []
    for permutation, stats in analysis['by_permutation'].items():
        perm_data.append({
            'Permutation': permutation,
            'Mean AUC': f"{stats['auc']['mean']:.4f}",
            'Std AUC': f"{stats['auc']['std']:.4f}",
            'Mean AP': f"{stats['ap']['mean']:.4f}",
            'Std AP': f"{stats['ap']['std']:.4f}",
            'Mean Accuracy': f"{stats['accuracy']['mean']:.4f}",
            'Count': stats['auc']['count']
        })
    
    perm_df = pd.DataFrame(perm_data).sort_values('Mean AUC', ascending=False)
    tables['by_permutation'] = perm_df.head(10)  # Top 10 permutations
    
    return tables

summary_tables = create_summary_tables(prediction_analysis)

print("Performance by Edge Type:")
print(summary_tables['by_edge_type'].to_string(index=False))

print("\n\nPerformance by Model Type:")
print(summary_tables['by_model_type'].to_string(index=False))

print("\n\nTop 10 Permutations by AUC:")
print(summary_tables['by_permutation'].to_string(index=False))

## Visualization

Create comprehensive visualizations of prediction performance:

In [None]:
# Create comprehensive visualizations
def create_prediction_visualizations(analysis, summary_tables, output_dir):
    """
    Create various visualizations of prediction performance.
    """
    
    # Set style
    plt.style.use('default')
    sns.set_palette("husl")
    
    # Figure 1: Performance by Edge Type
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # AUC by edge type
    edge_types = summary_tables['by_edge_type']['Edge Type']
    aucs = [float(x) for x in summary_tables['by_edge_type']['Mean AUC']]
    auc_stds = [float(x) for x in summary_tables['by_edge_type']['Std AUC']]
    
    ax1 = axes[0, 0]
    bars1 = ax1.bar(edge_types, aucs, yerr=auc_stds, capsize=5, alpha=0.7)
    ax1.set_title('AUC Performance by Edge Type')
    ax1.set_ylabel('AUC Score')
    ax1.tick_params(axis='x', rotation=45)
    ax1.grid(True, alpha=0.3)
    
    # Add value labels
    for bar, auc in zip(bars1, aucs):
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                f'{auc:.3f}', ha='center', va='bottom', fontsize=8)
    
    # AP by edge type
    aps = [float(x) for x in summary_tables['by_edge_type']['Mean AP']]
    ap_stds = [float(x) for x in summary_tables['by_edge_type']['Std AP']]
    
    ax2 = axes[0, 1]
    bars2 = ax2.bar(edge_types, aps, yerr=ap_stds, capsize=5, alpha=0.7, color='orange')
    ax2.set_title('Average Precision by Edge Type')
    ax2.set_ylabel('Average Precision')
    ax2.tick_params(axis='x', rotation=45)
    ax2.grid(True, alpha=0.3)
    
    # Model type comparison
    model_types = summary_tables['by_model_type']['Model Type']
    model_aucs = [float(x) for x in summary_tables['by_model_type']['Mean AUC']]
    model_auc_stds = [float(x) for x in summary_tables['by_model_type']['Std AUC']]
    
    ax3 = axes[1, 0]
    bars3 = ax3.bar(model_types, model_aucs, yerr=model_auc_stds, capsize=5, alpha=0.7, color='green')
    ax3.set_title('AUC Performance by Model Type')
    ax3.set_ylabel('AUC Score')
    ax3.tick_params(axis='x', rotation=45)
    ax3.grid(True, alpha=0.3)
    
    # Overall distribution
    ax4 = axes[1, 1]
    ax4.hist([float(x) for x in summary_tables['by_edge_type']['Mean AUC']], 
             bins=10, alpha=0.7, label='AUC', color='skyblue')
    ax4.hist([float(x) for x in summary_tables['by_edge_type']['Mean AP']], 
             bins=10, alpha=0.7, label='AP', color='lightcoral')
    ax4.set_title('Distribution of Performance Metrics')
    ax4.set_xlabel('Score')
    ax4.set_ylabel('Frequency')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(output_dir / 'prediction_performance_summary.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Figure 2: Heatmap of performance by edge type and model type
    heatmap_data = []
    
    for edge_type in analysis['by_edge_type'].keys():
        row = []
        for model_type in analysis['by_model_type'].keys():
            # Find performance for this combination
            performances = []
            for perm_results in all_predictions.values():
                if edge_type in perm_results and model_type in perm_results[edge_type]['predictions']:
                    performances.append(perm_results[edge_type]['predictions'][model_type]['auc'])
            
            if performances:
                row.append(np.mean(performances))
            else:
                row.append(np.nan)
        
        heatmap_data.append(row)
    
    if heatmap_data:
        plt.figure(figsize=(12, 8))
        heatmap_df = pd.DataFrame(
            heatmap_data,
            index=list(analysis['by_edge_type'].keys()),
            columns=[mt.replace('_', ' ').title() for mt in analysis['by_model_type'].keys()]
        )
        
        sns.heatmap(heatmap_df, annot=True, fmt='.3f', cmap='viridis', 
                   cbar_kws={'label': 'AUC Score'})
        plt.title('AUC Performance Heatmap: Edge Type vs Model Type')
        plt.ylabel('Edge Type')
        plt.xlabel('Model Type')
        plt.tight_layout()
        plt.savefig(output_dir / 'performance_heatmap.png', dpi=300, bbox_inches='tight')
        plt.show()

# Create visualizations
create_prediction_visualizations(prediction_analysis, summary_tables, output_dir)
print(f"Visualizations saved to {output_dir}")

## Save Results

Save all prediction results and analysis:

In [None]:
# Save all results
def save_prediction_results(all_predictions, prediction_analysis, summary_tables, output_dir):
    """
    Save all prediction results and analysis to files.
    """
    
    # Save raw predictions (sampled due to size)
    predictions_summary = {}
    for permutation, perm_results in all_predictions.items():
        predictions_summary[permutation] = {}
        for edge_type, edge_results in perm_results.items():
            predictions_summary[permutation][edge_type] = {
                'metadata': edge_results['metadata'],
                'model_performance': {}
            }
            
            for model_type, pred_results in edge_results['predictions'].items():
                predictions_summary[permutation][edge_type]['model_performance'][model_type] = {
                    'auc': pred_results['auc'],
                    'average_precision': pred_results['average_precision'],
                    'accuracy': pred_results['accuracy'],
                    'num_predictions': len(pred_results['probabilities'])
                }
    
    # Save predictions summary
    with open(output_dir / 'predictions_summary.json', 'w') as f:
        json.dump(predictions_summary, f, indent=2)
    
    # Save analysis
    with open(output_dir / 'prediction_analysis.json', 'w') as f:
        json.dump(prediction_analysis, f, indent=2)
    
    # Save summary tables
    for table_name, table_df in summary_tables.items():
        table_df.to_csv(output_dir / f'{table_name}_summary.csv', index=False)
    
    # Create final report
    report = {
        'experiment_summary': {
            'total_permutations': len(all_predictions),
            'total_edge_types': len(prediction_analysis['by_edge_type']),
            'total_model_types': len(prediction_analysis['by_model_type']),
            'total_predictions': prediction_analysis['overall_stats']['total_predictions'],
            'prediction_threshold': prediction_threshold
        },
        'overall_performance': prediction_analysis['overall_stats'],
        'best_performing': {
            'edge_type': summary_tables['by_edge_type'].iloc[0]['Edge Type'],
            'model_type': summary_tables['by_model_type'].iloc[0]['Model Type'],
            'permutation': summary_tables['by_permutation'].iloc[0]['Permutation']
        }
    }
    
    with open(output_dir / 'final_report.json', 'w') as f:
        json.dump(report, f, indent=2)
    
    print(f"All results saved to {output_dir}")
    print("Files created:")
    for file_path in output_dir.glob('*'):
        print(f"  {file_path.name}")
    
    return report

# Save results
final_report = save_prediction_results(all_predictions, prediction_analysis, summary_tables, output_dir)

print("\n" + "="*60)
print("PREDICTION EXPERIMENT COMPLETED SUCCESSFULLY")
print("="*60)
print(f"Total predictions: {final_report['experiment_summary']['total_predictions']}")
print(f"Mean AUC: {final_report['overall_performance']['auc']['mean']:.4f}")
print(f"Mean AP: {final_report['overall_performance']['average_precision']['mean']:.4f}")
print(f"Best edge type: {final_report['best_performing']['edge_type']}")
print(f"Best model type: {final_report['best_performing']['model_type']}")
print("="*60)