In [None]:
"""
Speed-Calcium Signal Analysis Script

This script analyzes the relationship between mean speed and calcium signals from a CSV file.
It performs the following operations:
1. Loads and validates input data
2. Calculates Spearman correlations with bootstrap confidence intervals
3. Applies FDR correction for multiple comparisons
4. Categorizes effect sizes
5. Generates visualizations (scatter plots and dual-axis charts)
6. Saves statistical results to a new CSV file

Key improvements:
- Proper FDR correction across all tests
- Efficient bootstrap implementation
- Comprehensive error handling
- Type hints for better code clarity
- Detailed annotations and documentation
- Optimized visualization functions
"""

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
from scipy.stats import spearmanr
from statsmodels.stats.multitest import multipletests
from sklearn.utils import resample
import os
from typing import List, Dict, Tuple, Union

def load_and_validate_data(file_path: str) -> pd.DataFrame:
    """
    Load CSV file and validate required columns.
    
    Args:
        file_path: Path to the CSV file
        
    Returns:
        DataFrame with validated data
        
    Raises:
        FileNotFoundError: If file doesn't exist
        ValueError: If required columns are missing
    """
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"File not found: {file_path}")
    
    df = pd.read_csv(file_path)
    
    if 'mean_speed' not in df.columns:
        raise ValueError("CSV file must contain 'mean_speed' column.")
    
    cell_columns = [col for col in df.columns if col.startswith("cell-") and col.endswith("_mean")]
    
    if not cell_columns:
        raise ValueError("No cell columns found (expected format: 'cell-XXX_mean')")
    
    return df

def bootstrap_ci(data_x: pd.Series, data_y: pd.Series, num_bootstrap: int = 1000) -> np.ndarray:
    """
    Calculate bootstrap confidence interval for Spearman correlation.
    
    Args:
        data_x: First variable (mean speed)
        data_y: Second variable (calcium signal)
        num_bootstrap: Number of bootstrap samples
        
    Returns:
        Array with lower and upper bounds of 95% CI
    """
    boot_corrs = []
    n = len(data_x)
    
    for _ in range(num_bootstrap):
        # Generate bootstrap indices
        indices = resample(np.arange(n), replace=True)
        # Calculate correlation for bootstrap sample
        corr, _ = spearmanr(data_x.iloc[indices], data_y.iloc[indices])
        boot_corrs.append(corr)
    
    return np.percentile(boot_corrs, [2.5, 97.5])

def calculate_statistics(speed: pd.Series, signal: pd.Series) -> Tuple[float, float, float, np.ndarray]:
    """
    Compute Spearman correlation statistics with bootstrap confidence interval.
    
    Args:
        speed: Mean speed values
        signal: Calcium signal values
        
    Returns:
        Tuple containing:
        - Spearman correlation coefficient
        - p-value
        - Effect size (r-squared)
        - Bootstrap confidence interval
    """
    spearman_corr, spearman_p = spearmanr(speed, signal)
    effect_size = spearman_corr ** 2
    spearman_ci = bootstrap_ci(speed, signal)
    
    return spearman_corr, spearman_p, effect_size, spearman_ci

def categorize_effect_size(effect_size: float) -> str:
    """
    Categorize effect size based on Cohen's guidelines.
    
    Args:
        effect_size: R-squared value
        
    Returns:
        Effect size category string
    """
    if effect_size >= 0.25:
        return "Strong"
    elif effect_size >= 0.09:
        return "Moderate"
    elif effect_size >= 0.04:
        return "Weak"
    else:
        return "Very Weak"

def create_scatter_plot(df: pd.DataFrame, cell: str, stats: Dict[str, Union[float, str, bool]]) -> go.Figure:
    """
    Create a scatter plot with statistical annotations.
    
    Args:
        df: DataFrame with data
        cell: Cell column name
        stats: Dictionary with statistics
        
    Returns:
        Plotly Figure object
    """
    fig = px.scatter(df, x='mean_speed', y=cell, title=f'Speed vs {cell}')
    fig.update_layout(xaxis_title='Mean Speed', yaxis_title=cell)
    
    annotation_text = (
        f"Spearman r: {stats['Spearman_r']:.2f}, p-value: {stats['Adjusted_P_value']:.6f}\n"
        f"Effect Size: {stats['Effect_Size']:.2f} ({stats['Effect_Size_Category']})\n"
        f"95% CI: [{stats['CI_Lower']:.2f}, {stats['CI_Upper']:.2f}]\n"
        f"Speed Modulated: {'Yes' if stats['Speed_Modulated'] else 'No'}"
    )
    
    fig.add_annotation(
        text=annotation_text,
        xref="paper", yref="paper",
        x=0.5, y=1.2,
        showarrow=False,
        align="center",
        bgcolor="rgba(255,255,255,0.8)"
    )
    
    return fig

def create_dual_axis_plot(df: pd.DataFrame, cell: str) -> go.Figure:
    """
    Create a dual-axis bar chart with trend line.
    
    Args:
        df: DataFrame with data
        cell: Cell column name
        
    Returns:
        Plotly Figure object
    """
    fig = go.Figure()
    
    # Add cell signal bars
    fig.add_trace(go.Bar(
        x=df.index,
        y=df[cell],
        name=cell,
        yaxis="y1",
        opacity=0.7
    ))
    
    # Add mean speed bars with transparency
    fig.add_trace(go.Bar(
        x=df.index,
        y=df['mean_speed'],
        name="Mean Speed",
        yaxis="y2",
        opacity=0.3
    ))
    
    # Add trend line for cell signal
    trend_x = df.index
    trend_y = np.poly1d(np.polyfit(df.index, df[cell], 1))(df.index)
    fig.add_trace(go.Scatter(
        x=trend_x,
        y=trend_y,
        mode='lines',
        name=f"Trend {cell}",
        yaxis="y1",
        line=dict(color='red', width=2)
    ))
    
    fig.update_layout(
        title=f"Time Series for {cell}",
        xaxis_title="Time Bins",
        yaxis=dict(
            title=cell,
            side="left",
            showgrid=False
        ),
        yaxis2=dict(
            title="Mean Speed",
            overlaying="y",
            side="right",
            showgrid=False
        ),
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
    )
    
    return fig

def save_results(results: List[Dict[str, Union[float, str, bool]]], output_path: str) -> None:
    """
    Save statistics to a CSV file.
    
    Args:
        results: List of dictionaries with statistics
        output_path: Path to save the CSV file
    """
    result_df = pd.DataFrame(results)
    result_df.to_csv(output_path, index=False)
    print(f"Results saved to: {output_path}")

def analyze_speed_calcium_relationship(file_path: str, significance_threshold: float = 0.05) -> None:
    """
    Main analysis function for speed-calcium relationships.
    
    Args:
        file_path: Path to the input CSV file
        significance_threshold: Threshold for statistical significance
    """
    # Load and validate data
    df = load_and_validate_data(file_path)
    
    # Identify cell columns
    cell_columns = [col for col in df.columns if col.startswith("cell-") and col.endswith("_mean")]
    print(f"Found {len(cell_columns)} cell columns to analyze")
    
    # Calculate statistics for each cell
    results = []
    p_values = []
    
    for cell in cell_columns:
        spearman_corr, spearman_p, effect_size, spearman_ci = calculate_statistics(
            df['mean_speed'], df[cell]
        )
        p_values.append(spearman_p)
        
        results.append({
            'Cell': cell,
            'Spearman_r': spearman_corr,
            'P_value': spearman_p,
            'Effect_Size': effect_size,
            'CI_Lower': spearman_ci[0],
            'CI_Upper': spearman_ci[1]
        })
    
    # Apply FDR correction to all p-values simultaneously
    _, adjusted_p_values, _, _ = multipletests(p_values, method='fdr_bh')
    
    # Update results with FDR-adjusted p-values and other metrics
    for i, result in enumerate(results):
        result['Adjusted_P_value'] = adjusted_p_values[i]
        result['Effect_Size_Category'] = categorize_effect_size(result['Effect_Size'])
        result['Speed_Modulated'] = result['Adjusted_P_value'] < significance_threshold
    
    # Create visualizations and save results
    output_csv_path = file_path.replace(".csv", "_speed_modulation_stats.csv")  # Change to your actual file name
    
    for i, result in enumerate(results):
        cell = result['Cell']
        
        # Create and show scatter plot
        scatter_fig = create_scatter_plot(df, cell, result)
        scatter_fig.show()
        
        # Create and show dual-axis plot
        dual_fig = create_dual_axis_plot(df, cell)
        dual_fig.show()
        
        # Print progress
        print(f"Processed {i+1}/{len(cell_columns)}: {cell}")
    
    # Save results
    save_results(results, output_csv_path)

# Main execution
if __name__ == "__main__":
    # Set file path
    file_path = "path_to_your_file.csv"  # Change to your actual file path
    # Run analysis
    analyze_speed_calcium_relationship(file_path)

Found 21 cell columns to analyze


Processed 1/21: cell-00_mean


Processed 2/21: cell-01_mean


Processed 3/21: cell-02_mean


Processed 4/21: cell-03_mean


Processed 5/21: cell-04_mean


Processed 6/21: cell-09_mean


Processed 7/21: cell-10_mean


Processed 8/21: cell-11_mean


Processed 9/21: cell-12_mean


Processed 10/21: cell-13_mean


Processed 11/21: cell-14_mean


Processed 12/21: cell-15_mean


Processed 13/21: cell-16_mean


Processed 14/21: cell-17_mean


Processed 15/21: cell-19_mean


Processed 16/21: cell-20_mean


Processed 17/21: cell-21_mean


Processed 18/21: cell-22_mean


Processed 19/21: cell-24_mean


Processed 20/21: cell-25_mean


Processed 21/21: cell-26_mean
Results saved to: C:/Users/Labo/Downloads/Risna Speed Analysis/130-trial1-rectangular-arena-binned_calcium_speed_analysis_speed_modulation_stats.csv
