In [None]:
# Z-score transformation for calcium imaging data
# Uses robust median/MAD normalization with proper scaling
# Updated: 2025-11-17

import pandas as pd
import numpy as np
import plotly.graph_objects as go
import os

# =============================================================================
# Configuration
# =============================================================================

INPUT_FILE = "path/to/your/data.csv"
OUTPUT_DIR = "path/to/your/output_directory"
OUTPUT_FILENAME = "your-file-name.csv"

PLOT_CELLS = True  # Set False to skip plotting
SPECIFIC_TIME_MARKER = 60  # Vertical line in plots (seconds)

# =============================================================================
# Functions
# =============================================================================

def load_and_clean_data(filepath):
    """Load CSV and clean data."""
    try:
        data = pd.read_csv(filepath)
    except FileNotFoundError:
        raise FileNotFoundError(f"File {filepath} not found.")
    except pd.errors.EmptyDataError:
        raise ValueError(f"File {filepath} is empty.")
    except pd.errors.ParserError:
        raise ValueError(f"File {filepath} is not valid CSV.")

    print(f"Loaded: {data.shape[0]} rows × {data.shape[1]} columns")
    
    cell_columns = [col for col in data.columns if col.startswith('cell-')]
    print(f"Found {len(cell_columns)} cells")

    if not cell_columns:
        raise ValueError("No columns starting with 'cell-' found.")

    time = data.iloc[:, 0].copy()

    F_clean = data[cell_columns].apply(pd.to_numeric, errors='coerce')
    F_clean = F_clean.replace([np.inf, -np.inf], np.nan)
    F_clean = F_clean.ffill().bfill()

    # Drop all-NaN rows
    all_nan = F_clean.isna().all(axis=1)
    if all_nan.any():
        print(f"Dropping {all_nan.sum()} all-NaN rows")
        F_clean = F_clean.loc[~all_nan]
        time = time.loc[F_clean.index]

    F_clean = F_clean.reset_index(drop=True)
    time = time.reset_index(drop=True)

    return time, F_clean, cell_columns


def robust_zscore(F_cell, eps=1e-9):

    arr = np.asarray(F_cell, dtype=float)
    arr[~np.isfinite(arr)] = np.nan
    
    median_val = np.nanmedian(arr)
    mad_val = np.nanmedian(np.abs(arr - median_val))
    scaled_mad = 1.4826 * mad_val
    
    if not np.isfinite(scaled_mad) or scaled_mad < eps:
        scaled_mad = eps
    
    z = (arr - median_val) / scaled_mad
    
    stats = {
        'median': median_val,
        'mad': mad_val,
        'scaled_mad': scaled_mad
    }
    
    return z, stats


def plot_zscore(time, z, cell_name, marker_time=None):
    """Plot Z-scored trace with threshold lines."""
    if not np.any(np.isfinite(z)):
        print(f"Skipping plot for {cell_name}: all NaN")
        return

    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=time, y=z,
        mode='lines',
        name=cell_name,
        line=dict(color='royalblue', width=1)
    ))

    # Reference lines
    fig.add_hline(y=0, line_dash="dash", line_color="gray")
    for thresh in [2, 3]:
        fig.add_hline(y=thresh, line_dash="dot", line_color="orange", opacity=0.5)

    if marker_time:
        fig.add_vline(x=marker_time, line_dash="dot", line_color="red")

    fig.update_layout(
        title=f"Z-score: {cell_name}",
        xaxis_title='Time (s)',
        yaxis_title='Z-score',
        template='plotly_white'
    )
    fig.show()


def validate_zscore(Z_df):
    """
    Validate Z-score transformation with CORRECT metrics for Ca²⁺ data.
    """
    all_z = Z_df.values.flatten()
    all_z = all_z[np.isfinite(all_z)]
    
    global_mean = np.mean(all_z)
    global_median = np.median(all_z)
    global_std = np.std(all_z)
    
    # Per-cell medians (should all be ~0)
    cell_medians = [np.nanmedian(Z_df[col]) for col in Z_df.columns]
    mean_of_medians = np.mean(cell_medians)
    max_median_deviation = np.max(np.abs(cell_medians))
    
    # Baseline statistics (points below median, i.e., Z < 0)
    baseline_z = all_z[all_z <= 0]
    baseline_std = np.std(baseline_z) if len(baseline_z) > 0 else np.nan
    
    print("\n" + "="*70)
    print("Z-SCORE VALIDATION (Robust median/MAD method)")
    print("="*70)
    
    print("\n[CENTERING CHECK] — Per-cell medians should be 0:")
    print(f"  Mean of cell medians:    {mean_of_medians:>10.6f}  (should be ~0)")
    print(f"  Max |median| any cell:   {max_median_deviation:>10.6f}  (should be ~0)")
    
    median_ok = max_median_deviation < 1e-6
    print(f"  Status: {'✓ PASS' if median_ok else '✗ FAIL'}")
    
    print("\n[GLOBAL STATISTICS] — For skewed Ca²⁺ data:")
    print(f"  Global median: {global_median:>8.3f}  (should be ~0) {'✓' if abs(global_median) < 0.1 else ''}")
    print(f"  Global mean:   {global_mean:>8.3f}  (expected > 0 due to transients)")
    print(f"  Global std:    {global_std:>8.3f}  (expected > 1 due to transients)")
    
    print("\n[BASELINE NOISE CHECK]:")
    print(f"  Std of Z <= 0: {baseline_std:>8.3f}  (should be ~0.6-0.8 for half-distribution)")
    
    # Skewness check
    skewness = np.mean((all_z - global_mean)**3) / (global_std**3)
    print(f"\n[SKEWNESS]: {skewness:>8.3f}  (positive = right-skewed, expected for Ca²⁺)")
    
    print("\n" + "-"*70)
    print("INTERPRETATION:")
    print("  • Median = 0 confirms correct centering")
    print("  • Mean > 0 is NORMAL for Ca²⁺ (transients pull mean up)")
    print("  • Std > 1 is NORMAL for Ca²⁺ (transients inflate variance)")
    print("  • Transients appear as Z >> 0 (e.g., Z > 2-3)")
    print("="*70)
    
    return {
        'global_mean': global_mean,
        'global_median': global_median,
        'global_std': global_std,
        'mean_of_cell_medians': mean_of_medians,
        'baseline_std': baseline_std,
        'skewness': skewness
    }


# =============================================================================
# Main
# =============================================================================

def main():
    print("="*70)
    print("ROBUST Z-SCORE TRANSFORMATION FOR Ca²⁺ IMAGING DATA")
    print("="*70)
    
    # Load data
    time, F_clean, cell_columns = load_and_clean_data(INPUT_FILE)
    print(f"\nProcessing {len(cell_columns)} cells, {len(time)} time points")
    
    # Z-score each cell
    Z_dict = {}
    stats_list = []
    
    for i, cell in enumerate(cell_columns):
        z, stats = robust_zscore(F_clean[cell])
        Z_dict[cell] = z
        stats['cell'] = cell
        stats_list.append(stats)
        
        if PLOT_CELLS:
            plot_zscore(time, z, cell, SPECIFIC_TIME_MARKER)
        
        if (i + 1) % 50 == 0:
            print(f"  Processed {i + 1}/{len(cell_columns)} cells")
    
    # Create output DataFrame
    Z_df = pd.DataFrame(Z_dict)
    
    # Validate
    validation = validate_zscore(Z_df)
    
    # Save Z-scored data
    df_out = Z_df.copy()
    df_out.insert(0, 'time', time)
    
    output_path = os.path.join(OUTPUT_DIR, OUTPUT_FILENAME)
    df_out.to_csv(output_path, index=False)
    print(f"\nZ-scored data saved to: {output_path}")
    
    # Save normalization stats
    stats_df = pd.DataFrame(stats_list)
    stats_path = os.path.join(OUTPUT_DIR, "normalization_stats.csv")
    stats_df.to_csv(stats_path, index=False)
    print(f"Normalization stats saved to: {stats_path}")
    
    return df_out, validation


if __name__ == "__main__":
    df_zscore, validation = main()