# Claim 3: Temporal Dynamics Analysis with PertPy
## Testing: "Temporal dynamics reveal progressive mitochondrial dysfunction"

This notebook analyzes how mitochondrial and proteostasis proteins change over pseudotime.

In [None]:
import pertpy as pt
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from sklearn.linear_model import LinearRegression
import warnings
warnings.filterwarnings('ignore')

plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('Set2')

## 1. Load Data and Define Protein Sets

In [None]:
# Load data
adata = sc.read_h5ad('../01_data_preparation/prepared_for_pertpy.h5ad')
print(f"Data shape: {adata.shape}")

# Check pseudotime availability
if 'pseudotime' not in adata.obs or adata.obs['pseudotime'].isna().all():
    print("ERROR: Pseudotime not available for temporal analysis")
else:
    print(f"Pseudotime range: {adata.obs['pseudotime'].min():.3f} to {adata.obs['pseudotime'].max():.3f}")

# Define temporal protein sets
temporal_proteins = {
    'early_response': ['HSP70', 'HSP90', 'HSPA1A', 'HSPA5', 'HSPB1'],
    'mitochondrial_complex_I': ['NDUFB8', 'NDUFA9', 'NDUFS3', 'NDUFV1'],
    'mitochondrial_complex_V': ['ATP5A1', 'ATP5B', 'ATP5C1', 'ATP5D'],
    'autophagy_early': ['BECN1', 'ATG5', 'ATG7', 'ATG12'],
    'autophagy_late': ['SQSTM1', 'NBR1', 'OPTN', 'TAX1BP1'],
    'proteasome': ['PSMA1', 'PSMB5', 'PSMC4', 'PSMD1']
}

## 2. Temporal Correlation Analysis

In [None]:
# Analyze temporal correlations for each protein set
temporal_results = []

protein_names = adata.var['protein_name'] if 'protein_name' in adata.var else adata.var.index

for set_name, proteins in temporal_proteins.items():
    print(f"\nAnalyzing {set_name}...")
    
    for protein in proteins:
        # Find protein in dataset
        matches = [p for p in protein_names if protein in p.upper()]
        
        if matches:
            protein_idx = list(protein_names).index(matches[0])
            expr = adata.X[:, protein_idx]
            
            # Remove NaN values
            valid_mask = ~(np.isnan(expr) | adata.obs['pseudotime'].isna())
            
            if valid_mask.sum() > 10:
                # Correlation with pseudotime
                corr, pval = stats.spearmanr(
                    adata.obs.loc[valid_mask, 'pseudotime'],
                    expr[valid_mask]
                )
                
                # Linear regression for slope
                X = adata.obs.loc[valid_mask, 'pseudotime'].values.reshape(-1, 1)
                y = expr[valid_mask]
                lr = LinearRegression()
                lr.fit(X, y)
                slope = lr.coef_[0]
                
                temporal_results.append({
                    'protein': matches[0],
                    'set': set_name,
                    'correlation': corr,
                    'p_value': pval,
                    'slope': slope,
                    'direction': 'increasing' if slope > 0 else 'decreasing'
                })

temporal_df = pd.DataFrame(temporal_results)

# Apply FDR correction
from statsmodels.stats.multitest import multipletests
if len(temporal_df) > 0:
    temporal_df['p_adjusted'] = multipletests(temporal_df['p_value'], method='fdr_bh')[1]
    temporal_df['significant'] = temporal_df['p_adjusted'] < 0.05
    
    print("\nTemporal Analysis Summary:")
    print(temporal_df.groupby('set')[['correlation', 'significant']].agg({
        'correlation': 'mean',
        'significant': 'sum'
    }))

## 3. PyDESeq2 with Pseudotime as Continuous Covariate

In [None]:
# Select proteins showing temporal changes
if len(temporal_df) > 0 and temporal_df['significant'].sum() > 0:
    significant_temporal = temporal_df[temporal_df['significant']]['protein'].tolist()
    
    # Subset to these proteins
    protein_indices = [i for i, p in enumerate(protein_names) if p in significant_temporal]
    
    if len(protein_indices) > 0:
        adata_temporal = adata[:, protein_indices].copy()
        
        print(f"Running PyDESeq2 on {len(protein_indices)} temporally dynamic proteins...")
        
        try:
            # Use counts if available
            if 'counts' in adata_temporal.layers:
                adata_temporal.layers['log2'] = adata_temporal.X.copy()
                adata_temporal.X = adata_temporal.layers['counts'].copy()
            
            # Model with pseudotime interaction
            pds2 = pt.tl.PyDESeq2(
                adata=adata_temporal,
                design="~tau_status * pseudotime",
                refit_cooks=True
            )
            
            pds2.fit()
            
            # Test interaction effect
            interaction_results = pds2.test_contrasts(
                ["tau_status[T.positive]:pseudotime"]
            )
            
            print("\nâœ“ PyDESeq2 with interaction term completed")
            print("\nProteins with significant tau:pseudotime interaction:")
            sig_interaction = interaction_results[interaction_results['padj'] < 0.05]
            print(sig_interaction[['protein', 'log2FoldChange', 'padj']].head())
            
        except Exception as e:
            print(f"PyDESeq2 failed: {e}")
            interaction_results = None
else:
    print("No significant temporal proteins found for interaction analysis")

## 4. Visualization of Temporal Dynamics

In [None]:
# Create temporal visualization
if len(temporal_df) > 0:
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    # Plot top proteins from each category
    for idx, (set_name, set_df) in enumerate(temporal_df.groupby('set')):
        if idx >= 6:
            break
            
        ax = axes[idx]
        
        # Get most significant protein from this set
        if len(set_df) > 0:
            top_protein = set_df.nsmallest(1, 'p_value').iloc[0]
            protein_idx = list(protein_names).index(top_protein['protein'])
            expr = adata.X[:, protein_idx]
            
            # Plot with tau status coloring
            colors = ['red' if x == 1 else 'blue' for x in adata.obs['tau_positive']]
            
            ax.scatter(adata.obs['pseudotime'], expr, c=colors, alpha=0.5, s=20)
            
            # Add trend line
            valid_mask = ~(np.isnan(expr) | adata.obs['pseudotime'].isna())
            if valid_mask.sum() > 10:
                z = np.polyfit(adata.obs.loc[valid_mask, 'pseudotime'], expr[valid_mask], 1)
                p = np.poly1d(z)
                x_line = np.linspace(adata.obs['pseudotime'].min(), adata.obs['pseudotime'].max(), 100)
                ax.plot(x_line, p(x_line), 'g--', alpha=0.8, linewidth=2)
            
            ax.set_title(f"{set_name}\n{top_protein['protein']} (r={top_protein['correlation']:.2f})")
            ax.set_xlabel('Pseudotime')
            ax.set_ylabel('Expression (log2)')
    
    plt.suptitle('Temporal Dynamics of Key Protein Sets', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig('temporal_dynamics.png', dpi=300, bbox_inches='tight')
    plt.show()

## 5. Phase Detection Analysis

In [None]:
# Identify temporal phases
if 'pseudotime' in adata.obs:
    # Define phases based on pseudotime tertiles
    tertiles = adata.obs['pseudotime'].quantile([0.33, 0.67])
    
    adata.obs['phase'] = pd.cut(
        adata.obs['pseudotime'],
        bins=[-np.inf, tertiles.iloc[0], tertiles.iloc[1], np.inf],
        labels=['Early', 'Middle', 'Late']
    )
    
    print("Phase distribution:")
    print(adata.obs['phase'].value_counts())
    
    # Analyze protein changes by phase
    phase_results = []
    
    for set_name, proteins in temporal_proteins.items():
        set_expression = []
        
        for protein in proteins:
            matches = [p for p in protein_names if protein in p.upper()]
            if matches:
                protein_idx = list(protein_names).index(matches[0])
                expr = adata.X[:, protein_idx]
                set_expression.append(expr)
        
        if set_expression:
            # Calculate mean expression for the set
            mean_expr = np.nanmean(set_expression, axis=0)
            
            # Compare phases
            for phase in ['Early', 'Middle', 'Late']:
                phase_mask = adata.obs['phase'] == phase
                phase_mean = np.nanmean(mean_expr[phase_mask])
                
                phase_results.append({
                    'protein_set': set_name,
                    'phase': phase,
                    'mean_expression': phase_mean
                })
    
    phase_df = pd.DataFrame(phase_results)
    
    # Pivot for heatmap
    phase_pivot = phase_df.pivot(index='protein_set', columns='phase', values='mean_expression')
    
    # Create heatmap
    plt.figure(figsize=(8, 6))
    sns.heatmap(phase_pivot, annot=True, fmt='.2f', cmap='RdBu_r', center=0)
    plt.title('Protein Set Expression by Disease Phase')
    plt.ylabel('Protein Set')
    plt.xlabel('Disease Phase')
    plt.tight_layout()
    plt.savefig('phase_heatmap.png', dpi=300, bbox_inches='tight')
    plt.show()

## 6. Evaluate Claim

In [None]:
print("\n" + "="*60)
print("CLAIM EVALUATION")
print("="*60)
print("Claim: Temporal dynamics reveal progressive mitochondrial dysfunction")
print()

# Analyze mitochondrial proteins specifically
if len(temporal_df) > 0:
    mito_sets = ['mitochondrial_complex_I', 'mitochondrial_complex_V']
    mito_df = temporal_df[temporal_df['set'].isin(mito_sets)]
    
    if len(mito_df) > 0:
        # Calculate summary statistics
        n_decreasing = (mito_df['direction'] == 'decreasing').sum()
        n_increasing = (mito_df['direction'] == 'increasing').sum()
        n_significant = mito_df['significant'].sum()
        mean_corr = mito_df['correlation'].mean()
        
        print(f"Mitochondrial Protein Analysis:")
        print(f"  Total analyzed: {len(mito_df)}")
        print(f"  Decreasing over time: {n_decreasing}")
        print(f"  Increasing over time: {n_increasing}")
        print(f"  Significantly correlated: {n_significant}")
        print(f"  Mean correlation: {mean_corr:.3f}")
        print()
        
        # Determine verdict
        if n_decreasing > n_increasing and n_significant > len(mito_df) * 0.3:
            verdict = "SUPPORTED"
            explanation = f"{n_decreasing}/{len(mito_df)} mitochondrial proteins decrease over pseudotime"
        elif n_significant > 0:
            verdict = "PARTIALLY SUPPORTED"
            explanation = f"Some temporal changes detected ({n_significant}/{len(mito_df)} significant)"
        else:
            verdict = "REFUTED"
            explanation = "No significant temporal dynamics in mitochondrial proteins"
    else:
        verdict = "UNSURE"
        explanation = "Insufficient mitochondrial proteins found for analysis"
else:
    verdict = "UNSURE"
    explanation = "Temporal analysis could not be performed"

print(f"VERDICT: {verdict}")
print(f"Explanation: {explanation}")

# Additional evidence from phases
if 'phase_pivot' in locals():
    print("\nPhase-based Evidence:")
    for set_name in mito_sets:
        if set_name in phase_pivot.index:
            early = phase_pivot.loc[set_name, 'Early']
            late = phase_pivot.loc[set_name, 'Late']
            change = ((late - early) / early * 100) if early != 0 else 0
            print(f"  {set_name}: {change:.1f}% change from early to late phase")

## 7. Save Results

In [None]:
# Save temporal analysis results
if len(temporal_df) > 0:
    temporal_df.to_csv('../05_statistical_reports/claim3_temporal_dynamics.csv', index=False)
    print("Temporal results saved")

# Create summary
summary = {
    'claim': 'Progressive mitochondrial dysfunction over time',
    'verdict': verdict,
    'proteins_analyzed': len(temporal_df) if len(temporal_df) > 0 else 0,
    'significant_temporal': temporal_df['significant'].sum() if len(temporal_df) > 0 else 0,
    'mito_decreasing': n_decreasing if 'n_decreasing' in locals() else 0,
    'mito_increasing': n_increasing if 'n_increasing' in locals() else 0
}

print("\nSummary:")
for key, value in summary.items():
    print(f"  {key}: {value}")