# Stage 4: Final Correlation Analysis

This notebook combines morphogen-regulon networks and calculates final correlations for publication.

**Input**: Morphogen-regulon networks from Stage 3
**Output**: Final correlation matrices, TF-target-morphogen relationships
**Method**: Correlation analysis, statistical testing, visualization

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import os
import warnings
warnings.filterwarnings('ignore')

# Set up plotting
plt.style.use('default')
sns.set_palette("husl")

print("📈 Stage 4: Final Correlation Analysis")
print("Combining morphogen-regulon networks for publication")

## 4.1 Load Network Data

Load the morphogen-regulon networks from Stage 3.

In [None]:
# Load combined network results
network_file = "../03_morphogen_networks/networks/morphogen_regulon_networks_combined.csv"

if os.path.exists(network_file):
    networks = pd.read_csv(network_file)
    print(f"Loaded combined networks: {networks.shape[0]} interactions")
    print(f"Columns: {list(networks.columns)}")
    print(f"\nSample data:")
    print(networks.head())
else:
    print(f"❌ Network file not found: {network_file}")
    print("Please run Stage 3 first")
    networks = None

In [None]:
# Load individual cell line networks for comparison
individual_networks = {}

for cell_line in ['H1', 'WTC', 'H9', 'WIBJ2']:
    network_file = f"../03_morphogen_networks/networks/morphogen_regulon_network_{cell_line}.csv"
    
    if os.path.exists(network_file):
        df = pd.read_csv(network_file)
        individual_networks[cell_line] = df
        print(f"Loaded {cell_line}: {df.shape[0]} interactions")
    else:
        print(f"❌ {cell_line} network not found")

print(f"\nLoaded networks for {len(individual_networks)} cell lines")

## 4.2 Calculate Correlations

Calculate correlations between morphogens and regulon activities.

In [None]:
def calculate_correlations(network_df, min_importance=0.1):
    """
    Calculate correlations from network importance scores.
    
    This follows your original methodology where network importance
    is used as a proxy for correlation strength.
    """
    
    # Filter by minimum importance
    significant = network_df[network_df['importance'] >= min_importance].copy()
    
    # Create correlation matrix
    correlations = []
    
    for _, row in significant.iterrows():
        correlations.append({
            'morphogen': row['TF'],
            'regulon': row['target'],
            'correlation': row['importance'],  # Using importance as correlation
            'cell_line': row.get('cell_line', 'unknown')
        })
    
    return pd.DataFrame(correlations)

# Calculate correlations for combined data
if networks is not None:
    print("Calculating correlations from combined networks...")
    correlations_combined = calculate_correlations(networks)
    print(f"Significant correlations: {len(correlations_combined)}")
    
    # Show top correlations
    top_corr = correlations_combined.nlargest(10, 'correlation')
    print("\nTop 10 morphogen-regulon correlations:")
    for _, row in top_corr.iterrows():
        print(f"  {row['morphogen']} -> {row['regulon']}: r={row['correlation']:.3f} ({row['cell_line']})")
else:
    correlations_combined = None

In [None]:
# Calculate correlations for individual cell lines
correlations_individual = {}

for cell_line, network_df in individual_networks.items():
    print(f"\nCalculating correlations for {cell_line}...")
    corr_df = calculate_correlations(network_df)
    corr_df['cell_line'] = cell_line
    correlations_individual[cell_line] = corr_df
    
    print(f"  Significant correlations: {len(corr_df)}")
    
    # Show top correlations
    if len(corr_df) > 0:
        top_corr = corr_df.nlargest(5, 'correlation')
        print(f"  Top correlations:")
        for _, row in top_corr.iterrows():
            print(f"    {row['morphogen']} -> {row['regulon']}: r={row['correlation']:.3f}")

## 4.3 Create Correlation Matrix

Create a comprehensive correlation matrix for visualization and analysis.

In [None]:
def create_correlation_matrix(correlations_df):
    """
    Create a correlation matrix from correlation data.
    """
    # Pivot to create matrix
    matrix = correlations_df.pivot_table(
        index='morphogen', 
        columns='regulon', 
        values='correlation',
        fill_value=0
    )
    
    return matrix

# Create correlation matrices
if correlations_combined is not None:
    print("Creating correlation matrix...")
    corr_matrix = create_correlation_matrix(correlations_combined)
    print(f"Matrix shape: {corr_matrix.shape}")
    print(f"Morphogens: {corr_matrix.shape[0]}")
    print(f"Regulons: {corr_matrix.shape[1]}")
    
    # Show matrix info
    non_zero = (corr_matrix != 0).sum().sum()
    total = corr_matrix.shape[0] * corr_matrix.shape[1]
    print(f"Non-zero correlations: {non_zero}/{total} ({100*non_zero/total:.1f}%)")
else:
    corr_matrix = None

## 4.4 Visualizations

Create publication-quality visualizations of the results.

In [None]:
# Create output directory for plots
os.makedirs("plots", exist_ok=True)

# Plot 1: Heatmap of correlation matrix
if corr_matrix is not None:
    plt.figure(figsize=(12, 8))
    
    # Only show top correlations for clarity
    top_morphogens = corr_matrix.max(axis=1).nlargest(20).index
    top_regulons = corr_matrix.max(axis=0).nlargest(20).index
    
    subset_matrix = corr_matrix.loc[top_morphogens, top_regulons]
    
    sns.heatmap(subset_matrix, cmap='RdYlBu_r', center=0, 
                cbar_kws={'label': 'Correlation'}, 
                xticklabels=True, yticklabels=True)
    
    plt.title('Top Morphogen-Regulon Correlations')
    plt.xlabel('Regulons')
    plt.ylabel('Morphogens')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig('plots/correlation_heatmap.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print("✅ Saved correlation heatmap")

In [None]:
# Plot 2: Distribution of correlations by cell line
if correlations_combined is not None:
    plt.figure(figsize=(10, 6))
    
    # Box plot of correlations by cell line
    sns.boxplot(data=correlations_combined, x='cell_line', y='correlation')
    plt.title('Distribution of Morphogen-Regulon Correlations by Cell Line')
    plt.xlabel('Cell Line')
    plt.ylabel('Correlation Strength')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig('plots/correlation_distribution.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print("✅ Saved correlation distribution plot")

In [None]:
# Plot 3: Network graph of top interactions
if correlations_combined is not None:
    # Get top interactions
    top_interactions = correlations_combined.nlargest(50, 'correlation')
    
    plt.figure(figsize=(12, 10))
    
    # Create a simple network visualization
    unique_morphogens = top_interactions['morphogen'].unique()
    unique_regulons = top_interactions['regulon'].unique()
    
    # Plot morphogens on left, regulons on right
    morph_y = np.linspace(0, 1, len(unique_morphogens))
    reg_y = np.linspace(0, 1, len(unique_regulons))
    
    # Plot nodes
    for i, morph in enumerate(unique_morphogens):
        plt.scatter(0, morph_y[i], s=100, c='red', alpha=0.7)
        plt.text(-0.05, morph_y[i], morph, ha='right', va='center', fontsize=8)
    
    for i, reg in enumerate(unique_regulons):
        plt.scatter(1, reg_y[i], s=100, c='blue', alpha=0.7)
        plt.text(1.05, reg_y[i], reg, ha='left', va='center', fontsize=8)
    
    # Plot edges
    for _, row in top_interactions.head(20).iterrows():  # Top 20 for clarity
        morph_idx = list(unique_morphogens).index(row['morphogen'])
        reg_idx = list(unique_regulons).index(row['regulon'])
        
        plt.plot([0, 1], [morph_y[morph_idx], reg_y[reg_idx]], 
                'k-', alpha=row['correlation'], linewidth=2*row['correlation'])
    
    plt.xlim(-0.3, 1.3)
    plt.ylim(-0.1, 1.1)
    plt.title('Top Morphogen-Regulon Network Interactions')
    plt.xlabel('Morphogens → Regulons')
    plt.axis('off')
    plt.tight_layout()
    plt.savefig('plots/network_graph.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print("✅ Saved network graph")

## 4.5 Save Final Results

Save all results for publication and further analysis.

In [None]:
# Save correlation results
if correlations_combined is not None:
    # Save combined correlations
    correlations_combined.to_csv('final_correlations_combined.csv', index=False)
    print("✅ Saved combined correlations")
    
    # Save correlation matrix
    if corr_matrix is not None:
        corr_matrix.to_csv('correlation_matrix.csv')
        print("✅ Saved correlation matrix")

# Save individual cell line correlations
for cell_line, corr_df in correlations_individual.items():
    corr_df.to_csv(f'final_correlations_{cell_line}.csv', index=False)
    print(f"✅ Saved correlations for {cell_line}")

# Create summary statistics
summary_stats = []

for cell_line, corr_df in correlations_individual.items():
    stats_dict = {
        'cell_line': cell_line,
        'total_interactions': len(corr_df),
        'unique_morphogens': corr_df['morphogen'].nunique(),
        'unique_regulons': corr_df['regulon'].nunique(),
        'mean_correlation': corr_df['correlation'].mean(),
        'max_correlation': corr_df['correlation'].max(),
        'min_correlation': corr_df['correlation'].min()
    }
    summary_stats.append(stats_dict)

summary_df = pd.DataFrame(summary_stats)
summary_df.to_csv('summary_statistics.csv', index=False)
print("✅ Saved summary statistics")

# Display summary
print("\n📊 Final Summary:")
print(summary_df.to_string(index=False))

print("\n🎉 Stage 4 Complete!")
print("\n📁 Output files:")
print("  - final_correlations_combined.csv")
print("  - correlation_matrix.csv")
print("  - final_correlations_[cellline].csv")
print("  - summary_statistics.csv")
print("  - plots/correlation_heatmap.png")
print("  - plots/correlation_distribution.png")
print("  - plots/network_graph.png")

print("\n🚀 Pipeline complete! Ready for publication.")