In [4]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import re
from pathlib import Path
import numpy as np
import warnings
from scipy.stats import pearsonr, spearmanr
# Suppress warnings
warnings.filterwarnings('ignore')

# File paths
EXPR_PATH = "/nfs/nas12.ethz.ch/fs1201/green_groups_let_public/Euler/Vakil/Mouse_brain_Sept2024/gene_annotation2/CRS_Morning_Evening_TPM_rearr.csv"
CPG_OXID = "../data_normalized/cpg_Normalized.csv"
BIN_OXID = "../data_normalized/cleaned_Normalized_1000.csv"
LIST = [
    "../data_anova/cpg/cpg_overlap_genes_result.bed",
    "../data_anova/bin1000/overlap_genes_result.bed",
    "../data_anova/bin1000/overlap_promoters_result.bed"
]

# Create output directory for plots
output_dir = Path("../images/gene_correlation_plots")
output_dir.mkdir(exist_ok=True)

# Function to extract group from sample name
def extract_group_expr(sample):
    # Expected format: '20_CRS_evening'
    parts = sample.split('_')
    if len(parts) >= 3:
        return f"{parts[1]}_{parts[2]}"
    return "Unknown"

def extract_group_oxid(sample):
    # Expected format: 'Sample_14_CRS_evening_S14_'
    parts = sample.split('_')
    if len(parts) >= 5:
        return f"{parts[2]}_{parts[3]}"
    return "Unknown"

# Load data
def load_data():
    print("Loading expression data...")
    expr_df = pd.read_csv(EXPR_PATH)
    
    print("Loading CpG oxidation data...")
    cpg_df = pd.read_csv(CPG_OXID)
    
    print("Loading bin oxidation data...")
    bin_df = pd.read_csv(BIN_OXID)
    
    # Process expression data
    # Add group column based on Sample column
    expr_df['Group'] = expr_df['Sample'].apply(extract_group_expr)
    print (expr_df.head(5))
    
    # Process oxidation data
    # Add group column based on sample name
    cpg_df['Group'] = cpg_df['sample'].apply(extract_group_oxid)
    bin_df['Group'] = bin_df['Sample'].apply(extract_group_oxid)
    
    return expr_df, cpg_df, bin_df

# Load gene lists from BED files
def load_gene_lists():
    all_genes = []
    for list_file in LIST:
        try:
            print(f"Loading genes from {list_file}...")
            # Read BED file, assuming genes are in the 10th column (index 9)
            df = pd.read_csv(list_file, sep='\t', header=None)
            
            # Get genes from the 10th column
            genes = df.iloc[:, 9].tolist()
            ids = df.iloc[:, 3].tolist()
            
            # Keep track of which file each gene came from
            file_name = os.path.basename(list_file)
            gene_data = [(gene, id_val, file_name) for gene, id_val in zip(genes, ids)]
            all_genes.extend(gene_data)
      
            
            print(f"Found {len(genes)} unique genes in {file_name}")
        except Exception as e:
            print(f"Error loading {list_file}: {e}")
    
    return all_genes

# Function to calculate both correlations with p-values
def calculate_correlations(x, y):
    """Calculate both Pearson and Spearman correlations with p-values"""
    # Remove NaN values
    mask = ~(np.isnan(x) | np.isnan(y))
    x_clean = x[mask]
    y_clean = y[mask]
    
    if len(x_clean) < 3:  # Need at least 3 points for meaningful correlation
        return None, None, None, None
    
    # Calculate Pearson correlation
    pearson_r, pearson_p = pearsonr(x_clean, y_clean)
    
    # Calculate Spearman correlation
    spearman_r, spearman_p = spearmanr(x_clean, y_clean)
    
    return pearson_r, pearson_p, spearman_r, spearman_p

# Function to create scatter plot for a gene with both correlations
def plot_gene_correlation(gene, ID, source_file, expr_df, cpg_df, bin_df):
    print(f"Processing gene: {gene} with id: {ID} from {source_file}")
    
    # Determine which oxidation data to use based on source file
    if "cpg" in source_file:
        oxid_df = cpg_df
        oxid_col = "median_normalized_damage"
        oxid_df['bin_id'] = oxid_df['id'].astype(str) + oxid_df['strand']
        oxid_id_col = "bin_id"
        oxid_type = "CpG"
    else:  # bin1000
        oxid_df = bin_df
        oxid_col = "Median_Normalized_Damage"
        oxid_df['bin_id'] = oxid_df['Bin'].astype(str) + '_' + oxid_df['Strand'] + oxid_df['Chromosome'].astype(str)
        oxid_id_col = "bin_id"
        oxid_type = "Bin1000"
    
    
    gene_expr = expr_df[expr_df['Gene'] == gene]

    if gene_expr.empty:
        print(f"No expression data found for gene {gene}")
        return None
    
    # Find corresponding oxidation data entries
    gene_oxid = oxid_df[oxid_df[oxid_id_col] == ID]

    if gene_oxid.empty:
        print(f"No oxidation data found for gene {gene}")
        return None
    
    # Prepare data for plotting by matching samples
    plot_data = []
    
    for group in ['CRS_evening', 'CRS_morning', 'Ctrl_evening', 'Ctrl_morning']:
        # Get expression values for this group
        group_expr = gene_expr[gene_expr['Group'] == group]
        
        # Get oxidation values for this group
        group_oxid = gene_oxid[gene_oxid['Group'] == group]
        
        # Combine data from both sources
        for i in range(min(len(group_expr), len(group_oxid))):
            try:
                expr_value = group_expr.iloc[i]['Expression_level']
                oxid_value = group_oxid.iloc[i][oxid_col]
                
                plot_data.append({
                    'Expression': expr_value,
                    'Oxidation': oxid_value,
                    'Group': group
                })
            except IndexError:
                continue
    
    # Create DataFrame from collected data
    if not plot_data:
        print(f"No matching data points found for gene {gene}")
        return None
    
    plot_df = pd.DataFrame(plot_data)
    
    # Calculate both correlations
    expr_values = plot_df['Expression'].values
    oxid_values = plot_df['Oxidation'].values
    
    pearson_r, pearson_p, spearman_r, spearman_p = calculate_correlations(expr_values, oxid_values)
    
    if pearson_r is None:
        print(f"Insufficient data points for correlation analysis for gene {gene}")
        return None
    
    # Create plot
    plt.figure(figsize=(12, 8))
    
    # Create scatter plot with color-coded groups
    sns.scatterplot(
        data=plot_df, 
        x='Expression', 
        y='Oxidation', 
        hue='Group', 
        palette='tab10',
        s=100,
        alpha=0.7
    )
    
    # Add trend line for all data
    if len(plot_df) > 1:
        sns.regplot(
            data=plot_df, 
            x='Expression', 
            y='Oxidation', 
            scatter=False, 
            line_kws={"color": "red", "linestyle": "--", "alpha": 0.7}
        )
    
    # Create title with both correlations
    title = f'Gene: {gene} | Oxidation Type: {oxid_type}\n'
    title += f'Pearson r = {pearson_r:.3f} (p = {pearson_p:.3f}) | '
    title += f'Spearman ρ = {spearman_r:.3f} (p = {spearman_p:.3f})'
    
    plt.title(title, fontsize=12, pad=20)
    plt.xlabel('Gene Expression Level (TPM)', fontsize=12)
    plt.ylabel(f'Oxidation Level ({oxid_col})', fontsize=12)
    

    # Add legend
    plt.legend(title='Group', fontsize=10)
    
    # Add grid
    plt.grid(True, linestyle='--', alpha=0.3)
    
    # Adjust layout to prevent text cutoff
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.15)
    
    # Save plot
    output_path = output_dir / f"{gene}_{oxid_type}_correlations.png"
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Plot saved to {output_path}")
    print(f"Correlations - Pearson: {pearson_r:.3f} (p={pearson_p:.3f}), Spearman: {spearman_r:.3f} (p={spearman_p:.3f})")
    
    return {
        'gene': gene,
        'oxid_type': oxid_type,
        'pearson_r': pearson_r,
        'pearson_p': pearson_p,
        'spearman_r': spearman_r,
        'spearman_p': spearman_p,
        'n_points': len(plot_df)
    }

# Function to run analysis on all genes and create summary
def run_correlation_analysis():
    """Run correlation analysis on all genes and create a summary table"""
    # Load data
    expr_df, cpg_df, bin_df = load_data()
    gene_list = load_gene_lists()
    
    # Store results
    results = []
    
    # Process each gene
    for gene, gene_id, source_file in gene_list:
        result = plot_gene_correlation(gene, gene_id, source_file, expr_df, cpg_df, bin_df)
        if result is not None:
            results.append(result)
    
    # Create summary DataFrame
    if results:
        summary_df = pd.DataFrame(results)
        
        # Save summary to CSV
        summary_path = output_dir / "correlation_summary.csv"
        summary_df.to_csv(summary_path, index=False)
        print(f"\nSummary saved to {summary_path}")
        
        # Print summary statistics
        print("\n=== CORRELATION ANALYSIS SUMMARY ===")
        print(f"Total genes analyzed: {len(summary_df)}")
        print(f"Mean Pearson correlation: {summary_df['pearson_r'].mean():.3f}")
        print(f"Mean Spearman correlation: {summary_df['spearman_r'].mean():.3f}")
        
        # Count significant correlations
        sig_pearson = (summary_df['pearson_p'] < 0.05).sum()
        sig_spearman = (summary_df['spearman_p'] < 0.05).sum()
        
        print(f"Significant Pearson correlations (p < 0.05): {sig_pearson}/{len(summary_df)} ({sig_pearson/len(summary_df)*100:.1f}%)")
        print(f"Significant Spearman correlations (p < 0.05): {sig_spearman}/{len(summary_df)} ({sig_spearman/len(summary_df)*100:.1f}%)")
        
        # Show cases where correlations differ substantially
        diff_threshold = 0.2
        large_diff = abs(summary_df['pearson_r'] - summary_df['spearman_r']) > diff_threshold
        if large_diff.any():
            print(f"\nGenes with substantial difference between Pearson and Spearman (|Δr| > {diff_threshold}):")
            diff_genes = summary_df[large_diff][['gene', 'pearson_r', 'spearman_r']]
            print(diff_genes.to_string(index=False))
    
    return summary_df if results else None


summary = run_correlation_analysis()

Loading expression data...
Loading CpG oxidation data...
Loading bin oxidation data...
   Unnamed: 0                Gene  Expression_level           Sample  \
0           0  ENSMUSG00000000001         15.613475  01_Ctrl_morning   
1           1  ENSMUSG00000000003          0.000000  01_Ctrl_morning   
2           2  ENSMUSG00000000028          0.775004  01_Ctrl_morning   
3           3  ENSMUSG00000000031          0.099366  01_Ctrl_morning   
4           4  ENSMUSG00000000037          0.403567  01_Ctrl_morning   

          Group  
0  Ctrl_morning  
1  Ctrl_morning  
2  Ctrl_morning  
3  Ctrl_morning  
4  Ctrl_morning  
Loading genes from ../data_anova/cpg/cpg_overlap_genes_result.bed...
Found 2 unique genes in cpg_overlap_genes_result.bed
Loading genes from ../data_anova/bin1000/overlap_genes_result.bed...
Found 11 unique genes in overlap_genes_result.bed
Loading genes from ../data_anova/bin1000/overlap_promoters_result.bed...
Found 1 unique genes in overlap_promoters_result.bed
Proce