# Time-Resolved Neural Selectivity Analysis
Quantifying % of neurons selective for different aspects as function of time

## 1: ENVIRONMENT SETUP AND IMPORTS


In [None]:
# Core data manipulation and analysis
import glob
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from scipy.io import loadmat
import warnings
warnings.filterwarnings('ignore')

# Statistical analysis
import statsmodels.formula.api as smf
import statsmodels.api as sm
from statsmodels.stats.multitest import multipletests

# Progress tracking
from tqdm import tqdm

# Set plotting style
plt.style.use('default')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12

print("✓ All packages imported successfully")
print("✓ Environment setup complete")

## 2: LOAD DATA


### Loading MATLAB Data Files

Loading of MATLAB data files containing neural recordings and trial information.

In [None]:
# Find all WMB_P*_v7.mat files in the current directory
mat_files = glob.glob('./WMB_P*_v7.mat')
print(f"Found {len(mat_files)} .mat files: {[os.path.basename(f) for f in mat_files]}")

# Initialize lists to store data from each file
cell_mats = []
total_mats = []

# Load each file and append its data
for mat_file in mat_files:
    print(f"\nLoading {mat_file}...")
    mat_data = loadmat(mat_file)
    cell_mats.append(mat_data['cellStatsAll'])
    total_mats.append(mat_data['totStats'])

# Print shapes of loaded data for debugging
print("\nShapes of loaded data:")
for i, (cell, total) in enumerate(zip(cell_mats, total_mats)):
    print(f"File {i}: cell_mat shape: {cell.shape}, total_mat shape: {total.shape}")

In [None]:
# Combine the data
# For cell_mat, we need to handle the different dimensions
# First, convert each cell_mat to a list of records
all_cell_records = []
for cell_mat in cell_mats:
    # Convert to list of records
    cell_list = cell_mat[0]  # now shape is (n,)
    records = []
    for cell in cell_list:
        record = {key: cell[key] for key in cell.dtype.names}
        records.append(record)
    all_cell_records.extend(records)

# Convert combined records to DataFrame
df = pd.DataFrame(all_cell_records)

# For total_mat, we can concatenate directly since they have the same structure
total_mat = np.concatenate(total_mats, axis=0)

print(f"\nCombined data shape - total_mat: {total_mat.shape}")
print(f"\nCombined data shape - df: {df.shape}")

### Translates numeric area codes

In [4]:
from collections import Counter

def count_area_codes(area_column):
    
    mapping = {
        1: 'RH', 2: 'LH', 3: 'RA', 4: 'LA', 5: 'RAC', 6: 'LAC',
        7: 'RSMA', 8: 'LSMA', 9: 'RPT', 10: 'LPT', 11: 'ROFC', 12: 'LOFC',
        50: 'RFFA', 51: 'REC', 52: 'RCM', 53: 'LCM', 54: 'RPUL', 55: 'LPUL',
        56: 'N/A', 57: 'RPRV', 58: 'LPRV'
    }
    
    labels = []
    for code in area_column:
        label = mapping.get(code, 'Unknown')
        labels.append(label)
    
    return dict(Counter(labels))

In [None]:
# neuron number in each area
area_codes = total_mat[:, 3]

counts = count_area_codes(area_codes)
print("Area counts (no prefix):")
for area, count in counts.items():
    print(f"{area}: {count}")

### Data Preprocessing & Filtering

First, we'll format the cell data and collapse brain areas into broader regions:

In [6]:
collapsed_area_map = {
    1: 'H', 2: 'H',
    3: 'A', 4: 'A',
    5: 'AC', 6: 'AC',
    7: 'SMA', 8: 'SMA',
    9: 'PT', 10: 'PT',
    11: 'OFC', 12: 'OFC',
    50: 'FFA', 51: 'EC',
    52: 'CM', 53: 'CM',
    54: 'PUL', 55: 'PUL',
    56: 'N/A', 57: 'PRV', 58: 'PRV'
}
# Convert brain area codes in the DataFrame
df['brainAreaOfCell'] = df['brainAreaOfCell'].apply(
    lambda x: collapsed_area_map.get(int(x[0, 0]), 'Unknown') if isinstance(x, np.ndarray) else collapsed_area_map.get(x, 'Unknown')
)

### Quality Control & Unit Selection

Filtering criteria for reliable unit selection:
- Minimum firing rate threshold
- Signal quality metrics

In [7]:
# Filter out units with low firing rate
fr = df['timestamps'].apply(lambda x: len(x) / (x[-1] - x[0]) * 1e6)
df_sample_new = df[fr > 0.1].reset_index(drop=True)

# unit id
df_sample_new = df_sample_new.reset_index(drop=True)
df_sample_new["unit_id"] = df_sample_new.index

In [8]:
# Filter out units with low firing rate
fr = df['timestamps'].apply(lambda x: len(x) / (x[-1] - x[0]) * 1e6)
df_sample_new = df[fr > 0.1].reset_index(drop=True)

# unit id
df_sample_new = df_sample_new.reset_index(drop=True)
df_sample_new["unit_id"] = df_sample_new.index

###  Neural Activity & Trial Processing

#### Trial Data Extraction

Extraction and alignment of trial-specific information with neural data:

In [9]:
def extract_trial_info(trials_struct, unit_id):
    # Build a DataFrame from the trials structure.
    df_trial = pd.DataFrame({field: trials_struct[field].squeeze() 
                             for field in trials_struct.dtype.names})
    # Add the unit_id so that you can later separate trials by unit/session.
    df_trial["unit_id"] = unit_id
    df_trial["trial_nr"] = df_trial["trial"].apply(lambda x: np.squeeze(x).item() if isinstance(x, (list, np.ndarray)) else x) - 1 # Adjust for 0-indexing
    return df_trial

In [10]:
trial_info_list = []
for idx, row in df_sample_new.iterrows():
    # Use the unit identifier from this row
    unit_id = row["unit_id"]  
    # Extract the trial DataFrame, including the unit identifier.
    trial_info_list.append(extract_trial_info(row["Trials"], unit_id, ))

# Concatenate the list of trial info DataFrames into one.
trial_info = pd.concat(trial_info_list, ignore_index=True)

#### Data Integration & Column Selection

Merge neural firing rate data with trial information and select relevant columns for analysis. This step combines the computed firing rates with behavioral trial data to enable selectivity analysis.


In [None]:
def extract_event_timestamps(df_sample_new, start_idx_col='idxEnc1', end_idx_col='idxDel1', is_window=False, window_size=0.5):
    """
    Extract event timestamps for computing firing rates across task epochs.
    
    Parameters:
    -----------
    df_sample_new : DataFrame
        Neural data with event indices and timestamps
    start_idx_col : str
        Column name containing start event indices
    end_idx_col : str  
        Column name containing end event indices
    is_window : bool
        If True, extract fixed window around start event; if False, extract epoch between start and end
    window_size : float
        Window size in seconds (only used if is_window=True)
        
    Returns:
    --------
    epoch_ts : dict
        Dictionary mapping unit_id to array of [start_time, end_time] pairs for each trial
    """
    epoch_ts = {}
    
    for i, row in df_sample_new.iterrows():
        unit_id = row['unit_id']
        events = row['events'].squeeze()  # Event timestamps array [trial_idx, event_type, timestamp]
        
        if is_window:
            # Extract fixed window around specific event (e.g., ±0.5s around response)
            idxs = row[start_idx_col].squeeze() - 1  # Convert to 0-based indexing
            extracted = events[idxs]
            center_times = extracted[:, 0]  # Get timestamp column
            window_start = center_times - window_size * 1e6  # Convert to microseconds
            window_end = center_times + window_size * 1e6
            combined = np.column_stack((window_start, window_end))
        else:
            # Extract epoch between two task events (e.g., encoding to delay)
            idxs_start = row[start_idx_col].squeeze() - 1  # Start event indices (0-based)
            idxs_end = row[end_idx_col].squeeze() - 1      # End event indices (0-based)
            
            # Handle mismatched trial counts between start and end events
            min_length = min(len(idxs_start), len(idxs_end))
            idxs_start = idxs_start[:min_length]
            idxs_end = idxs_end[:min_length]
            
            # Extract timestamps for start and end events
            extracted_start = events[idxs_start]  # [n_trials, 3] 
            extracted_end = events[idxs_end]      # [n_trials, 3]

            # Combine start and end timestamps for each trial
            combined = np.column_stack((extracted_start[:, 0], extracted_end[:, 0]))
        
        epoch_ts[unit_id] = combined
        
    return epoch_ts


def compute_firing_rates(df_sample_new, start_idx_col='idxEnc1', end_idx_col='idxDel1', 
                         fr_prefix='fr', is_window=False, window_size=0.5):
    """
    Compute firing rates for a specific task epoch.
    
    Parameters:
    -----------
    df_sample_new : DataFrame
        Neural data containing spike timestamps and event indices
    start_idx_col : str
        Column with start event indices
    end_idx_col : str
        Column with end event indices  
    fr_prefix : str
        Prefix for the firing rate column name
    is_window : bool
        Whether to use fixed window (True) or epoch between events (False)
    window_size : float
        Window size in seconds (only for is_window=True)
        
    Returns:
    --------
    df_sample_new : DataFrame
        Input dataframe with added firing rate columns
    """
    # Extract epoch timestamps for all units
    epoch_ts = extract_event_timestamps(df_sample_new, start_idx_col, end_idx_col, is_window, window_size)
    
    # Compute baseline firing rate (1 second before first stimulus) - only on first call
    if 'fr_baseline' not in df_sample_new.columns:
        print("Computing baseline firing rates...")
        
        # Get first stimulus onset timestamps
        enc_ts = extract_event_timestamps(df_sample_new, 'idxEnc1', 'idxEnc1')
        
        # Create baseline windows: 1 second before stimulus onset
        baseline_ts = {}
        for unit_id, timestamps in enc_ts.items():
            baseline_start = timestamps[:, 0] - 1e6  # 1 second before (microseconds)
            baseline_end = timestamps[:, 0]          # End at stimulus onset
            baseline_ts[unit_id] = np.column_stack((baseline_start, baseline_end))
            
        # Compute baseline firing rates for each unit and trial
        df_sample_new['fr_baseline'] = df_sample_new.apply(
            lambda row: [
                # Count spikes in baseline window and convert to Hz
                np.sum((np.ravel(row["timestamps"]) >= baseline_on) & 
                       (np.ravel(row["timestamps"]) < baseline_off)) / ((baseline_off - baseline_on) / 1e6)
                for baseline_on, baseline_off in baseline_ts[row["unit_id"]]
            ],
            axis=1
        )
    
    # Compute firing rates for the specified epoch
    epoch_col = f"{fr_prefix}_epoch"
    print(f"Computing firing rates for {epoch_col}...")
    
    df_sample_new[epoch_col] = df_sample_new.apply(
        lambda row: [
            # Count spikes in epoch window and convert to Hz
            np.sum((np.ravel(row["timestamps"]) >= epoch_on) & 
                   (np.ravel(row["timestamps"]) < epoch_off)) / ((epoch_off - epoch_on) / 1e6)
            for epoch_on, epoch_off in epoch_ts[row["unit_id"]]
        ],
        axis=1
    )
    
    # Add trial numbers if not already present
    if "trial_nr" not in df_sample_new.columns:
        df_sample_new["trial_nr"] = df_sample_new[epoch_col].apply(lambda x: np.arange(len(x)))
    
    return df_sample_new


print("Computing firing rates for all task epochs...")

# Compute firing rates for each task epoch
# Each call adds a new firing rate column to the dataframe

# 1. First stimulus encoding period (from stimulus onset to delay start)
df_sample_new = compute_firing_rates(df_sample_new, 'idxEnc1', 'idxDel1', fr_prefix='fr')

# 2. First delay period (between first stimulus and second stimulus)  
df_sample_new = compute_firing_rates(df_sample_new, 'idxDel1', 'idxEnc2', fr_prefix='fr_del1')

# 3. Second stimulus encoding period (from second stimulus onset to second delay)
df_sample_new = compute_firing_rates(df_sample_new, 'idxEnc2', 'idxDel2', fr_prefix='fr_enc2')

# 4. Second delay period (between second stimulus and probe)
df_sample_new = compute_firing_rates(df_sample_new, 'idxDel2', 'idxProbeOn', fr_prefix='fr_del2')

# 5. Response period (±0.5 second window around button press)
df_sample_new = compute_firing_rates(df_sample_new, 'idxResp', None, fr_prefix='fr_resp', 
                                   is_window=True, window_size=0.5)

print("Expanding dataframe from unit-based to trial-based format...")

# Transform from unit-level to trial-level format
# Each unit currently has lists of firing rates (one per trial)
# After exploding, each row will represent one unit-trial combination
columns_to_explode = ['fr_baseline', 'fr_epoch', 'fr_del1_epoch', 'fr_enc2_epoch', 
                      'fr_del2_epoch', 'fr_resp_epoch', "trial_nr"]
df_sample_new = df_sample_new.explode(columns_to_explode)

print(f"Final neural data shape: {df_sample_new.shape}")


In [None]:
print("Merging neural data with trial information...")

# Reset indices to ensure proper alignment
df_sample_new = df_sample_new.reset_index(drop=True)
trial_info = trial_info.reset_index(drop=True)

# Merge neural firing rates with behavioral trial data
# This combines unit-trial firing rates with stimulus information for each trial
data = pd.merge(
    df_sample_new,
    trial_info,
    on=["unit_id", "trial_nr"],  # Join on unit and trial identifiers
    how="left",                  # Keep all neural data, add matching trial info
).infer_objects()

print(f"Combined data shape: {data.shape}")

# Select columns needed for selectivity analysis
# Include: neural identifiers, firing rates, stimulus properties, task timing
cols_to_keep = [
    # Neural data identifiers and firing rates
    "unit_id", "timestamps", "brainAreaOfCell", 
    "fr_epoch", "fr_baseline", "fr_del1_epoch", "fr_enc2_epoch", "fr_del2_epoch", "fr_resp_epoch", 
    "trial_nr",
    
    # Stimulus properties for selectivity analysis
    "first_cat", "second_cat", "first_num", "second_num",
    "first_pic", "second_pic", "probe_cat", "probe_pic",
    
    # Task variables
    "probe_validity", "probe_num", "correct_answer",
    "rt", "acc", "key", "cat_comparison", 
    
    # Event timing and metadata
    "events", "nTrials", "Trials", 
    "idxEnc1", "idxEnc2", "idxDel1", "idxDel2", "idxProbeOn", "idxResp", 
    "nrProcessed", "periods_Enc1", "periods_Enc2", "periods_Del1", "periods_Del2", 
    "periods_Probe", "periods_Resp", "prestimEnc", "prestimMaint", "prestimProbe",
    "prestimButtonPress", "poststimEnc", "poststimMaint", "poststimProbe", 
    "poststimButtonPress", "sessionIdx", "channel", "cellNr", "sessionID", "origClusterID"
]

# Create filtered dataset with only necessary columns
data_filtered = data[cols_to_keep].copy()

print("Converting stimulus variables to simple string format for statistical analysis...")

# Convert stimulus variables to simple string format
# MATLAB arrays need to be converted to hashable Python types for grouping operations
data_filtered["first_cat_simple"] = data_filtered["first_cat"].apply(
    lambda x: str(np.squeeze(x)) if isinstance(x, (list, np.ndarray)) else str(x)
)
data_filtered["second_cat_simple"] = data_filtered["second_cat"].apply(
    lambda x: str(np.squeeze(x)) if isinstance(x, (list, np.ndarray)) else str(x)
)
data_filtered["first_num_simple"] = data_filtered["first_num"].apply(
    lambda x: str(np.squeeze(x)) if isinstance(x, (list, np.ndarray)) else str(x)
)
data_filtered["second_num_simple"] = data_filtered["second_num"].apply(
    lambda x: str(np.squeeze(x)) if isinstance(x, (list, np.ndarray)) else str(x)
)


In [None]:
# Verify data structure
print(f"Data shape: {data_filtered.shape if 'data_filtered' in locals() else 'data_filtered not found'}")


## 3: DEFINE TIME ANALYSIS PARAMETERS

Configuration parameters for time-resolved analysis

In [None]:
# Check the columns in full_results to see what's available
print("Columns in full_results:")
print(full_results.columns.tolist())
print("\nFirst few rows:")
print(full_results.head())


In [None]:
class TimeAnalysisConfig:
    
    
    # Time window parameters (all in milliseconds)
    WINDOW_WIDTH = 500      # Width of each analysis window
    STEP_SIZE = 100         # Step between windows
    
    # Analysis period relative to stimulus onset
    TIME_START = -500      # Pre-stimulus baseline
    TIME_END = 6000        # Post-stimulus through response
    
    # Statistical parameters
    P_THRESHOLD = 0.05     # Significance threshold for selectivity
    MIN_TRIALS = 4        # Minimum trials per condition per neuron
    MIN_FIRING_RATE = 0.5  # Minimum average firing rate (Hz)
    
    # Event timings (in ms, relative to first stimulus)
    EVENTS = {
        'first_stimulus': 0,
        'delay1': 1000,
        'second_stimulus': 2000, 
        'delay2': 3000,
        'probe': 5500,
        'response': 5500
    }

config = TimeAnalysisConfig()

# Generate time windows
time_windows = []
window_centers = []

for t_start in range(config.TIME_START, config.TIME_END - config.WINDOW_WIDTH, config.STEP_SIZE):
    t_end = t_start + config.WINDOW_WIDTH
    time_windows.append((t_start, t_end))
    window_centers.append(t_start + config.WINDOW_WIDTH // 2)

print(f"✓ Created {len(time_windows)} time windows")
print(f"✓ Window width: {config.WINDOW_WIDTH}ms, Step: {config.STEP_SIZE}ms")
print(f"✓ Analysis period: {config.TIME_START}ms to {config.TIME_END}ms")
print(f"✓ First few window centers: {window_centers[:5]}ms")

## 4: EXTRACT TRIAL-ALIGNED SPIKE TIMES

In [None]:
def extract_trial_aligned_spikes(data_filtered, unit_id, alignment_event='first_stimulus'):
    
    
    # Get data for this unit
    unit_data = data_filtered[data_filtered['unit_id'] == unit_id].copy()
    
    if len(unit_data) == 0:
        return {}
    
    # Get spike timestamps (convert to seconds)
    spike_times = np.asarray(unit_data["timestamps"].iloc[0]).flatten() / 1e6
    
    # Get trial events and alignment indices
    events = unit_data['events'].iloc[0].squeeze()
    
    # Choose alignment based on event type
    if alignment_event == 'first_stimulus':
        align_indices = unit_data['idxEnc1'].iloc[0].squeeze() - 1
    elif alignment_event == 'second_stimulus':
        align_indices = unit_data['idxEnc2'].iloc[0].squeeze() - 1
    elif alignment_event == 'probe':
        align_indices = unit_data['idxProbeOn'].iloc[0].squeeze() - 1
    else:
        align_indices = unit_data['idxEnc1'].iloc[0].squeeze() - 1  # Default to first stimulus
    
    trial_data = {}
    
    # Extract spikes for each trial
    for trial_idx, event_idx in enumerate(align_indices):
        try:
            # Get alignment time for this trial
            align_time = events[event_idx, 0] / 1e6  # Convert to seconds
            
            # Extract spikes in analysis window around alignment
            window_start = align_time + config.TIME_START / 1000  # Convert ms to s
            window_end = align_time + config.TIME_END / 1000
            
            trial_spikes = spike_times[(spike_times >= window_start) & (spike_times <= window_end)]
            trial_spikes_aligned = (trial_spikes - align_time) * 1000  # Convert to ms relative to event
            
            # Get trial conditions
            trial_data[trial_idx] = {
                'spikes': trial_spikes_aligned,
                'first_cat': unit_data.iloc[trial_idx]['first_cat_simple'],
                'second_cat': unit_data.iloc[trial_idx]['second_cat_simple'],
                'first_num': unit_data.iloc[trial_idx]['first_num_simple'],
                'second_num': unit_data.iloc[trial_idx]['second_num_simple'],
                'probe_validity': unit_data.iloc[trial_idx].get('probe_validity', 'unknown')
            }
            
        except (IndexError, KeyError):
            # Skip trials with missing data
            continue
            
    return trial_data

# Test the function
print("✓ Trial alignment function defined")

## 5: COMPUTE SLIDING WINDOW FIRING RATES

In [None]:
def compute_sliding_window_firing_rates(trial_aligned_spikes, time_windows):
    
    
    results = []
    
    for trial_idx, trial_data in trial_aligned_spikes.items():
        spikes = trial_data['spikes']
        
        for window_idx, (t_start, t_end) in enumerate(time_windows):
            # Count spikes in this window
            spike_count = np.sum((spikes >= t_start) & (spikes < t_end))
            
            # Convert to firing rate (Hz)
            window_duration = (t_end - t_start) / 1000  # Convert ms to seconds
            firing_rate = spike_count / window_duration
            
            # Store results
            results.append({
                'trial_idx': trial_idx,
                'window_idx': window_idx,
                'window_start': t_start,
                'window_end': t_end,
                'window_center': (t_start + t_end) / 2,
                'firing_rate': firing_rate,
                'spike_count': spike_count,
                'first_cat': trial_data['first_cat'],
                'second_cat': trial_data['second_cat'],
                'first_num': trial_data['first_num'],
                'second_num': trial_data['second_num'],
                'probe_validity': trial_data['probe_validity']
            })
    
    return pd.DataFrame(results)

# Test the function
print("✓ Sliding window firing rate function defined")

## 6: TEST SELECTIVITY AT SINGLE TIME POINT

In [35]:
def test_selectivity_at_timepoint_corrected(firing_rate_data, window_center):
    """
    Test neural selectivity with proper statistical framework
    """
    results = []
    
    # Group by unit
    for unit_id, unit_data in firing_rate_data.groupby('unit_id'):
        
        # Check minimum requirements
        if len(unit_data) < config.MIN_TRIALS:
            continue
            
        if unit_data['firing_rate'].mean() < config.MIN_FIRING_RATE:
            continue
            
        # Skip if no variance
        if unit_data['firing_rate'].std() == 0:
            continue
        
        try:
            # Test category selectivity for first stimulus
            categories = unit_data['first_cat'].unique()
            if len(categories) > 1:
                cat1_pval = stats.f_oneway(*[unit_data[unit_data['first_cat'] == cat]['firing_rate'] 
                                           for cat in categories]).pvalue
            else:
                cat1_pval = 1.0
            
            # Test category selectivity for second stimulus  
            categories = unit_data['second_cat'].unique()
            if len(categories) > 1:
                cat2_pval = stats.f_oneway(*[unit_data[unit_data['second_cat'] == cat]['firing_rate'] 
                                           for cat in categories]).pvalue
            else:
                cat2_pval = 1.0
            
            # Test numerosity selectivity for first stimulus
            numbers = unit_data['first_num'].unique()
            if len(numbers) > 1:
                num1_pval = stats.f_oneway(*[unit_data[unit_data['first_num'] == num]['firing_rate'] 
                                           for num in numbers]).pvalue
            else:
                num1_pval = 1.0
                
            # Test numerosity selectivity for second stimulus
            numbers = unit_data['second_num'].unique()
            if len(numbers) > 1:
                num2_pval = stats.f_oneway(*[unit_data[unit_data['second_num'] == num]['firing_rate'] 
                                           for num in numbers]).pvalue
            else:
                num2_pval = 1.0
            
            # Store results with RAW p-values (correction applied later)
            results.append({
                'unit_id': unit_id,
                'window_center': window_center,
                'n_trials': len(unit_data),
                'mean_firing_rate': unit_data['firing_rate'].mean(),
                
                # Raw p-values (no correction yet)
                'cat1_pval': cat1_pval,
                'cat2_pval': cat2_pval, 
                'num1_pval': num1_pval,
                'num2_pval': num2_pval,
                
                # Uncorrected significance (for comparison)
                'cat1_selective_uncorrected': cat1_pval < config.P_THRESHOLD,
                'cat2_selective_uncorrected': cat2_pval < config.P_THRESHOLD,
                'num1_selective_uncorrected': num1_pval < config.P_THRESHOLD,
                'num2_selective_uncorrected': num2_pval < config.P_THRESHOLD
            })
            
        except Exception as e:
            # Handle any statistical errors
            continue
    
    return pd.DataFrame(results)


def apply_multiple_comparison_corrections(full_results):
    """
    Apply different multiple comparison correction methods
    """
    from statsmodels.stats.multitest import multipletests
    
    print("Applying multiple comparison corrections...")
    
    corrected_results = []
    
    for unit_id, unit_data in tqdm(full_results.groupby('unit_id'), desc="Correcting p-values"):
        unit_data = unit_data.copy().sort_values('window_center')
        
        # METHOD 1: FDR Correction (across time for each condition)
        for condition in ['cat1', 'cat2', 'num1', 'num2']:
            pval_col = f'{condition}_pval'
            pvals = unit_data[pval_col].values
            
            # Apply FDR correction
            _, fdr_corrected, _, _ = multipletests(pvals, method='fdr_bh')
            unit_data[f'{condition}_pval_fdr'] = fdr_corrected
            unit_data[f'{condition}_selective_fdr'] = fdr_corrected < 0.05
        
        # METHOD 2: Bonferroni Correction (very conservative)
        bonferroni_threshold = 0.05 / (len(unit_data) * 4)  # 4 conditions tested
        for condition in ['cat1', 'cat2', 'num1', 'num2']:
            pval_col = f'{condition}_pval'
            unit_data[f'{condition}_selective_bonferroni'] = unit_data[pval_col] < bonferroni_threshold
        
        # METHOD 3: Cluster-based correction (consecutive significant windows)
        min_consecutive = 3  # Require 3 consecutive significant windows
        
        for condition in ['cat1', 'cat2', 'num1', 'num2']:
            uncorrected_col = f'{condition}_selective_uncorrected'
            sig_windows = unit_data[uncorrected_col].values
            
            # Find consecutive runs
            cluster_corrected = np.zeros_like(sig_windows, dtype=bool)
            current_run = []
            
            for i, is_sig in enumerate(sig_windows):
                if is_sig:
                    current_run.append(i)
                else:
                    if len(current_run) >= min_consecutive:
                        cluster_corrected[current_run] = True
                    current_run = []
            
            # Check final run
            if len(current_run) >= min_consecutive:
                cluster_corrected[current_run] = True
            
            unit_data[f'{condition}_selective_cluster'] = cluster_corrected
        
        corrected_results.append(unit_data)
    
    return pd.concat(corrected_results, ignore_index=True)


def calculate_corrected_population_timecourse(full_results_corrected):
    """
    Calculate population percentages using different correction methods
    """
    
    # Group by time and calculate percentages for each correction method
    methods = ['uncorrected', 'fdr', 'bonferroni', 'cluster']
    conditions = ['cat1', 'cat2', 'num1', 'num2']
    
    population_data = []
    
    for time_point, time_data in full_results_corrected.groupby('window_center'):
        
        row = {'time_ms': time_point, 'n_units': len(time_data)}
        
        for method in methods:
            for condition in conditions:
                if method == 'uncorrected':
                    col_name = f'{condition}_selective_uncorrected'
                else:
                    col_name = f'{condition}_selective_{method}'
                
                if col_name in time_data.columns:
                    percentage = time_data[col_name].mean() * 100
                    row[f'pct_{condition}_selective_{method}'] = percentage
                else:
                    row[f'pct_{condition}_selective_{method}'] = 0
        
        population_data.append(row)
    
    return pd.DataFrame(population_data)


# UPDATED MAIN ANALYSIS FUNCTION
def analyze_time_resolved_selectivity_with_corrections(data_filtered, alignment_event='first_stimulus'):
    """
    Main analysis with proper multiple comparison corrections
    """
    print(f"Starting time-resolved selectivity analysis with corrections...")
    
    # Get all unique units
    all_units = data_filtered['unit_id'].unique()
    all_selectivity_results = []
    
    # Process each time window (same as before)
    for window_idx, window_center in enumerate(tqdm(window_centers, desc="Processing time windows")):
        
        timepoint_data = []
        
        for unit_id in all_units:
            trial_spikes = extract_trial_aligned_spikes(data_filtered, unit_id, alignment_event)
            
            if len(trial_spikes) == 0:
                continue
                
            unit_rates = compute_sliding_window_firing_rates(trial_spikes, [time_windows[window_idx]])
            
            if len(unit_rates) > 0:
                unit_rates['unit_id'] = unit_id
                timepoint_data.append(unit_rates)
        
        if len(timepoint_data) == 0:
            continue
            
        timepoint_df = pd.concat(timepoint_data, ignore_index=True)
        
        # Use corrected version of selectivity testing
        selectivity_results = test_selectivity_at_timepoint_corrected(timepoint_df, window_center)
        all_selectivity_results.append(selectivity_results)
    
    # Combine results across time
    if len(all_selectivity_results) > 0:
        full_results = pd.concat(all_selectivity_results, ignore_index=True)
        
        # Apply multiple comparison corrections
        full_results_corrected = apply_multiple_comparison_corrections(full_results)
        
        # Calculate population timecourses for each correction method
        population_timecourse = calculate_corrected_population_timecourse(full_results_corrected)
        
        print(f"✓ Analysis complete with corrections!")
        print(f"✓ {len(population_timecourse)} timepoints analyzed")
        print(f"✓ Applied uncorrected, FDR, Bonferroni, and cluster-based corrections")
        
        return population_timecourse, full_results_corrected
    
    else:
        print("✗ No results generated")
        return None, None

## 7: MAIN ANALYSIS PIPELINE

In [36]:
population_timecourse, full_results = analyze_time_resolved_selectivity_with_corrections(data_filtered)

print(population_timecourse.columns)


Starting time-resolved selectivity analysis with corrections...


Processing time windows:  15%|█▌        | 9/60 [04:06<23:45, 27.96s/it]

## 8: PLOT

In [None]:
# ===================================================================
# TIME-RESOLVED NEURAL SELECTIVITY VISUALIZATION FUNCTIONS
# ===================================================================

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from matplotlib.patches import Rectangle
import matplotlib.patches as mpatches

# Set up plotting parameters
plt.rcParams['figure.figsize'] = (15, 10)
plt.rcParams['font.size'] = 12
plt.rcParams['axes.linewidth'] = 1.5
plt.rcParams['xtick.major.size'] = 6
plt.rcParams['ytick.major.size'] = 6

def plot_main_selectivity_timecourse(population_timecourse, save_path="selectivity_timecourse.png"):
    """
    Main figure: % of neurons selective for different aspects over time
    """
    
    print("Creating main selectivity timecourse plot...")
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 12), sharex=True)
    
    # Colors for different selectivity types
    colors = {
        'cat1': '#E74C3C',  # Red
        'cat2': '#3498DB',  # Blue  
        'num1': '#E67E22',  # Orange
        'num2': '#9B59B6'   # Purple
    }
    
    # Plot 1: Category Selectivity
    ax1.plot(population_timecourse['time_ms'], population_timecourse['pct_cat1_selective'], 
             color=colors['cat1'], linewidth=3, label='First Category', alpha=0.9)
    ax1.plot(population_timecourse['time_ms'], population_timecourse['pct_cat2_selective'], 
             color=colors['cat2'], linewidth=3, label='Second Category', alpha=0.9)
    
    ax1.fill_between(population_timecourse['time_ms'], population_timecourse['pct_cat1_selective'], 
                     alpha=0.3, color=colors['cat1'])
    ax1.fill_between(population_timecourse['time_ms'], population_timecourse['pct_cat2_selective'], 
                     alpha=0.3, color=colors['cat2'])
    
    ax1.set_ylabel('% Neurons Selective\nfor Category', fontsize=14, fontweight='bold')
    ax1.set_title('Time Course of Neural Selectivity During Working Memory Task', 
                  fontsize=16, fontweight='bold', pad=20)
    ax1.legend(loc='upper right', fontsize=12)
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim(0, max(population_timecourse[['pct_cat1_selective', 'pct_cat2_selective']].max()) * 1.1)
    
    # Plot 2: Numerosity Selectivity  
    ax2.plot(population_timecourse['time_ms'], population_timecourse['pct_num1_selective'], 
             color=colors['num1'], linewidth=3, label='First Numerosity', alpha=0.9)
    ax2.plot(population_timecourse['time_ms'], population_timecourse['pct_num2_selective'], 
             color=colors['num2'], linewidth=3, label='Second Numerosity', alpha=0.9)
    
    ax2.fill_between(population_timecourse['time_ms'], population_timecourse['pct_num1_selective'], 
                     alpha=0.3, color=colors['num1'])
    ax2.fill_between(population_timecourse['time_ms'], population_timecourse['pct_num2_selective'], 
                     alpha=0.3, color=colors['num2'])
    
    ax2.set_ylabel('% Neurons Selective\nfor Numerosity', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Time from First Stimulus Onset (ms)', fontsize=14, fontweight='bold')
    ax2.legend(loc='upper right', fontsize=12)
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(0, max(population_timecourse[['pct_num1_selective', 'pct_num2_selective']].max()) * 1.1)
    
    # Add task event markers
    events = {
        'First Stimulus': 0,
        'First Delay': 1000, 
        'Second Stimulus': 2000,
        'Second Delay': 3000,
        'Probe': 5500
    }
    
    for ax in [ax1, ax2]:
        for event_name, event_time in events.items():
            ax.axvline(event_time, color='black', linestyle='--', alpha=0.7, linewidth=2)
            ax.text(event_time + 50, ax.get_ylim()[1] * 0.9, event_name, 
                   rotation=90, fontsize=10, alpha=0.8, fontweight='bold')
    
    # Add epoch shading
    epoch_colors = {
        'Encoding 1': (0, 1000, '#FFE5E5'),
        'Delay 1': (1000, 2000, '#E5F2FF'), 
        'Encoding 2': (2000, 3000, '#FFE5E5'),
        'Delay 2': (3000, 5500, '#E5F2FF'),
        'Response': (5500, 6000, '#F0F0F0')
    }
    
    for ax in [ax1, ax2]:
        for epoch_name, (start, end, color) in epoch_colors.items():
            ax.axvspan(start, end, alpha=0.2, color=color, zorder=0)
    
    # Add sample size info
    avg_units = population_timecourse['n_units'].mean()
    fig.text(0.02, 0.02, f'Average units per timepoint: {avg_units:.0f}', 
             fontsize=10, style='italic')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"✓ Saved main timecourse plot: {save_path}")
    plt.show()
    
    return fig


def plot_selectivity_heatmap(population_timecourse, save_path="selectivity_heatmap.png"):
    """
    Heatmap showing all selectivity types over time
    """
    
    print("Creating selectivity heatmap...")
    
    # Prepare data for heatmap
    heatmap_data = population_timecourse.set_index('time_ms')[
        ['pct_cat1_selective', 'pct_cat2_selective', 'pct_num1_selective', 'pct_num2_selective']
    ].T
    
    # Rename for better labels
    heatmap_data.index = ['First Category', 'Second Category', 'First Numerosity', 'Second Numerosity']
    
    fig, ax = plt.subplots(figsize=(16, 6))
    
    # Create heatmap
    im = ax.imshow(heatmap_data.values, aspect='auto', cmap='viridis', interpolation='gaussian')
    
    # Customize axes
    ax.set_xticks(np.arange(0, len(heatmap_data.columns), 5))
    ax.set_xticklabels(heatmap_data.columns[::5])
    ax.set_yticks(range(len(heatmap_data.index)))
    ax.set_yticklabels(heatmap_data.index, fontsize=12)
    
    ax.set_xlabel('Time from First Stimulus Onset (ms)', fontsize=14, fontweight='bold')
    ax.set_title('Neural Selectivity Heatmap: % of Neurons Over Time', 
                 fontsize=16, fontweight='bold', pad=20)
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=ax, shrink=0.8)
    cbar.set_label('% Neurons Selective', fontsize=12, fontweight='bold')
    
    # Add task event lines
    events = [0, 1000, 2000, 3000, 5500]  # Convert to indices
    time_points = heatmap_data.columns.values
    
    for event_time in events:
        # Find closest time index
        event_idx = np.argmin(np.abs(time_points - event_time))
        ax.axvline(event_idx, color='white', linestyle='--', alpha=0.8, linewidth=2)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"✓ Saved selectivity heatmap: {save_path}")
    plt.show()
    
    return fig


def plot_brain_region_comparison(full_results, save_path="brain_region_comparison.png"):
    """
    Compare selectivity across brain regions
    """
    
    print("Creating brain region comparison...")
    
    # Calculate average selectivity by brain region across time
    region_data = []
    
    for region in full_results['brainAreaOfCell'].unique():
        region_subset = full_results[full_results['brainAreaOfCell'] == region]
        
        region_summary = region_subset.groupby('window_center').agg({
            'cat1_selective': 'mean',
            'cat2_selective': 'mean', 
            'num1_selective': 'mean',
            'num2_selective': 'mean',
            'unit_id': 'count'
        }).reset_index()
        
        region_summary['region'] = region
        region_data.append(region_summary)
    
    region_df = pd.concat(region_data, ignore_index=True)
    
    # Create subplot for each brain region
    regions = sorted(full_results['brainAreaOfCell'].unique())
    n_regions = len(regions)
    
    fig, axes = plt.subplots(2, 3, figsize=(20, 12), sharex=True, sharey=True)
    axes = axes.flatten()
    
    colors = ['#E74C3C', '#3498DB', '#E67E22', '#9B59B6']
    labels = ['First Category', 'Second Category', 'First Numerosity', 'Second Numerosity']
    
    for i, region in enumerate(regions):
        if i >= len(axes):
            break
            
        ax = axes[i]
        region_data = region_df[region_df['region'] == region]
        
        if len(region_data) == 0:
            continue
            
        # Plot each selectivity type
        for j, (col, color, label) in enumerate(zip(
            ['cat1_selective', 'cat2_selective', 'num1_selective', 'num2_selective'],
            colors, labels)):
            
            ax.plot(region_data['window_center'], region_data[col] * 100, 
                   color=color, linewidth=2.5, label=label, alpha=0.8)
        
        ax.set_title(f'{region} (n={region_data["unit_id"].iloc[0]:.0f})', 
                    fontsize=14, fontweight='bold')
        ax.grid(True, alpha=0.3)
        ax.set_ylim(0, 60)
        
        # Add task events
        events = [0, 1000, 2000, 3000, 5500]
        for event_time in events:
            ax.axvline(event_time, color='gray', linestyle='--', alpha=0.5)
        
        if i == 0:  # Add legend to first subplot
            ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    # Remove empty subplots
    for i in range(n_regions, len(axes)):
        fig.delaxes(axes[i])
    
    fig.text(0.5, 0.02, 'Time from First Stimulus Onset (ms)', 
             ha='center', fontsize=14, fontweight='bold')
    fig.text(0.02, 0.5, '% Neurons Selective', va='center', rotation=90, 
             fontsize=14, fontweight='bold')
    fig.suptitle('Neural Selectivity by Brain Region', fontsize=16, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"✓ Saved brain region comparison: {save_path}")
    plt.show()
    
    return fig


def plot_peak_selectivity_analysis(population_timecourse, save_path="peak_selectivity_analysis.png"):
    """
    Analyze timing of peak selectivity for each condition
    """
    
    print("Creating peak selectivity analysis...")
    
    # Find peak times and values
    conditions = ['pct_cat1_selective', 'pct_cat2_selective', 'pct_num1_selective', 'pct_num2_selective']
    condition_names = ['First Category', 'Second Category', 'First Numerosity', 'Second Numerosity']
    colors = ['#E74C3C', '#3498DB', '#E67E22', '#9B59B6']
    
    peak_data = []
    for condition, name in zip(conditions, condition_names):
        peak_idx = population_timecourse[condition].argmax()
        peak_time = population_timecourse.iloc[peak_idx]['time_ms']
        peak_value = population_timecourse.iloc[peak_idx][condition]
        
        peak_data.append({
            'condition': name,
            'peak_time': peak_time,
            'peak_value': peak_value
        })
    
    peak_df = pd.DataFrame(peak_data)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Plot 1: Peak times
    bars1 = ax1.bar(peak_df['condition'], peak_df['peak_time'], color=colors, alpha=0.7, edgecolor='black')
    ax1.set_ylabel('Peak Time (ms)', fontsize=12, fontweight='bold')
    ax1.set_title('Timing of Peak Selectivity', fontsize=14, fontweight='bold')
    ax1.tick_params(axis='x', rotation=45)
    
    # Add value labels on bars
    for bar, time in zip(bars1, peak_df['peak_time']):
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height + 50,
                f'{time:.0f}ms', ha='center', va='bottom', fontweight='bold')
    
    # Add task event reference lines
    events = {'First Stim': 0, 'Second Stim': 2000, 'Probe': 5500}
    for event, time in events.items():
        ax1.axhline(time, color='gray', linestyle='--', alpha=0.5)
        ax1.text(0.1, time + 100, event, fontsize=10, alpha=0.7)
    
    # Plot 2: Peak values
    bars2 = ax2.bar(peak_df['condition'], peak_df['peak_value'], color=colors, alpha=0.7, edgecolor='black')
    ax2.set_ylabel('Peak Selectivity (%)', fontsize=12, fontweight='bold')
    ax2.set_title('Maximum Selectivity Reached', fontsize=14, fontweight='bold')
    ax2.tick_params(axis='x', rotation=45)
    
    # Add value labels on bars
    for bar, value in zip(bars2, peak_df['peak_value']):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height + 0.5,
                f'{value:.1f}%', ha='center', va='bottom', fontweight='bold')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"✓ Saved peak selectivity analysis: {save_path}")
    plt.show()
    
    return fig, peak_df


def plot_selectivity_onset_analysis(population_timecourse, threshold=10, save_path="selectivity_onset_analysis.png"):
    """
    Analyze when selectivity first emerges (onset analysis)
    """
    
    print("Creating selectivity onset analysis...")
    
    conditions = ['pct_cat1_selective', 'pct_cat2_selective', 'pct_num1_selective', 'pct_num2_selective']
    condition_names = ['First Category', 'Second Category', 'First Numerosity', 'Second Numerosity']
    colors = ['#E74C3C', '#3498DB', '#E67E22', '#9B59B6']
    
    onset_data = []
    
    fig, ax = plt.subplots(figsize=(14, 8))
    
    for condition, name, color in zip(conditions, condition_names, colors):
        # Find first time point where selectivity exceeds threshold
        above_threshold = population_timecourse[population_timecourse[condition] > threshold]
        
        if len(above_threshold) > 0:
            onset_time = above_threshold.iloc[0]['time_ms']
            onset_data.append({'condition': name, 'onset_time': onset_time})
            
            # Plot full timecourse
            ax.plot(population_timecourse['time_ms'], population_timecourse[condition], 
                   color=color, linewidth=3, label=name, alpha=0.8)
            
            # Mark onset
            ax.scatter([onset_time], [threshold], color=color, s=100, zorder=5, 
                      marker='o', edgecolor='black', linewidth=2)
            ax.text(onset_time + 100, threshold + 1, f'{onset_time:.0f}ms', 
                   color=color, fontweight='bold', fontsize=10)
    
    # Add threshold line
    ax.axhline(threshold, color='black', linestyle=':', linewidth=2, alpha=0.7, 
              label=f'Threshold ({threshold}%)')
    
    # Add task events
    events = [0, 1000, 2000, 3000, 5500]
    for event_time in events:
        ax.axvline(event_time, color='gray', linestyle='--', alpha=0.3)
    
    ax.set_xlabel('Time from First Stimulus Onset (ms)', fontsize=12, fontweight='bold')
    ax.set_ylabel('% Neurons Selective', fontsize=12, fontweight='bold')
    ax.set_title(f'Selectivity Onset Analysis (Threshold: {threshold}%)', 
                fontsize=14, fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"✓ Saved selectivity onset analysis: {save_path}")
    plt.show()
    
    onset_df = pd.DataFrame(onset_data) if onset_data else pd.DataFrame()
    return fig, onset_df


def create_comprehensive_report(population_timecourse, full_results, save_prefix="selectivity_analysis"):
    """
    Create all visualizations and save them
    """
    
    print("Creating comprehensive visualization report...")
    print("="*60)
    
    # Create main plots
    fig1 = plot_main_selectivity_timecourse(population_timecourse, f"{save_prefix}_main_timecourse.png")
    
    fig2 = plot_selectivity_heatmap(population_timecourse, f"{save_prefix}_heatmap.png")
    
    # Skip brain region comparison if brainAreaOfCell column is missing
    fig3 = None
    if 'brainAreaOfCell' in full_results.columns:
        fig3 = plot_brain_region_comparison(full_results, f"{save_prefix}_brain_regions.png")
    else:
        print("⚠️ Skipping brain region comparison - 'brainAreaOfCell' column not found")
    
    fig4, peak_df = plot_peak_selectivity_analysis(population_timecourse, f"{save_prefix}_peak_analysis.png")
    
    fig5, onset_df = plot_selectivity_onset_analysis(population_timecourse, f"{save_prefix}_onset_analysis.png", threshold=10)
    
    # Print summary statistics
    print("\n" + "="*60)
    print("SUMMARY STATISTICS")
    print("="*60)
    
    print("\nPeak Selectivity Times:")
    print(peak_df.to_string(index=False))
    
    if len(onset_df) > 0:
        print("\nSelectivity Onset Times (>10% threshold):")
        print(onset_df.to_string(index=False))
    
    print(f"\nAverage number of units analyzed: {population_timecourse['n_units'].mean():.1f}")
    print(f"Time range analyzed: {population_timecourse['time_ms'].min():.0f} to {population_timecourse['time_ms'].max():.0f} ms")
    
    # Only print brain regions if the column exists
    if 'brainAreaOfCell' in full_results.columns:
        print(f"Number of brain regions: {full_results['brainAreaOfCell'].nunique()}")
    else:
        print("Brain region information not available")
    
    overall_max = population_timecourse[['pct_cat1_selective', 'pct_cat2_selective', 
                                       'pct_num1_selective', 'pct_num2_selective']].max().max()
    print(f"Maximum selectivity reached: {overall_max:.1f}%")
    
    print("\n✓ All visualizations complete!")
    
    return {
        'main_timecourse': fig1,
        'heatmap': fig2, 
        'brain_regions': fig3,
        'peak_analysis': fig4,
        'onset_analysis': fig5,
        'peak_data': peak_df,
        'onset_data': onset_df
    }


# ===================================================================
# QUICK START: Run this to create all plots
# ===================================================================

# Uncomment to run all visualizations:
# results = create_comprehensive_report(population_timecourse, full_results)

print("✓ All visualization functions loaded!")
print("\nTo create all plots, run:")
print("results = create_comprehensive_report(population_timecourse, full_results)")

In [None]:
# add brain area back to full_results

# Check what columns exist in your current data
print("📊 CURRENT DATA STATUS:")
try:
    print(f"✅ data_filtered shape: {data_filtered.shape}")
    print(f"✅ data_filtered columns: {list(data_filtered.columns)}")
    print(f"❌ Contains brainAreaOfCell: {'brainAreaOfCell' in data_filtered.columns}")
except NameError:
    print("❌ data_filtered not found")

# Check if original processed data still has brain area info
print("\n📊 CHECKING ORIGINAL DATA:")
try:
    print(f"✅ df_sample_new shape: {df_sample_new.shape}")
    print(f"✅ Contains brainAreaOfCell: {'brainAreaOfCell' in df_sample_new.columns}")
    if 'brainAreaOfCell' in df_sample_new.columns:
        print(f"✅ Brain areas in df_sample_new: {df_sample_new['brainAreaOfCell'].unique()}")
        print("🎯 FOUND IT! Brain area info is in df_sample_new")
except NameError:
    print("❌ df_sample_new not found")

# Check if we can trace back to the original df
try:
    print(f"\n✅ df shape: {df.shape}")
    print(f"✅ Contains brainAreaOfCell: {'brainAreaOfCell' in df.columns}")
    if 'brainAreaOfCell' in df.columns:
        print(f"✅ Brain areas in df: {df['brainAreaOfCell'].unique()}")
except NameError:
    print("❌ df not found")


# ===================================================================
# SOLUTION 1: RECREATE BRAIN AREA MAPPING FROM AVAILABLE DATA
# ===================================================================

def create_brain_area_mapping_from_source():
    """
    Recreate the unit_id to brain area mapping from the available source data
    """
    print("\n🛠️  CREATING BRAIN AREA MAPPING...")
    
    # Try to find brain area info in available DataFrames
    brain_area_source = None
    
    # Option 1: Use df_sample_new if available
    try:
        if 'brainAreaOfCell' in df_sample_new.columns:
            brain_area_source = df_sample_new[['unit_id', 'brainAreaOfCell']].drop_duplicates()
            print("✅ Using brain area info from df_sample_new")
    except NameError:
        pass
    
    # Option 2: Use df if df_sample_new not available
    if brain_area_source is None:
        try:
            if 'brainAreaOfCell' in df.columns:
                # Need to add unit_id to df first
                df_with_unit_id = df.copy()
                df_with_unit_id['unit_id'] = df_with_unit_id.index
                brain_area_source = df_with_unit_id[['unit_id', 'brainAreaOfCell']].drop_duplicates()
                print("✅ Using brain area info from df")
        except NameError:
            pass
    
    # Option 3: Recreate from original MATLAB data
    if brain_area_source is None:
        print("🔧 Recreating from MATLAB data...")
        try:
            # Use the collapsed area mapping we defined earlier
            collapsed_area_map = {
                1: 'H', 2: 'H', 3: 'A', 4: 'A', 5: 'AC', 6: 'AC',
                7: 'SMA', 8: 'SMA', 9: 'PT', 10: 'PT', 11: 'OFC', 12: 'OFC',
                50: 'FFA', 51: 'EC', 52: 'CM', 53: 'CM', 54: 'PUL', 55: 'PUL',
                56: 'N/A', 57: 'PRV', 58: 'PRV'
            }
            
            # If we have the original cell data, recreate the mapping
            if 'df' in globals() and len(df) > 0:
                # Create unit_id mapping
                unit_brain_map = []
                for idx, row in df.iterrows():
                    brain_code = row['brainAreaOfCell']
                    if isinstance(brain_code, np.ndarray):
                        brain_code = int(brain_code[0, 0])
                    brain_area = collapsed_area_map.get(brain_code, 'Unknown')
                    unit_brain_map.append({'unit_id': idx, 'brainAreaOfCell': brain_area})
                
                brain_area_source = pd.DataFrame(unit_brain_map)
                print("✅ Recreated brain area mapping from original data")
        except Exception as e:
            print(f"❌ Failed to recreate from MATLAB data: {e}")
    
    return brain_area_source


def fix_data_filtered_brain_area():
    """
    Add brain area information back to data_filtered
    """
    global data_filtered  # We'll modify the global variable
    
    print("\n🔧 FIXING data_filtered...")
    
    # Get brain area mapping
    brain_mapping = create_brain_area_mapping_from_source()
    
    if brain_mapping is None:
        print("❌ Could not find brain area information anywhere!")
        return None
    
    print(f"✅ Found brain area mapping for {len(brain_mapping)} units")
    print(f"✅ Brain areas: {brain_mapping['brainAreaOfCell'].unique()}")
    
    # Add brain area to data_filtered
    data_filtered_fixed = data_filtered.merge(
        brain_mapping, 
        on='unit_id', 
        how='left'
    )
    
    # Check success
    missing_areas = data_filtered_fixed['brainAreaOfCell'].isna().sum()
    if missing_areas > 0:
        print(f"⚠️  Warning: {missing_areas} rows missing brain area info")
    else:
        print("✅ Successfully added brain area info to all rows!")
    
    print(f"✅ Updated data_filtered shape: {data_filtered_fixed.shape}")
    print(f"✅ Brain areas in fixed data: {data_filtered_fixed['brainAreaOfCell'].unique()}")
    
    return data_filtered_fixed


# ===================================================================
# SOLUTION 2: QUICK MANUAL RECONSTRUCTION
# ===================================================================

def manual_brain_area_reconstruction():
    """
    Manually reconstruct brain area mapping if automated methods fail
    """
    print("\n🛠️  MANUAL RECONSTRUCTION:")
    print("If automated methods fail, we can manually recreate the mapping...")
    
    collapsed_area_map = {
        1: 'H', 2: 'H',           # Hippocampus
        3: 'A', 4: 'A',           # Amygdala  
        5: 'AC', 6: 'AC',         # Anterior Cingulate
        7: 'SMA', 8: 'SMA',       # Supplementary Motor Area
        9: 'PT', 10: 'PT',        # Parahippocampal cortex
        11: 'OFC', 12: 'OFC',     # Orbitofrontal Cortex
        50: 'FFA', 51: 'EC',      # Other areas
        52: 'CM', 53: 'CM',
        54: 'PUL', 55: 'PUL',
        56: 'N/A', 57: 'PRV', 58: 'PRV'
    }
    
    print("Brain area mapping codes:")
    for code, area in collapsed_area_map.items():
        print(f"  {code}: {area}")
    
    return collapsed_area_map


# ===================================================================
# RUN THE FIX
# ===================================================================

print("\n" + "="*60)
print("🚀 RUNNING THE FIX...")
print("="*60)

# Step 1: Diagnose what data we have
brain_mapping = create_brain_area_mapping_from_source()

if brain_mapping is not None:
    # Step 2: Fix data_filtered
    data_filtered_fixed = fix_data_filtered_brain_area()
    
    if data_filtered_fixed is not None:
        # Update the global variable
        data_filtered = data_filtered_fixed
        print("\n✅ SUCCESS! data_filtered now has brain area information")
        
        # Step 3: Now fix the results
        print("\n🔧 Now fixing full_results...")
        try:
            # Create the unit-brain mapping
            unit_brain_mapping = data_filtered[['unit_id', 'brainAreaOfCell']].drop_duplicates()
            
            # Merge into full_results
            full_results_fixed = full_results.merge(
                unit_brain_mapping, 
                on='unit_id', 
                how='left'
            )
            
            # Update global variable
            full_results = full_results_fixed
            
            print("✅ SUCCESS! full_results now has brain area information")
            print(f"✅ Brain areas in results: {full_results['brainAreaOfCell'].unique()}")
            
            # Final verification
            print(f"\n📊 FINAL STATUS:")
            print(f"✅ data_filtered shape: {data_filtered.shape}")
            print(f"✅ data_filtered has brainAreaOfCell: {'brainAreaOfCell' in data_filtered.columns}")
            print(f"✅ full_results shape: {full_results.shape}")
            print(f"✅ full_results has brainAreaOfCell: {'brainAreaOfCell' in full_results.columns}")
            
            # Count units per area
            if 'brainAreaOfCell' in full_results.columns:
                area_counts = full_results.groupby(['unit_id', 'brainAreaOfCell']).size().reset_index()[['unit_id', 'brainAreaOfCell']].groupby('brainAreaOfCell').size()
                print(f"\n📊 Units per brain area:")
                print(area_counts.to_string())
            
        except Exception as e:
            print(f"❌ Error fixing full_results: {e}")
    
else:
    print("❌ Could not create brain area mapping. Need to trace back further...")
    manual_brain_area_reconstruction()

print("\n" + "="*60)
print("🎯 NEXT STEPS:")
print("1. Verify: print(data_filtered['brainAreaOfCell'].unique())")
print("2. Verify: print(full_results['brainAreaOfCell'].unique())")  
print("3. Run visualizations: create_comprehensive_report(population_timecourse, full_results)")
print("="*60)

In [None]:
# Run all visualizations at once
results = create_comprehensive_report(population_timecourse, full_results)