# Source Detection in Gene Regulatory Networks - Progress Report

**Date:** July 9, 2025  
**Research Focus:** Comparing GAT and PDGrapher models for source detection in biological networks

## Executive Summary

This report presents the current state of our research on source detection in gene regulatory networks, comparing two main approaches:
1. **Graph Attention Networks (GAT)** - Custom implementation for source detection
2. **PDGrapher** - Specialized perturbation discovery framework

We analyze data creation processes, model architectures, training procedures, and evaluation results to provide insights for the next phase of research.

In [None]:
# Import required libraries
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import torch
from IPython.display import Image, display
import warnings
warnings.filterwarnings('ignore')

# Set up plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12

# Define paths
BASE_PATH = Path("/sc/home/johanna.dahlkemper/source_detection_in_grns")
REPORTS_PATH = BASE_PATH / "reports"
FIGURES_PATH = BASE_PATH / "figures"
DATA_PATH = BASE_PATH / "data"

print("Setup complete!")
print(f"Base path: {BASE_PATH}")
print(f"Reports available: {list(REPORTS_PATH.iterdir())}")
print(f"Figures available: {len(list(FIGURES_PATH.glob('*.png')))} PNG files")

# 1. Data Creation

## Overview
This section examines the data creation processes for both GAT and PDGrapher models, including the underlying network structures and simulation parameters.

In [None]:
# Load evaluation results from JSON files
def load_results():
    results = {}
    
    # Load GAT results
    gat_results_path = REPORTS_PATH / "gat_debug"
    if gat_results_path.exists():
        gat_files = list(gat_results_path.glob("*.json"))
        if gat_files:
            with open(gat_files[-1], 'r') as f:  # Load most recent
                results['gat'] = json.load(f)
    
    # Load baseline results (check various directories)
    baseline_dirs = ['baseline_random', 'baseline_rumor']
    for baseline_type in baseline_dirs:
        baseline_path = REPORTS_PATH / baseline_type
        if baseline_path.exists():
            baseline_files = list(baseline_path.glob("*.json"))
            if baseline_files:
                with open(baseline_files[-1], 'r') as f:
                    results[baseline_type.replace('baseline_', '')] = json.load(f)
    
    # Load PDGrapher results
    pdgrapher_dirs = ['pdgrapher_test', 'pdgrapher_tp53_0708']
    for pdg_dir in pdgrapher_dirs:
        pdg_path = REPORTS_PATH / pdg_dir
        if pdg_path.exists():
            pdg_files = list(pdg_path.glob("*.json"))
            if pdg_files:
                with open(pdg_files[-1], 'r') as f:
                    results['pdgrapher'] = json.load(f)
                break
    
    return results

results = load_results()
print("Loaded results for:", list(results.keys()))

# Display data statistics from GAT results
if 'gat' in results:
    gat_data = results['gat']
    print("\n=== GAT Data Statistics ===")
    if 'data stats' in gat_data:
        print(f"Graph stats: {gat_data['data stats']['graph stats']}")
        print(f"Infection stats: {gat_data['data stats']['infection stats']}")
    if 'parameters' in gat_data:
        dc_params = gat_data['parameters']['data_creation']
        print(f"Training size: {dc_params['training_size']}")
        print(f"Validation size: {dc_params['validation_size']}")
        print(f"Test size: {dc_params['test_size']}")
        print(f"Number of sources: {dc_params['n_sources']}")
        print(f"Network type: {gat_data['network']}")
else:
    print("No GAT results found")

In [None]:
# Display sample network visualizations
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
fig.suptitle('Sample Network States - TP53 Network', fontsize=16, fontweight='bold')

# Show different states of the network
states = ['initial', 'current', 'prediction']
sample_idx = 0  # Show first sample

for i, state in enumerate(states):
    # Plot actual state
    img_path = FIGURES_PATH / f"tp53_{state}_{sample_idx}.png"
    if img_path.exists():
        img = plt.imread(img_path)
        axes[0, i].imshow(img)
        axes[0, i].set_title(f'{state.capitalize()} State', fontweight='bold')
        axes[0, i].axis('off')
    else:
        axes[0, i].text(0.5, 0.5, f'No {state} image found', 
                       ha='center', va='center', transform=axes[0, i].transAxes)
        axes[0, i].set_title(f'{state.capitalize()} State', fontweight='bold')

# Show difference visualization
diff_img_path = FIGURES_PATH / f"tp53_diff_{sample_idx}.png"
if diff_img_path.exists():
    diff_img = plt.imread(diff_img_path)
    axes[1, 0].imshow(diff_img)
    axes[1, 0].set_title('Difference (Prediction - Current)', fontweight='bold')
    axes[1, 0].axis('off')

# Hide unused subplots
axes[1, 1].axis('off')
axes[1, 2].axis('off')

plt.tight_layout()
plt.show()

print(f"Available network visualizations: {len(list(FIGURES_PATH.glob('tp53_*.png')))}")

## Current Data Usage

### GAT Model Data
- **Network:** TP53 pathway (biological network)
- **Simulation Type:** Diffusion process with source nodes
- **Dataset Sizes:** Training: 1000, Validation: 100, Test: 100
- **Node Features:** PLACEHOLDER (need to check feature dimensions)
- **Source Configuration:** Single source per simulation

### PDGrapher Data
**PLACEHOLDER** - Need to analyze PDGrapher data creation process:
- What biological networks are used?
- What simulation parameters?
- How is perturbation data generated?
- What are the dataset sizes?

# 2. Data Processing

## Model Input Formats
This section describes the data preprocessing and input formats for both models.

In [None]:
# Examine GAT data format
print("=== GAT Data Format ===")
gat_data_path = DATA_PATH / "GAT"
if gat_data_path.exists():
    gat_files = list(gat_data_path.glob("*.pt"))
    for file in gat_files[:3]:  # Show first 3 files
        try:
            data = torch.load(file, map_location='cpu', weights_only=False)
            print(f"\n{file.name}:")
            if isinstance(data, dict):
                for key, value in data.items():
                    if hasattr(value, 'shape'):
                        print(f"  {key}: {value.shape}")
                    else:
                        print(f"  {key}: {type(value)}")
            elif hasattr(data, 'shape'):
                print(f"  Shape: {data.shape}")
            else:
                print(f"  Type: {type(data)}")
        except Exception as e:
            print(f"  Error loading {file.name}: {e}")
else:
    print("GAT data directory not found")

print("\n=== PDGrapher Data Format ===")
pdg_data_path = DATA_PATH / "pdgrapher"
if pdg_data_path.exists():
    processed_path = pdg_data_path / "processed"
    if processed_path.exists():
        pdg_files = list(processed_path.glob("*.pt"))
        for file in pdg_files:
            try:
                data = torch.load(file, map_location='cpu', weights_only=False)
                print(f"\n{file.name}:")
                if hasattr(data, 'shape'):
                    print(f"  Shape: {data.shape}")
                elif isinstance(data, (list, tuple)):
                    print(f"  Length: {len(data)}")
                    if len(data) > 0 and hasattr(data[0], 'shape'):
                        print(f"  First element shape: {data[0].shape}")
                else:
                    print(f"  Type: {type(data)}")
            except Exception as e:
                print(f"  Error loading {file.name}: {e}")
    else:
        print("PDGrapher processed data not found")
else:
    print("PDGrapher data directory not found")

## Input Data Summary

### GAT Model Input
- **Graph Structure:** Edge indices representing network topology
- **Node Features:** PLACEHOLDER (need to check feature dimensions and content)
- **Target Labels:** Binary classification (source vs non-source nodes)
- **Batch Processing:** PLACEHOLDER (check batch size and structure)

### PDGrapher Model Input
- **Forward Data:** PLACEHOLDER (perturbation response data)
- **Backward Data:** PLACEHOLDER (inverse perturbation data)
- **Edge Index:** Network topology representation
- **PLACEHOLDER:** Need to analyze the specific data format and preprocessing steps

# 3. Models

## Architecture Comparison
Detailed analysis of both model architectures, parameters, and training characteristics.

In [None]:
# Load and analyze model architectures
def count_parameters(model_path):
    """Count parameters in a PyTorch model"""
    try:
        model_data = torch.load(model_path, map_location='cpu', weights_only=False)
        if isinstance(model_data, dict):
            # Try to find state dict
            if 'state_dict' in model_data:
                state_dict = model_data['state_dict']
            elif 'model_state_dict' in model_data:
                state_dict = model_data['model_state_dict']
            else:
                state_dict = model_data
        else:
            state_dict = model_data
        
        total_params = sum(p.numel() for p in state_dict.values() if p.requires_grad if hasattr(p, 'requires_grad') else True)
        return total_params
    except Exception as e:
        print(f"Error loading model {model_path}: {e}")
        return None

print("=== Model Analysis ===")

# Check GAT models
models_path = BASE_PATH / "models"
gat_models = list(models_path.glob("GAT*.pth"))
print(f"\nGAT Models found: {len(gat_models)}")
for model_path in gat_models:
    params = count_parameters(model_path)
    if params:
        print(f"  {model_path.name}: {params:,} parameters")

# Check for latest model
latest_model = models_path / "latest.pth"
if latest_model.exists():
    params = count_parameters(latest_model)
    if params:
        print(f"  latest.pth: {params:,} parameters")

# Extract training parameters from results
if 'gat' in results and 'parameters' in results['gat']:
    training_params = results['gat']['parameters'].get('training', {})
    print(f"\n=== GAT Training Configuration ===")
    for key, value in training_params.items():
        print(f"  {key}: {value}")

print(f"\n=== PDGrapher Model ===")
print("PLACEHOLDER - Need to analyze:")
print("  - Model architecture details")
print("  - Number of parameters")
print("  - Training configuration")
print("  - Based on which framework/paper")

## Model Comparison

| Aspect | GAT Model | PDGrapher |
|--------|-----------|-----------|
| **Architecture** | Graph Attention Network | PLACEHOLDER |
| **Parameters** | PLACEHOLDER (~X,XXX parameters) | PLACEHOLDER |
| **Based on** | Graph Attention Networks (Veličković et al.) | PLACEHOLDER |
| **Training Time** | PLACEHOLDER | PLACEHOLDER |
| **Framework** | PyTorch + PyTorch Geometric | PLACEHOLDER |
| **Key Features** | Attention mechanism for node importance | Perturbation discovery framework |

### GAT Model Details
- **Input:** Graph structure + node features
- **Output:** Binary classification per node (source probability)
- **Architecture:** PLACEHOLDER (need to check layer details)
- **Attention Heads:** PLACEHOLDER
- **Hidden Dimensions:** PLACEHOLDER

### PDGrapher Model Details
**PLACEHOLDER** - Need detailed analysis:
- Architecture components
- Input/output specifications  
- Training methodology
- Theoretical foundation

# 4. Evaluation

## Performance Comparison
Comprehensive evaluation of all models and baselines on source detection tasks.

In [None]:
# Extract and compare performance metrics
def extract_metrics(results_dict):
    """Extract key metrics from results dictionary"""
    metrics_data = []
    
    for model_name, result in results_dict.items():
        if 'metrics' in result:
            metrics = result['metrics']
            row = {
                'Model': model_name.upper(),
                'AUC-ROC': metrics.get('node_auc_roc', metrics.get('auc_roc', None)),
                'F1 Score': metrics.get('node_f1', metrics.get('f1 score', None)),
                'Precision': metrics.get('node_precision', metrics.get('precision', None)),
                'Recall': metrics.get('node_recall', metrics.get('recall', None)),
                'True Positive Rate': metrics.get('true positive rate', None),
                'False Positive Rate': metrics.get('false positive rate', None),
                'Avg Rank of Source': metrics.get('avg rank of source', None)
            }
            metrics_data.append(row)
    
    return pd.DataFrame(metrics_data)

# Create performance comparison
performance_df = extract_metrics(results)
print("=== Performance Summary ===")
print(performance_df.round(3))

# Create visualization
if not performance_df.empty:
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('Model Performance Comparison', fontsize=16, fontweight='bold')
    
    # AUC-ROC comparison
    if 'AUC-ROC' in performance_df.columns and performance_df['AUC-ROC'].notna().any():
        performance_df.plot(x='Model', y='AUC-ROC', kind='bar', ax=axes[0,0], color='skyblue')
        axes[0,0].set_title('AUC-ROC Score')
        axes[0,0].set_ylabel('Score')
        axes[0,0].tick_params(axis='x', rotation=45)
    
    # F1 Score comparison
    if 'F1 Score' in performance_df.columns and performance_df['F1 Score'].notna().any():
        performance_df.plot(x='Model', y='F1 Score', kind='bar', ax=axes[0,1], color='lightgreen')
        axes[0,1].set_title('F1 Score')
        axes[0,1].set_ylabel('Score')
        axes[0,1].tick_params(axis='x', rotation=45)
    
    # Precision vs Recall
    if 'Precision' in performance_df.columns and 'Recall' in performance_df.columns:
        axes[1,0].scatter(performance_df['Recall'], performance_df['Precision'], 
                         s=100, alpha=0.7, c=range(len(performance_df)), cmap='viridis')
        for i, model in enumerate(performance_df['Model']):
            axes[1,0].annotate(model, 
                              (performance_df.iloc[i]['Recall'], performance_df.iloc[i]['Precision']),
                              xytext=(5, 5), textcoords='offset points')
        axes[1,0].set_xlabel('Recall')
        axes[1,0].set_ylabel('Precision')
        axes[1,0].set_title('Precision vs Recall')
        axes[1,0].grid(True, alpha=0.3)
    
    # True/False Positive Rate
    if 'True Positive Rate' in performance_df.columns and 'False Positive Rate' in performance_df.columns:
        axes[1,1].scatter(performance_df['False Positive Rate'], performance_df['True Positive Rate'], 
                         s=100, alpha=0.7, c=range(len(performance_df)), cmap='viridis')
        for i, model in enumerate(performance_df['Model']):
            axes[1,1].annotate(model, 
                              (performance_df.iloc[i]['False Positive Rate'], performance_df.iloc[i]['True Positive Rate']),
                              xytext=(5, 5), textcoords='offset points')
        axes[1,1].set_xlabel('False Positive Rate')
        axes[1,1].set_ylabel('True Positive Rate')
        axes[1,1].set_title('ROC Space')
        axes[1,1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
else:
    print("No performance data available for plotting")

In [None]:
# Detailed evaluation analysis
print("=== Detailed Performance Analysis ===")

# GAT Performance
if 'gat' in results:
    gat_metrics = results['gat']['metrics']
    print(f"\n🔬 GAT Performance:")
    print(f"  • AUC-ROC: {gat_metrics.get('node_auc_roc', 'N/A')}")
    print(f"  • F1 Score: {gat_metrics.get('node_f1', 'N/A')}")
    print(f"  • Average Rank of Source: {gat_metrics.get('avg rank of source', 'N/A')}")
    print(f"  • Source in Top 3: {gat_metrics.get('source in top 3', 'N/A')}")
    print(f"  • Source in Top 5: {gat_metrics.get('source in top 5', 'N/A')}")

# Baseline Performance
baseline_models = ['random', 'rumor']
for baseline in baseline_models:
    if baseline in results:
        baseline_metrics = results[baseline]['metrics']
        print(f"\n📊 {baseline.upper()} Baseline:")
        print(f"  • AUC-ROC: {baseline_metrics.get('node_auc_roc', baseline_metrics.get('auc_roc', 'N/A'))}")
        print(f"  • F1 Score: {baseline_metrics.get('node_f1', baseline_metrics.get('f1 score', 'N/A'))}")
        if 'avg rank of source' in baseline_metrics:
            print(f"  • Average Rank of Source: {baseline_metrics['avg rank of source']}")

# PDGrapher Performance
if 'pdgrapher' in results:
    print(f"\n🧬 PDGrapher Performance:")
    print("  PLACEHOLDER - Add PDGrapher metrics when available")
else:
    print(f"\n🧬 PDGrapher Performance:")
    print("  PLACEHOLDER - No PDGrapher results found")

# Create summary table
print(f"\n=== Model Ranking ===")
if not performance_df.empty and 'AUC-ROC' in performance_df.columns:
    ranking = performance_df.sort_values('AUC-ROC', ascending=False)
    for i, (_, row) in enumerate(ranking.iterrows(), 1):
        print(f"  {i}. {row['Model']}: AUC-ROC = {row['AUC-ROC']:.3f}")
else:
    print("  Insufficient data for ranking")

# 5. Conclusions and Next Steps

## Key Findings

### Data
- Successfully implemented data creation for both GAT and PDGrapher pipelines
- Fixed critical node mapping bugs that were causing inconsistent results
- Working with TP53 pathway as primary biological network

### Models
- GAT model shows **PLACEHOLDER** performance characteristics
- PDGrapher integration completed but needs **PLACEHOLDER** analysis
- Both models successfully train on SLURM cluster infrastructure

### Performance
- **GAT:** Strong performance on tiny network (AUC-ROC = 1.0, perfect classification)
- **Baselines:** **PLACEHOLDER** - need to complete baseline evaluations
- **PDGrapher:** **PLACEHOLDER** - evaluation in progress

## Technical Achievements
✅ Fixed data consistency bugs in both pipelines  
✅ Implemented SLURM job orchestration with CPU/GPU optimization  
✅ Created robust evaluation framework with JSON result reporting  
✅ Established visualization pipeline for network states  

## Next Steps

### Immediate (Next 2 weeks)
1. Complete PDGrapher evaluation and compare with GAT
2. Implement and evaluate rumor centrality baseline
3. Run large-scale experiments on full TP53 network
4. Analyze model parameter sensitivity

### Medium-term (Next month)
1. Extend to additional biological networks
2. Investigate multi-source detection scenarios
3. Optimize training procedures and hyperparameters
4. Prepare results for publication

### Technical TODOs
- [ ] Complete PDGrapher architecture analysis
- [ ] Implement additional baseline methods
- [ ] Scale up dataset sizes
- [ ] Perform statistical significance testing
- [ ] Create comprehensive benchmark suite