# Group 2 Claim 1: V-ATPase Subunit Analysis with PertPy
## Testing: "V-ATPase subunits show differential expression patterns"

This notebook analyzes V-ATPase subunit expression between tau-positive and tau-negative neurons.

In [None]:
import pertpy as pt
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import warnings
warnings.filterwarnings('ignore')

plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

## 1. Load Data and Define V-ATPase Subunits

In [None]:
# Load data
adata = sc.read_h5ad('../01_data_preparation/prepared_for_pertpy.h5ad')
print(f"Data shape: {adata.shape}")

# Define V-ATPase subunits (V0 and V1 domains)
vatpase_subunits = {
    'V0_domain': [
        'ATP6V0A1', 'ATP6V0A2', 'ATP6V0A4',
        'ATP6V0B', 'ATP6V0C', 'ATP6V0D1', 'ATP6V0D2',
        'ATP6V0E1', 'ATP6V0E2', 'ATP6AP1', 'ATP6AP2'
    ],
    'V1_domain': [
        'ATP6V1A', 'ATP6V1B1', 'ATP6V1B2',
        'ATP6V1C1', 'ATP6V1C2', 'ATP6V1D',
        'ATP6V1E1', 'ATP6V1E2', 'ATP6V1F',
        'ATP6V1G1', 'ATP6V1G2', 'ATP6V1G3', 'ATP6V1H'
    ]
}

# Flatten list for analysis
all_vatpase = [p for sublist in vatpase_subunits.values() for p in sublist]
print(f"Total V-ATPase subunits to analyze: {len(all_vatpase)}")

## 2. Find V-ATPase Proteins in Dataset

In [None]:
# Find V-ATPase proteins
protein_names = adata.var['protein_name'] if 'protein_name' in adata.var else adata.var.index
found_vatpase = []
missing_vatpase = []

for protein in all_vatpase:
    # Try exact match first
    if protein in protein_names.tolist():
        found_vatpase.append(protein)
    else:
        # Try partial match
        matches = [p for p in protein_names if protein in p or p in protein]
        if matches:
            found_vatpase.append(matches[0])
        else:
            # Try without ATP6 prefix
            short_name = protein.replace('ATP6', '')
            matches = [p for p in protein_names if short_name in p]
            if matches:
                found_vatpase.append(matches[0])
            else:
                missing_vatpase.append(protein)

print(f"Found V-ATPase proteins: {len(found_vatpase)}/{len(all_vatpase)}")
print(f"V0 domain proteins found: {sum(1 for p in found_vatpase if 'V0' in p)}")
print(f"V1 domain proteins found: {sum(1 for p in found_vatpase if 'V1' in p)}")

if missing_vatpase:
    print(f"\nMissing proteins ({len(missing_vatpase)}):")
    print(missing_vatpase[:5] if len(missing_vatpase) > 5 else missing_vatpase)

## 3. Subset Data and Run PyDESeq2

In [None]:
# Create V-ATPase subset
if found_vatpase:
    vatpase_indices = [i for i, p in enumerate(protein_names) if p in found_vatpase]
    adata_vatpase = adata[:, vatpase_indices].copy()
    
    print(f"V-ATPase subset shape: {adata_vatpase.shape}")
    
    # Run PyDESeq2
    try:
        # Use counts if available
        if 'counts' in adata_vatpase.layers:
            adata_vatpase.layers['log2'] = adata_vatpase.X.copy()
            adata_vatpase.X = adata_vatpase.layers['counts'].copy()
        
        # Initialize PyDESeq2
        pds2 = pt.tl.PyDESeq2(
            adata=adata_vatpase,
            design="~tau_status",
            refit_cooks=True
        )
        
        pds2.fit()
        
        # Test contrast
        results_vatpase = pds2.test_contrasts(
            pds2.contrast(
                column="tau_status",
                baseline="negative",
                group_to_compare="positive"
            )
        )
        
        print("✓ PyDESeq2 analysis completed")
        
    except Exception as e:
        print(f"PyDESeq2 failed: {e}")
        print("Using fallback analysis...")
        
        # Traditional DGE
        results_list = []
        tau_pos = adata_vatpase.obs['tau_status'] == 'positive'
        tau_neg = adata_vatpase.obs['tau_status'] == 'negative'
        
        for i in range(adata_vatpase.n_vars):
            expr_pos = adata_vatpase.X[tau_pos, i]
            expr_neg = adata_vatpase.X[tau_neg, i]
            
            log2fc = np.mean(expr_pos) - np.mean(expr_neg)
            tstat, pval = stats.ttest_ind(expr_pos, expr_neg)
            
            results_list.append({
                'protein': adata_vatpase.var.index[i],
                'log2FoldChange': log2fc,
                'pvalue': pval,
                'stat': tstat
            })
        
        results_vatpase = pd.DataFrame(results_list)
        
        # Add FDR correction
        from statsmodels.stats.multitest import multipletests
        results_vatpase['padj'] = multipletests(results_vatpase['pvalue'], method='fdr_bh')[1]
else:
    print("No V-ATPase proteins found for analysis")
    results_vatpase = pd.DataFrame()

## 4. Analyze Results by Domain

In [None]:
if len(results_vatpase) > 0:
    # Add domain annotation
    results_vatpase['domain'] = results_vatpase['protein'].apply(
        lambda x: 'V0' if 'V0' in x else 'V1' if 'V1' in x else 'Unknown'
    )
    
    # Summary statistics
    print("V-ATPase Differential Expression Summary:")
    print("="*50)
    
    # Overall statistics
    n_sig = (results_vatpase['padj'] < 0.05).sum()
    n_total = len(results_vatpase)
    
    print(f"Total V-ATPase subunits analyzed: {n_total}")
    print(f"Significant (FDR < 0.05): {n_sig} ({n_sig/n_total*100:.1f}%)")
    
    # Domain-specific analysis
    print("\nBy Domain:")
    for domain in ['V0', 'V1']:
        domain_df = results_vatpase[results_vatpase['domain'] == domain]
        if len(domain_df) > 0:
            n_sig_domain = (domain_df['padj'] < 0.05).sum()
            mean_fc = domain_df['log2FoldChange'].mean()
            print(f"  {domain} domain: {n_sig_domain}/{len(domain_df)} significant, mean log2FC = {mean_fc:.3f}")
    
    # Top differentially expressed
    print("\nTop 5 Differentially Expressed V-ATPase Subunits:")
    top_vatpase = results_vatpase.nsmallest(5, 'padj')
    for _, row in top_vatpase.iterrows():
        direction = "↑" if row['log2FoldChange'] > 0 else "↓"
        print(f"  {row['protein']} ({row['domain']}): {direction} log2FC={row['log2FoldChange']:.2f}, FDR={row['padj']:.3e}")

## 5. Create Volcano Plot

In [None]:
if len(results_vatpase) > 0:
    fig, ax = plt.subplots(figsize=(10, 8))
    
    # Calculate -log10(p-value)
    results_vatpase['neg_log10_pval'] = -np.log10(results_vatpase['pvalue'] + 1e-300)
    
    # Color by domain and significance
    colors = []
    for _, row in results_vatpase.iterrows():
        if row['padj'] < 0.05:
            if row['domain'] == 'V0':
                colors.append('red')
            else:
                colors.append('blue')
        else:
            colors.append('gray')
    
    # Create scatter plot
    scatter = ax.scatter(results_vatpase['log2FoldChange'],
                        results_vatpase['neg_log10_pval'],
                        c=colors, alpha=0.7, s=100)
    
    # Add threshold lines
    ax.axhline(y=-np.log10(0.05), color='black', linestyle='--', alpha=0.3)
    ax.axvline(x=0.5, color='black', linestyle='--', alpha=0.3)
    ax.axvline(x=-0.5, color='black', linestyle='--', alpha=0.3)
    
    # Label significant proteins
    for _, row in results_vatpase.iterrows():
        if row['padj'] < 0.05:
            ax.annotate(row['protein'],
                       (row['log2FoldChange'], row['neg_log10_pval']),
                       fontsize=8, alpha=0.8)
    
    ax.set_xlabel('Log2 Fold Change (Tau+ vs Tau-)', fontsize=12)
    ax.set_ylabel('-Log10(p-value)', fontsize=12)
    ax.set_title('V-ATPase Subunits Differential Expression\n(Red=V0 domain, Blue=V1 domain)',
                fontsize=14, fontweight='bold')
    
    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor='red', alpha=0.7, label='V0 domain (significant)'),
        Patch(facecolor='blue', alpha=0.7, label='V1 domain (significant)'),
        Patch(facecolor='gray', alpha=0.7, label='Not significant')
    ]
    ax.legend(handles=legend_elements, loc='upper left')
    
    plt.tight_layout()
    plt.savefig('vatpase_volcano_plot.png', dpi=300, bbox_inches='tight')
    plt.show()

## 6. Domain-Specific Heatmap

In [None]:
if len(results_vatpase) > 0 and len(found_vatpase) > 3:
    # Create expression matrix for heatmap
    vatpase_expr = pd.DataFrame()
    
    for protein in found_vatpase:
        protein_idx = list(protein_names).index(protein)
        vatpase_expr[protein] = adata.X[:, protein_idx]
    
    # Calculate z-scores
    from scipy.stats import zscore
    vatpase_zscore = vatpase_expr.apply(zscore)
    
    # Sort by tau status
    sort_idx = np.argsort(adata.obs['tau_positive'].values)
    vatpase_zscore_sorted = vatpase_zscore.iloc[sort_idx]
    
    # Create heatmap
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Create row colors for tau status
    tau_colors = ['blue' if x == 0 else 'red' for x in adata.obs['tau_positive'].iloc[sort_idx]]
    
    # Plot heatmap
    sns.heatmap(vatpase_zscore_sorted.T, cmap='RdBu_r', center=0,
               cbar_kws={'label': 'Z-score'},
               yticklabels=True, xticklabels=False,
               vmin=-2, vmax=2)
    
    # Add tau status bar
    for i, color in enumerate(tau_colors):
        ax.add_patch(plt.Rectangle((i, -0.5), 1, 0.5, color=color, alpha=0.7))
    
    ax.set_xlabel('Samples (sorted by tau status)', fontsize=12)
    ax.set_ylabel('V-ATPase Subunits', fontsize=12)
    ax.set_title('V-ATPase Expression Heatmap\n(Blue=Tau-, Red=Tau+)', fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig('vatpase_heatmap.png', dpi=300, bbox_inches='tight')
    plt.show()

## 7. Evaluate Claim

In [None]:
print("\n" + "="*60)
print("CLAIM EVALUATION")
print("="*60)
print("Claim: V-ATPase subunits show differential expression patterns")
print()

if len(results_vatpase) > 0:
    # Calculate evaluation metrics
    n_sig = (results_vatpase['padj'] < 0.05).sum()
    n_total = len(results_vatpase)
    percent_sig = n_sig / n_total * 100
    
    # Check for domain-specific patterns
    v0_df = results_vatpase[results_vatpase['domain'] == 'V0']
    v1_df = results_vatpase[results_vatpase['domain'] == 'V1']
    
    v0_sig = (v0_df['padj'] < 0.05).sum() if len(v0_df) > 0 else 0
    v1_sig = (v1_df['padj'] < 0.05).sum() if len(v1_df) > 0 else 0
    
    # Check for differential patterns (different directions)
    if n_sig > 0:
        sig_df = results_vatpase[results_vatpase['padj'] < 0.05]
        n_up = (sig_df['log2FoldChange'] > 0).sum()
        n_down = (sig_df['log2FoldChange'] < 0).sum()
        has_bidirectional = n_up > 0 and n_down > 0
    else:
        has_bidirectional = False
    
    print(f"Analysis Results:")
    print(f"  V-ATPase subunits tested: {n_total}")
    print(f"  Significantly changed: {n_sig} ({percent_sig:.1f}%)")
    print(f"  V0 domain significant: {v0_sig}/{len(v0_df) if len(v0_df) > 0 else 0}")
    print(f"  V1 domain significant: {v1_sig}/{len(v1_df) if len(v1_df) > 0 else 0}")
    print(f"  Bidirectional changes: {has_bidirectional}")
    print()
    
    # Determine verdict
    if n_sig >= 3 and percent_sig > 20:
        verdict = "SUPPORTED"
        explanation = f"Multiple V-ATPase subunits show differential expression ({n_sig}/{n_total})"
    elif n_sig >= 2:
        verdict = "PARTIALLY SUPPORTED"
        explanation = f"Some V-ATPase subunits differentially expressed ({n_sig}/{n_total})"
    elif n_sig == 1:
        verdict = "WEAKLY SUPPORTED"
        explanation = f"Only one V-ATPase subunit significantly changed"
    else:
        verdict = "REFUTED"
        explanation = "No significant differential expression in V-ATPase subunits"
else:
    verdict = "UNSURE"
    explanation = "V-ATPase proteins not found in dataset"

print(f"VERDICT: {verdict}")
print(f"Explanation: {explanation}")

# Additional evidence
if n_sig > 0:
    print("\nKey V-ATPase changes:")
    for _, row in results_vatpase.nsmallest(3, 'padj').iterrows():
        if row['padj'] < 0.05:
            direction = "upregulated" if row['log2FoldChange'] > 0 else "downregulated"
            print(f"  {row['protein']}: {direction} (log2FC={row['log2FoldChange']:.2f}, FDR={row['padj']:.3e})")

## 8. Save Results

In [None]:
# Save results
if len(results_vatpase) > 0:
    results_vatpase.to_csv('../05_statistical_reports/group2_claim1_vatpase.csv', index=False)
    print("Results saved to: ../05_statistical_reports/group2_claim1_vatpase.csv")
    
    # Save summary
    summary = {
        'claim': 'V-ATPase differential expression',
        'verdict': verdict,
        'n_proteins': n_total,
        'n_significant': n_sig,
        'percent_significant': percent_sig,
        'v0_significant': v0_sig,
        'v1_significant': v1_sig,
        'bidirectional_changes': has_bidirectional
    }
    
    summary_df = pd.DataFrame([summary])
    summary_df.to_csv('../05_statistical_reports/group2_claim1_summary.csv', index=False)
    
    print("\nSummary saved")