# Brain Wide Map: Statistical Analysis Example

This notebook demonstrates advanced statistical analysis of Brain Wide Map data.

In [None]:
import sys
sys.path.append('..')

from brainwidemap import DataLoader, Explorer, Statistics, Visualizer
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

## 1. Load Data

In [None]:
# Initialize
loader = DataLoader(mode='auto')
explorer = Explorer(loader)
stats = Statistics()
viz = Visualizer()

# Get a session
sessions = explorer.list_sessions(n_trials_min=400)
eid = sessions.iloc[0]['eid']

# Load data
spikes, clusters = loader.load_spike_data(eid)
trials = loader.load_trials(eid)

print(f"Session: {eid}")
print(f"Units: {len(clusters)}")
print(f"Trials: {len(trials)}")

## 2. Trial-Aligned Analysis (PSTH)

Compute peri-stimulus time histograms aligned to trial events.

In [None]:
# Get stimulus onset times from trials
if 'stimOn_times' in trials.columns:
    event_times = trials['stimOn_times'].values
    event_times = event_times[~np.isnan(event_times)]
    
    # Compute PSTH for first few units
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))
    axes = axes.flatten()
    
    for i, cluster_id in enumerate(clusters.index[:4]):
        time_bins, psth = stats.compute_psth(
            spikes['times'],
            spikes['clusters'],
            cluster_id=cluster_id,
            event_times=event_times,
            window=(-0.5, 1.0),
            bin_size=0.02
        )
        
        axes[i].plot(time_bins, psth)
        axes[i].axvline(x=0, color='r', linestyle='--', alpha=0.5)
        axes[i].set_xlabel('Time from stimulus (s)')
        axes[i].set_ylabel('Firing rate (Hz)')
        axes[i].set_title(f'Unit {cluster_id}')
    
    plt.tight_layout()
    plt.show()
else:
    print("No stimulus onset times available")

## 3. Population Analysis by Brain Region

In [None]:
# Get brain regions
regions = loader.get_brain_regions(eid)

if len(regions) > 0:
    # Compute population statistics for each region
    region_stats_list = []
    
    for region in regions[:10]:  # Top 10 regions
        try:
            pop_stats = stats.compute_population_statistics(
                spikes, clusters, brain_region=region
            )
            pop_stats['region'] = region
            region_stats_list.append(pop_stats)
        except:
            continue
    
    # Create DataFrame
    region_stats_df = pd.DataFrame(region_stats_list)
    
    # Plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot 1: Mean firing rate by region
    region_stats_df.plot.bar(
        x='region', y='mean_firing_rate', 
        ax=ax1, legend=False, color='steelblue'
    )
    ax1.set_xlabel('Brain Region')
    ax1.set_ylabel('Mean Firing Rate (Hz)')
    ax1.set_title('Average Firing Rate by Region')
    ax1.tick_params(axis='x', rotation=45)
    
    # Plot 2: Number of units by region
    region_stats_df.plot.bar(
        x='region', y='n_units',
        ax=ax2, legend=False, color='coral'
    )
    ax2.set_xlabel('Brain Region')
    ax2.set_ylabel('Number of Units')
    ax2.set_title('Unit Count by Region')
    ax2.tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    print("\nRegion Statistics:")
    print(region_stats_df[['region', 'n_units', 'mean_firing_rate', 'mean_cv']].to_string(index=False))
else:
    print("No brain region information available")

## 4. Trial-by-Trial Variability Analysis

In [None]:
# Analyze trial-by-trial variability (Fano factor)
if 'stimOn_times' in trials.columns and 'feedback_times' in trials.columns:
    # Create trial windows
    trial_windows = list(zip(
        trials['stimOn_times'].values,
        trials['feedback_times'].values
    ))
    
    # Remove NaN windows
    trial_windows = [(s, e) for s, e in trial_windows if not (np.isnan(s) or np.isnan(e))]
    
    if len(trial_windows) > 0:
        # Compute trial rates and Fano factors for multiple units
        fano_factors = []
        mean_rates = []
        
        for cluster_id in clusters.index[:50]:  # First 50 units
            trial_rates = stats.compute_trial_firing_rates(
                spikes['times'],
                spikes['clusters'],
                cluster_id=cluster_id,
                trial_windows=trial_windows
            )
            
            # Compute spike counts for Fano factor
            spike_counts = trial_rates * np.array([e - s for s, e in trial_windows])
            fano = stats.compute_fano_factor(spike_counts)
            
            if not np.isnan(fano):
                fano_factors.append(fano)
                mean_rates.append(trial_rates.mean())
        
        # Plot Fano factor vs mean rate
        fig, ax = plt.subplots(figsize=(8, 6))
        ax.scatter(mean_rates, fano_factors, alpha=0.6)
        ax.set_xlabel('Mean Firing Rate (Hz)')
        ax.set_ylabel('Fano Factor')
        ax.set_title('Trial-to-Trial Variability vs Firing Rate')
        ax.axhline(y=1, color='r', linestyle='--', alpha=0.5, label='Poisson')
        ax.legend()
        plt.tight_layout()
        plt.show()
        
        print(f"\nMean Fano Factor: {np.mean(fano_factors):.3f}")
        print(f"Median Fano Factor: {np.median(fano_factors):.3f}")

## 5. Correlation Analysis

In [None]:
# Compute correlation between units
# Bin spike times and create firing rate matrix

bin_size = 0.1  # 100ms bins
t_start = spikes['times'].min()
t_end = spikes['times'].max()
time_bins = np.arange(t_start, t_end, bin_size)

# Use subset of units
n_units = min(20, len(clusters))
selected_units = clusters.index[:n_units]

# Create firing rate matrix
fr_matrix = np.zeros((n_units, len(time_bins) - 1))

for i, cluster_id in enumerate(selected_units):
    cluster_spikes = spikes['times'][spikes['clusters'] == cluster_id]
    counts, _ = np.histogram(cluster_spikes, bins=time_bins)
    fr_matrix[i, :] = counts / bin_size

# Compute correlation matrix
corr_matrix = stats.compute_correlation_matrix(fr_matrix)

# Plot
fig = viz.plot_correlation_matrix(
    corr_matrix,
    labels=[f'U{i}' for i in range(n_units)]
)
plt.show()

# Summary statistics
off_diag = corr_matrix[~np.eye(n_units, dtype=bool)]
print(f"\nMean pairwise correlation: {off_diag.mean():.3f}")
print(f"Std pairwise correlation: {off_diag.std():.3f}")

## 6. Behavioral Correlation Analysis

In [None]:
# Analyze neural activity in relation to behavioral performance
if 'feedbackType' in trials.columns and len(trial_windows) > 0:
    # Select a unit
    cluster_id = clusters.index[0]
    
    # Compute trial rates
    trial_rates = stats.compute_trial_firing_rates(
        spikes['times'],
        spikes['clusters'],
        cluster_id=cluster_id,
        trial_windows=trial_windows[:len(trials)]
    )
    
    # Split by correct/incorrect
    correct_mask = trials['feedbackType'].values[:len(trial_rates)] == 1
    correct_rates = trial_rates[correct_mask]
    incorrect_rates = trial_rates[~correct_mask]
    
    # Statistical test
    if len(correct_rates) > 0 and len(incorrect_rates) > 0:
        t_stat, p_val = stats.perform_ttest(correct_rates, incorrect_rates)
        
        # Plot
        fig, ax = plt.subplots(figsize=(8, 6))
        
        positions = [1, 2]
        data = [correct_rates, incorrect_rates]
        labels = ['Correct', 'Incorrect']
        
        bp = ax.boxplot(data, positions=positions, labels=labels)
        ax.set_ylabel('Firing Rate (Hz)')
        ax.set_title(f'Unit {cluster_id}: Activity by Trial Outcome\n'
                    f't={t_stat:.3f}, p={p_val:.4f}')
        
        plt.tight_layout()
        plt.show()
        
        print(f"\nCorrect trials: mean={correct_rates.mean():.2f}, std={correct_rates.std():.2f}")
        print(f"Incorrect trials: mean={incorrect_rates.mean():.2f}, std={incorrect_rates.std():.2f}")
        print(f"t-test: t={t_stat:.3f}, p={p_val:.4f}")

## Summary

This notebook demonstrated:
- PSTH analysis aligned to task events
- Population statistics by brain region
- Trial-to-trial variability (Fano factor)
- Neural correlation analysis
- Behavioral correlation of neural activity

These analyses provide insights into neural coding and decision-making in the Brain Wide Map dataset.