# Neural Basic Selectivity Analysis

Analysis of neuronal selectivity patterns during working memory binding tasks.

## Analysis Overview

1. **Data Loading & Initial Processing** - MATLAB data loading, brain area mapping, and unit quality control
2. **Neural Activity & Trial Processing** - Trial extraction, firing rate computation across task epochs, and data integration
3. **Feature Selectivity Analysis** - Statistical analysis identifying neurons selective for stimulus features and temporal position
4. **Results & Visualization** - Selectivity summaries, population distributions, and single-unit response visualization

---

## 1. Data Loading & Initial Processing

### Environment Setup

In [None]:
import sys
print(sys.executable)

In [5]:
# pip install numpy pandas matplotlib seaborn scipy mat73
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy.io import loadmat
from scipy.stats import ttest_1samp
import glob
import os
# import mat73

### Loading MATLAB Data Files

Loading of MATLAB data files containing neural recordings and trial information.

In [None]:
# Find all WMB_P*_v7.mat files in the current directory
mat_files = glob.glob('./WMB_P*_v7.mat')
print(f"Found {len(mat_files)} .mat files: {[os.path.basename(f) for f in mat_files]}")

# Initialize lists to store data from each file
cell_mats = []
total_mats = []

# Load each file and append its data
for mat_file in mat_files:
    print(f"\nLoading {mat_file}...")
    mat_data = loadmat(mat_file)
    cell_mats.append(mat_data['cellStatsAll'])
    total_mats.append(mat_data['totStats'])

# Print shapes of loaded data for debugging
print("\nShapes of loaded data:")
for i, (cell, total) in enumerate(zip(cell_mats, total_mats)):
    print(f"File {i}: cell_mat shape: {cell.shape}, total_mat shape: {total.shape}")

In [None]:
# Combine the data
# For cell_mat, we need to handle the different dimensions
# First, convert each cell_mat to a list of records
all_cell_records = []
for cell_mat in cell_mats:
    # Convert to list of records
    cell_list = cell_mat[0]  # now shape is (n,)
    records = []
    for cell in cell_list:
        record = {key: cell[key] for key in cell.dtype.names}
        records.append(record)
    all_cell_records.extend(records)

# Convert combined records to DataFrame
df = pd.DataFrame(all_cell_records)

# For total_mat, we can concatenate directly since they have the same structure
total_mat = np.concatenate(total_mats, axis=0)

print(f"\nCombined data shape - total_mat: {total_mat.shape}")
print(f"\nCombined data shape - df: {df.shape}")

### Translates numeric area codes

In [9]:
from collections import Counter

def count_area_codes(area_column):
    
    mapping = {
        1: 'RH', 2: 'LH', 3: 'RA', 4: 'LA', 5: 'RAC', 6: 'LAC',
        7: 'RSMA', 8: 'LSMA', 9: 'RPT', 10: 'LPT', 11: 'ROFC', 12: 'LOFC',
        50: 'RFFA', 51: 'REC', 52: 'RCM', 53: 'LCM', 54: 'RPUL', 55: 'LPUL',
        56: 'N/A', 57: 'RPRV', 58: 'LPRV'
    }
    
    labels = []
    for code in area_column:
        label = mapping.get(code, 'Unknown')
        labels.append(label)
    
    return dict(Counter(labels))

In [None]:
# neuron number in each area
area_codes = total_mat[:, 3]

counts = count_area_codes(area_codes)
print("Area counts (no prefix):")
for area, count in counts.items():
    print(f"{area}: {count}")

### Data Preprocessing & Filtering

First, we'll format the cell data and collapse brain areas into broader regions:

In [11]:
collapsed_area_map = {
    1: 'H', 2: 'H',
    3: 'A', 4: 'A',
    5: 'AC', 6: 'AC',
    7: 'SMA', 8: 'SMA',
    9: 'PT', 10: 'PT',
    11: 'OFC', 12: 'OFC',
    50: 'FFA', 51: 'EC',
    52: 'CM', 53: 'CM',
    54: 'PUL', 55: 'PUL',
    56: 'N/A', 57: 'PRV', 58: 'PRV'
}
# Convert brain area codes in the DataFrame
df['brainAreaOfCell'] = df['brainAreaOfCell'].apply(
    lambda x: collapsed_area_map.get(int(x[0, 0]), 'Unknown') if isinstance(x, np.ndarray) else collapsed_area_map.get(x, 'Unknown')
)

### Quality Control & Unit Selection

Filtering criteria for reliable unit selection:
- Minimum firing rate threshold
- Signal quality metrics

In [12]:
# Filter out units with low firing rate
fr = df['timestamps'].apply(lambda x: len(x) / (x[-1] - x[0]) * 1e6)
df_sample_new = df[fr > 0.1].reset_index(drop=True)

# unit id
df_sample_new = df_sample_new.reset_index(drop=True)
df_sample_new["unit_id"] = df_sample_new.index

## 2. Neural Activity & Trial Processing

### Trial Data Extraction

Extraction and alignment of trial-specific information with neural data:

In [13]:
def extract_trial_info(trials_struct, unit_id):
    # Build a DataFrame from the trials structure.
    df_trial = pd.DataFrame({field: trials_struct[field].squeeze() 
                             for field in trials_struct.dtype.names})
    # Add the unit_id so that you can later separate trials by unit/session.
    df_trial["unit_id"] = unit_id
    df_trial["trial_nr"] = df_trial["trial"].apply(lambda x: np.squeeze(x).item() if isinstance(x, (list, np.ndarray)) else x) - 1 # Adjust for 0-indexing
    return df_trial

In [14]:
trial_info_list = []
for idx, row in df_sample_new.iterrows():
    # Use the unit identifier from this row
    unit_id = row["unit_id"]  
    # Extract the trial DataFrame, including the unit identifier.
    trial_info_list.append(extract_trial_info(row["Trials"], unit_id, ))

# Concatenate the list of trial info DataFrames into one.
trial_info = pd.concat(trial_info_list, ignore_index=True)

### Firing Rate Computation Functions

The following functions extract event timestamps and compute firing rates for different task epochs. This is the core neural data processing step that converts spike timestamps into firing rates aligned to task events.


### Event Timestamp Extraction & Firing Rate Computation

Firing rate computation across task epochs:
- Baseline period (pre-stimulus)
- First stimulus encoding
- First delay period
- Second stimulus encoding
- Second delay period
- Response period


### Data Integration & Column Selection

Merge neural firing rate data with trial information and select relevant columns for analysis. This step combines the computed firing rates with behavioral trial data to enable selectivity analysis.


In [None]:
def extract_event_timestamps(df_sample_new, start_idx_col='idxEnc1', end_idx_col='idxDel1', is_window=False, window_size=0.5):
    """
    Extract event timestamps for computing firing rates across task epochs.
    
    Parameters:
    -----------
    df_sample_new : DataFrame
        Neural data with event indices and timestamps
    start_idx_col : str
        Column name containing start event indices
    end_idx_col : str  
        Column name containing end event indices
    is_window : bool
        If True, extract fixed window around start event; if False, extract epoch between start and end
    window_size : float
        Window size in seconds (only used if is_window=True)
        
    Returns:
    --------
    epoch_ts : dict
        Dictionary mapping unit_id to array of [start_time, end_time] pairs for each trial
    """
    epoch_ts = {}
    
    for i, row in df_sample_new.iterrows():
        unit_id = row['unit_id']
        events = row['events'].squeeze()  # Event timestamps array [trial_idx, event_type, timestamp]
        
        if is_window:
            # Extract fixed window around specific event (e.g., ±0.5s around response)
            idxs = row[start_idx_col].squeeze() - 1  # Convert to 0-based indexing
            extracted = events[idxs]
            center_times = extracted[:, 0]  # Get timestamp column
            window_start = center_times - window_size * 1e6  # Convert to microseconds
            window_end = center_times + window_size * 1e6
            combined = np.column_stack((window_start, window_end))
        else:
            # Extract epoch between two task events (e.g., encoding to delay)
            idxs_start = row[start_idx_col].squeeze() - 1  # Start event indices (0-based)
            idxs_end = row[end_idx_col].squeeze() - 1      # End event indices (0-based)
            
            # Handle mismatched trial counts between start and end events
            min_length = min(len(idxs_start), len(idxs_end))
            idxs_start = idxs_start[:min_length]
            idxs_end = idxs_end[:min_length]
            
            # Extract timestamps for start and end events
            extracted_start = events[idxs_start]  # [n_trials, 3] 
            extracted_end = events[idxs_end]      # [n_trials, 3]

            # Combine start and end timestamps for each trial
            combined = np.column_stack((extracted_start[:, 0], extracted_end[:, 0]))
        
        epoch_ts[unit_id] = combined
        
    return epoch_ts


def compute_firing_rates(df_sample_new, start_idx_col='idxEnc1', end_idx_col='idxDel1', 
                         fr_prefix='fr', is_window=False, window_size=0.5):
    """
    Compute firing rates for a specific task epoch.
    
    Parameters:
    -----------
    df_sample_new : DataFrame
        Neural data containing spike timestamps and event indices
    start_idx_col : str
        Column with start event indices
    end_idx_col : str
        Column with end event indices  
    fr_prefix : str
        Prefix for the firing rate column name
    is_window : bool
        Whether to use fixed window (True) or epoch between events (False)
    window_size : float
        Window size in seconds (only for is_window=True)
        
    Returns:
    --------
    df_sample_new : DataFrame
        Input dataframe with added firing rate columns
    """
    # Extract epoch timestamps for all units
    epoch_ts = extract_event_timestamps(df_sample_new, start_idx_col, end_idx_col, is_window, window_size)
    
    # Compute baseline firing rate (1 second before first stimulus) - only on first call
    if 'fr_baseline' not in df_sample_new.columns:
        print("Computing baseline firing rates...")
        
        # Get first stimulus onset timestamps
        enc_ts = extract_event_timestamps(df_sample_new, 'idxEnc1', 'idxEnc1')
        
        # Create baseline windows: 1 second before stimulus onset
        baseline_ts = {}
        for unit_id, timestamps in enc_ts.items():
            baseline_start = timestamps[:, 0] - 1e6  # 1 second before (microseconds)
            baseline_end = timestamps[:, 0]          # End at stimulus onset
            baseline_ts[unit_id] = np.column_stack((baseline_start, baseline_end))
            
        # Compute baseline firing rates for each unit and trial
        df_sample_new['fr_baseline'] = df_sample_new.apply(
            lambda row: [
                # Count spikes in baseline window and convert to Hz
                np.sum((np.ravel(row["timestamps"]) >= baseline_on) & 
                       (np.ravel(row["timestamps"]) < baseline_off)) / ((baseline_off - baseline_on) / 1e6)
                for baseline_on, baseline_off in baseline_ts[row["unit_id"]]
            ],
            axis=1
        )
    
    # Compute firing rates for the specified epoch
    epoch_col = f"{fr_prefix}_epoch"
    print(f"Computing firing rates for {epoch_col}...")
    
    df_sample_new[epoch_col] = df_sample_new.apply(
        lambda row: [
            # Count spikes in epoch window and convert to Hz
            np.sum((np.ravel(row["timestamps"]) >= epoch_on) & 
                   (np.ravel(row["timestamps"]) < epoch_off)) / ((epoch_off - epoch_on) / 1e6)
            for epoch_on, epoch_off in epoch_ts[row["unit_id"]]
        ],
        axis=1
    )
    
    # Add trial numbers if not already present
    if "trial_nr" not in df_sample_new.columns:
        df_sample_new["trial_nr"] = df_sample_new[epoch_col].apply(lambda x: np.arange(len(x)))
    
    return df_sample_new


print("Computing firing rates for all task epochs...")

# Compute firing rates for each task epoch
# Each call adds a new firing rate column to the dataframe

# 1. First stimulus encoding period (from stimulus onset to delay start)
df_sample_new = compute_firing_rates(df_sample_new, 'idxEnc1', 'idxDel1', fr_prefix='fr')

# 2. First delay period (between first stimulus and second stimulus)  
df_sample_new = compute_firing_rates(df_sample_new, 'idxDel1', 'idxEnc2', fr_prefix='fr_del1')

# 3. Second stimulus encoding period (from second stimulus onset to second delay)
df_sample_new = compute_firing_rates(df_sample_new, 'idxEnc2', 'idxDel2', fr_prefix='fr_enc2')

# 4. Second delay period (between second stimulus and probe)
df_sample_new = compute_firing_rates(df_sample_new, 'idxDel2', 'idxProbeOn', fr_prefix='fr_del2')

# 5. Response period (±0.5 second window around button press)
df_sample_new = compute_firing_rates(df_sample_new, 'idxResp', None, fr_prefix='fr_resp', 
                                   is_window=True, window_size=0.5)

print("Expanding dataframe from unit-based to trial-based format...")

# Transform from unit-level to trial-level format
# Each unit currently has lists of firing rates (one per trial)
# After exploding, each row will represent one unit-trial combination
columns_to_explode = ['fr_baseline', 'fr_epoch', 'fr_del1_epoch', 'fr_enc2_epoch', 
                      'fr_del2_epoch', 'fr_resp_epoch', "trial_nr"]
df_sample_new = df_sample_new.explode(columns_to_explode)

print(f"Final neural data shape: {df_sample_new.shape}")


In [None]:
print("Merging neural data with trial information...")

# Reset indices to ensure proper alignment
df_sample_new = df_sample_new.reset_index(drop=True)
trial_info = trial_info.reset_index(drop=True)

# Merge neural firing rates with behavioral trial data
# This combines unit-trial firing rates with stimulus information for each trial
data = pd.merge(
    df_sample_new,
    trial_info,
    on=["unit_id", "trial_nr"],  # Join on unit and trial identifiers
    how="left",                  # Keep all neural data, add matching trial info
).infer_objects()

print(f"Combined data shape: {data.shape}")

# Select columns needed for selectivity analysis
# Include: neural identifiers, firing rates, stimulus properties, task timing
cols_to_keep = [
    # Neural data identifiers and firing rates
    "unit_id", "timestamps", "brainAreaOfCell", 
    "fr_epoch", "fr_baseline", "fr_del1_epoch", "fr_enc2_epoch", "fr_del2_epoch", "fr_resp_epoch", 
    "trial_nr",
    
    # Stimulus properties for selectivity analysis
    "first_cat", "second_cat", "first_num", "second_num",
    "first_pic", "second_pic", "probe_cat", "probe_pic",
    
    # Task variables
    "probe_validity", "probe_num", "correct_answer",
    "rt", "acc", "key", "cat_comparison", 
    
    # Event timing and metadata
    "events", "nTrials", "Trials", 
    "idxEnc1", "idxEnc2", "idxDel1", "idxDel2", "idxProbeOn", "idxResp", 
    "nrProcessed", "periods_Enc1", "periods_Enc2", "periods_Del1", "periods_Del2", 
    "periods_Probe", "periods_Resp", "prestimEnc", "prestimMaint", "prestimProbe",
    "prestimButtonPress", "poststimEnc", "poststimMaint", "poststimProbe", 
    "poststimButtonPress", "sessionIdx", "channel", "cellNr", "sessionID", "origClusterID"
]

# Create filtered dataset with only necessary columns
data_filtered = data[cols_to_keep].copy()

print("Converting stimulus variables to simple string format for statistical analysis...")

# Convert stimulus variables to simple string format
# MATLAB arrays need to be converted to hashable Python types for grouping operations
data_filtered["first_cat_simple"] = data_filtered["first_cat"].apply(
    lambda x: str(np.squeeze(x)) if isinstance(x, (list, np.ndarray)) else str(x)
)
data_filtered["second_cat_simple"] = data_filtered["second_cat"].apply(
    lambda x: str(np.squeeze(x)) if isinstance(x, (list, np.ndarray)) else str(x)
)
data_filtered["first_num_simple"] = data_filtered["first_num"].apply(
    lambda x: str(np.squeeze(x)) if isinstance(x, (list, np.ndarray)) else str(x)
)
data_filtered["second_num_simple"] = data_filtered["second_num"].apply(
    lambda x: str(np.squeeze(x)) if isinstance(x, (list, np.ndarray)) else str(x)
)

print(f"Final processed dataset shape: {data_filtered.shape}")
print(f"Stimulus categories: {data_filtered['first_cat_simple'].unique()}")
print(f"Stimulus numbers: {data_filtered['first_num_simple'].unique()}")


### Core Selectivity Analysis Functions

The following functions implement the statistical analysis to identify selective neurons. The analysis uses ANOVA to test whether firing rates significantly differ across stimulus conditions (categories, numerosities, or temporal positions).


## 3. Feature Selectivity Analysis

### Statistical Analysis of Neuronal Selectivity

Identification of neurons selective for task-relevant features:
1. **Stimulus Categories** - Visual category preferences during first and second stimulus presentation
2. **Numerosities** - Number-selective responses during encoding periods  
3. **Temporal Position** - Differential responses based on stimulus presentation order (first vs second)

Using ANOVA to test for significant selectivity across stimulus conditions and brain regions.


In [None]:
import statsmodels.formula.api as smf
import statsmodels.api as sm
import numpy as np
import pandas as pd
from tqdm import tqdm

def identify_selective_neurons(data_filtered):
    selectivity_stats = []

    for unit_id, unit_df in tqdm(data_filtered.groupby("unit_id")):
        unit_df = unit_df.copy()
        
        # Analyze first stimulus
        unit_df["fr_first"] = unit_df["fr_epoch"]
        unit_df["fr_second"] = unit_df["fr_enc2_epoch"]

        # Skip if no variance
        if unit_df["fr_first"].std() == 0 or unit_df["fr_second"].std() == 0:
            continue

        # First stimulus model
        # Using ordinary least squares regression to model firing rate as a function of 
        # categorical variables for stimulus category and numerosity
        model_first = smf.ols(
            "fr_first ~ C(first_cat_simple) + C(first_num_simple)", 
            data=unit_df
        ).fit()
        anova_first = sm.stats.anova_lm(model_first, typ=2)

        # Second stimulus model
        model_second = smf.ols(
            "fr_second ~ C(second_cat_simple) + C(second_num_simple)", 
            data=unit_df
        ).fit()
        anova_second = sm.stats.anova_lm(model_second, typ=2)

        # Position analysis
        position_df = pd.DataFrame({
            'fr': np.concatenate([unit_df['fr_first'], unit_df['fr_second']]),
            'category': np.concatenate([unit_df['first_cat_simple'], unit_df['second_cat_simple']]),
            'number': np.concatenate([unit_df['first_num_simple'], unit_df['second_num_simple']]),
            'position': ['first'] * len(unit_df) + ['second'] * len(unit_df)
        })

        model_position = smf.ols(
            "fr ~ C(category) + C(number) + C(position)", 
            data=position_df
        ).fit()
        anova_position = sm.stats.anova_lm(model_position, typ=2)

        stats = {
            'unit_id': unit_id,
            'area': unit_df['brainAreaOfCell'].iloc[0],
            'first_cat_pvalue': anova_first.loc['C(first_cat_simple)', 'PR(>F)'],
            'first_num_pvalue': anova_first.loc['C(first_num_simple)', 'PR(>F)'],
            'second_cat_pvalue': anova_second.loc['C(second_cat_simple)', 'PR(>F)'],
            'second_num_pvalue': anova_second.loc['C(second_num_simple)', 'PR(>F)'],
            'position_pvalue': anova_position.loc['C(position)', 'PR(>F)'],
            'is_first_cat_selective': anova_first.loc['C(first_cat_simple)', 'PR(>F)'] < 0.05,
            'is_first_num_selective': anova_first.loc['C(first_num_simple)', 'PR(>F)'] < 0.05,
            'is_second_cat_selective': anova_second.loc['C(second_cat_simple)', 'PR(>F)'] < 0.05,
            'is_second_num_selective': anova_second.loc['C(second_num_simple)', 'PR(>F)'] < 0.05,
            'is_position_selective': anova_position.loc['C(position)', 'PR(>F)'] < 0.05,
            'r2_first': model_first.rsquared,
            'r2_second': model_second.rsquared,
            'r2_position': model_position.rsquared,
            'is_any_selective': (
                (anova_first.loc['C(first_cat_simple)', 'PR(>F)'] < 0.05) or
                (anova_first.loc['C(first_num_simple)', 'PR(>F)'] < 0.05) or
                (anova_second.loc['C(second_cat_simple)', 'PR(>F)'] < 0.05) or
                (anova_second.loc['C(second_num_simple)', 'PR(>F)'] < 0.05) or
                (anova_position.loc['C(position)', 'PR(>F)'] < 0.05)
            )
        }
        selectivity_stats.append(stats)

    results_df = pd.DataFrame(selectivity_stats)
    return results_df


### Running the Analysis

Execute the selectivity analysis on all units and display summary statistics.


In [None]:
def summarize_selectivity(selectivity_results):
    total_units = len(selectivity_results)

    summary = {
        "total_units": total_units,
        "first_category_selective": {
            "count": int(selectivity_results["is_first_cat_selective"].sum()),
            "percentage": round(100 * selectivity_results["is_first_cat_selective"].mean(), 1)
        },
        "first_number_selective": {
            "count": int(selectivity_results["is_first_num_selective"].sum()),
            "percentage": round(100 * selectivity_results["is_first_num_selective"].mean(), 1)
        },
        "second_category_selective": {
            "count": int(selectivity_results["is_second_cat_selective"].sum()),
            "percentage": round(100 * selectivity_results["is_second_cat_selective"].mean(), 1)
        },
        "second_number_selective": {
            "count": int(selectivity_results["is_second_num_selective"].sum()),
            "percentage": round(100 * selectivity_results["is_second_num_selective"].mean(), 1)
        },
        "position_selective": {
            "count": int(selectivity_results["is_position_selective"].sum()),
            "percentage": round(100 * selectivity_results["is_position_selective"].mean(), 1)
        },
        "any_selective": {
            "count": int(selectivity_results["is_any_selective"].sum()),
            "percentage": round(100 * selectivity_results["is_any_selective"].mean(), 1)
        }
    }

    area_summary = selectivity_results.groupby("area").agg({
        "is_first_cat_selective": "sum",
        "is_first_num_selective": "sum",
        "is_second_cat_selective": "sum",
        "is_second_num_selective": "sum",
        "is_position_selective": "sum",
        "is_any_selective": "sum",
        "unit_id": "count"
    }).rename(columns={"unit_id": "total_units"})

    # Calculate percentage columns explicitly
    for col in ["is_first_cat_selective", "is_first_num_selective",
                "is_second_cat_selective", "is_second_num_selective",
                "is_position_selective", "is_any_selective"]:
        area_summary[f"{col}_pct"] = round(100 * area_summary[col] / area_summary["total_units"], 1)

    summary["by_area"] = area_summary.reset_index().to_dict(orient='records')

    return summary


In [None]:
print("Starting selectivity analysis for all units...")

# Run the main selectivity analysis
# This function analyzes each unit individually using ANOVA to test for selectivity
selectivity_results = identify_selective_neurons(data_filtered)

print(f"Analysis complete! Analyzed {len(selectivity_results)} units.")

# Generate summary statistics across all units and brain regions
selectivity_summary = summarize_selectivity(selectivity_results)

# Display overall population results
print("\n" + "="*50)
print("POPULATION SELECTIVITY RESULTS")
print("="*50)

print(f"Total units analyzed: {selectivity_summary['total_units']}")
print(f"First category selective: {selectivity_summary['first_category_selective']['count']} ({selectivity_summary['first_category_selective']['percentage']}%)")
print(f"First number selective: {selectivity_summary['first_number_selective']['count']} ({selectivity_summary['first_number_selective']['percentage']}%)")
print(f"Second category selective: {selectivity_summary['second_category_selective']['count']} ({selectivity_summary['second_category_selective']['percentage']}%)")
print(f"Second number selective: {selectivity_summary['second_number_selective']['count']} ({selectivity_summary['second_number_selective']['percentage']}%)")
print(f"Position selective: {selectivity_summary['position_selective']['count']} ({selectivity_summary['position_selective']['percentage']}%)")
print(f"Any selective: {selectivity_summary['any_selective']['count']} ({selectivity_summary['any_selective']['percentage']}%)")

# Display regional breakdown
print(f"\nRegional breakdown available for {len(selectivity_summary['by_area'])} brain areas.")


### Population Distribution Plots

Generate summary plots showing the distribution of different selectivity types across the population and brain regions.


## 4. Results & Visualization

### Population Selectivity Summary

Overview of selectivity patterns across the neural population and brain regions:

In [15]:
print("Generating population selectivity distribution plots...")

# Plot 1: Overall distribution of selectivity types across the population
selectivity_types = [
    'First Category', 
    'First Number', 
    'Second Category', 
    'Second Number', 
    'Position', 
    'Any'
]

# Extract counts for each selectivity type
selectivity_counts = [
    selectivity_summary['first_category_selective']['count'],
    selectivity_summary['first_number_selective']['count'],
    selectivity_summary['second_category_selective']['count'],
    selectivity_summary['second_number_selective']['count'],
    selectivity_summary['position_selective']['count'],
    selectivity_summary['any_selective']['count']
]

# Create bar plot showing absolute numbers of selective neurons
plt.figure(figsize=(10, 6))
plt.bar(selectivity_types, selectivity_counts, color='skyblue')
plt.title("Distribution of Feature Selectivity Across Population")
plt.ylabel("Number of Neurons")
plt.xlabel("Selectivity Type")
plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig("feature_selectivity_overall.png", dpi=300)
plt.show()

print("Creating regional selectivity comparison plot...")

# Plot 2: Regional distribution of selective neurons
area_df = pd.DataFrame(selectivity_summary['by_area'])

# Sort brain areas by percentage of selective neurons (descending)
area_df_sorted = area_df.sort_values('is_any_selective_pct', ascending=False)

# Create bar plot comparing selectivity percentages across brain regions
plt.figure(figsize=(12, 8))
plt.bar(area_df_sorted['area'], area_df_sorted['is_any_selective_pct'], color='teal')
plt.title("Percentage of Selective Neurons by Brain Region")
plt.ylabel("Percent of Selective Neurons (%)")
plt.xlabel("Brain Region")

# Add horizontal line showing overall population average
plt.axhline(
    y=selectivity_summary['any_selective']['percentage'], 
    color='red', 
    linestyle='--', 
    label=f"Overall Average ({selectivity_summary['any_selective']['percentage']}%)"
)
plt.legend()
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.savefig("selective_neurons_by_region.png", dpi=300)
plt.show()

print("Population plots saved as:")
print("  - feature_selectivity_overall.png")
print("  - selective_neurons_by_region.png")

### Regional Selectivity Analysis

Detailed breakdown of selectivity patterns by brain region:


### Single-Unit Response Visualization

Spike raster plots and PSTHs for neurons exhibiting significant selectivity.

**Task Event Markers:**
- `pic1` (1): First picture presentation
- `delay1` (2): First delay period  
- `pic2` (3): Second picture presentation
- `delay2` (4): Second delay period
- `probeOnset` (5): Probe stimulus onset
- `response` (6): Response period


In [None]:
# Display basic information about the processed dataset
print(f"Dataset shape: {data_filtered.shape}")
print(f"Available columns: {len(data_filtered.columns)}")
print(f"Key variables for analysis:")
print(f"  - Brain areas: {sorted(data_filtered['brainAreaOfCell'].unique())}")
print(f"  - Stimulus categories: {sorted(data_filtered['first_cat_simple'].unique())}")
print(f"  - Stimulus numbers: {sorted(data_filtered['first_num_simple'].unique())}")
data_filtered.columns


In [15]:
import statsmodels.formula.api as smf
import statsmodels.api as sm
import numpy as np
import pandas as pd
from tqdm import tqdm

def identify_selective_neurons(data_filtered):
    selectivity_stats = []

    for unit_id, unit_df in tqdm(data_filtered.groupby("unit_id")):
        unit_df = unit_df.copy()
        
        # Analyze first stimulus
        unit_df["fr_first"] = unit_df["fr_epoch"]
        unit_df["fr_second"] = unit_df["fr_enc2_epoch"]

        # Skip if no variance
        if unit_df["fr_first"].std() == 0 or unit_df["fr_second"].std() == 0:
            continue

        # First stimulus model
        # Using ordinary least squares regression to model firing rate as a function of 
        # categorical variables for stimulus category and numerosity
        model_first = smf.ols(
            "fr_first ~ C(first_cat_simple) + C(first_num_simple)", 
            data=unit_df
        ).fit()
        anova_first = sm.stats.anova_lm(model_first, typ=2)

        # Second stimulus model
        model_second = smf.ols(
            "fr_second ~ C(second_cat_simple) + C(second_num_simple)", 
            data=unit_df
        ).fit()
        anova_second = sm.stats.anova_lm(model_second, typ=2)

        # Position analysis
        position_df = pd.DataFrame({
            'fr': np.concatenate([unit_df['fr_first'], unit_df['fr_second']]),
            'category': np.concatenate([unit_df['first_cat_simple'], unit_df['second_cat_simple']]),
            'number': np.concatenate([unit_df['first_num_simple'], unit_df['second_num_simple']]),
            'position': ['first'] * len(unit_df) + ['second'] * len(unit_df)
        })

        model_position = smf.ols(
            "fr ~ C(category) + C(number) + C(position)", 
            data=position_df
        ).fit()
        anova_position = sm.stats.anova_lm(model_position, typ=2)

        stats = {
            'unit_id': unit_id,
            'area': unit_df['brainAreaOfCell'].iloc[0],
            'first_cat_pvalue': anova_first.loc['C(first_cat_simple)', 'PR(>F)'],
            'first_num_pvalue': anova_first.loc['C(first_num_simple)', 'PR(>F)'],
            'second_cat_pvalue': anova_second.loc['C(second_cat_simple)', 'PR(>F)'],
            'second_num_pvalue': anova_second.loc['C(second_num_simple)', 'PR(>F)'],
            'position_pvalue': anova_position.loc['C(position)', 'PR(>F)'],
            'is_first_cat_selective': anova_first.loc['C(first_cat_simple)', 'PR(>F)'] < 0.05,
            'is_first_num_selective': anova_first.loc['C(first_num_simple)', 'PR(>F)'] < 0.05,
            'is_second_cat_selective': anova_second.loc['C(second_cat_simple)', 'PR(>F)'] < 0.05,
            'is_second_num_selective': anova_second.loc['C(second_num_simple)', 'PR(>F)'] < 0.05,
            'is_position_selective': anova_position.loc['C(position)', 'PR(>F)'] < 0.05,
            'r2_first': model_first.rsquared,
            'r2_second': model_second.rsquared,
            'r2_position': model_position.rsquared,
            'is_any_selective': (
                (anova_first.loc['C(first_cat_simple)', 'PR(>F)'] < 0.05) or
                (anova_first.loc['C(first_num_simple)', 'PR(>F)'] < 0.05) or
                (anova_second.loc['C(second_cat_simple)', 'PR(>F)'] < 0.05) or
                (anova_second.loc['C(second_num_simple)', 'PR(>F)'] < 0.05) or
                (anova_position.loc['C(position)', 'PR(>F)'] < 0.05)
            )
        }
        selectivity_stats.append(stats)

    results_df = pd.DataFrame(selectivity_stats)
    return results_df


def visualize_selective_responses(data_filtered, unit_id):
    """
    Create visualizations showing the selectivity pattern of a neuron.
    
    Parameters
    ----------
    data_filtered : pandas.DataFrame
        DataFrame containing neural data
    unit_id : int or str
        ID of the unit to visualize
        
    Returns
    -------
    matplotlib.figure.Figure
        Figure containing the visualization
    """
    unit_data = data_filtered[data_filtered["unit_id"] == unit_id].copy()
    unit_data["fr_normalized"] = unit_data["fr_epoch"] - unit_data["fr_baseline"]
    
    # Create figure with subplots for different selectivity types
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    
    # First stimulus features
    # 1. First Category selectivity
    cat_responses = unit_data.groupby("first_cat_simple")["fr_normalized"].mean().reset_index()
    cat_sem = unit_data.groupby("first_cat_simple")["fr_normalized"].sem().reset_index()
    
    sns.barplot(x="first_cat_simple", y="fr_normalized", data=cat_responses, ax=axes[0,0])
    axes[0,0].errorbar(
        x=range(len(cat_responses)), 
        y=cat_responses["fr_normalized"],
        yerr=cat_sem["fr_normalized"],
        fmt='none', color='black', capsize=5
    )
    axes[0,0].set_title(f"First Stimulus Category Selectivity")
    axes[0,0].set_xlabel("Category")
    axes[0,0].set_ylabel("Normalized Firing Rate (Hz)")
    
    # 2. First Number selectivity
    num_responses = unit_data.groupby("first_num_simple")["fr_normalized"].mean().reset_index()
    num_sem = unit_data.groupby("first_num_simple")["fr_normalized"].sem().reset_index()
    
    sns.barplot(x="first_num_simple", y="fr_normalized", data=num_responses, ax=axes[0,1])
    axes[0,1].errorbar(
        x=range(len(num_responses)), 
        y=num_responses["fr_normalized"],
        yerr=num_sem["fr_normalized"],
        fmt='none', color='black', capsize=5
    )
    axes[0,1].set_title(f"First Stimulus Number Selectivity")
    axes[0,1].set_xlabel("Number")
    axes[0,1].set_ylabel("Normalized Firing Rate (Hz)")
    
    # Second stimulus features
    # 3. Second Category selectivity
    cat_responses = unit_data.groupby("second_cat_simple")["fr_normalized"].mean().reset_index()
    cat_sem = unit_data.groupby("second_cat_simple")["fr_normalized"].sem().reset_index()
    
    sns.barplot(x="second_cat_simple", y="fr_normalized", data=cat_responses, ax=axes[1,0])
    axes[1,0].errorbar(
        x=range(len(cat_responses)), 
        y=cat_responses["fr_normalized"],
        yerr=cat_sem["fr_normalized"],
        fmt='none', color='black', capsize=5
    )
    axes[1,0].set_title(f"Second Stimulus Category Selectivity")
    axes[1,0].set_xlabel("Category")
    axes[1,0].set_ylabel("Normalized Firing Rate (Hz)")
    
    # 4. Second Number selectivity
    num_responses = unit_data.groupby("second_num_simple")["fr_normalized"].mean().reset_index()
    num_sem = unit_data.groupby("second_num_simple")["fr_normalized"].sem().reset_index()
    
    sns.barplot(x="second_num_simple", y="fr_normalized", data=num_responses, ax=axes[1,1])
    axes[1,1].errorbar(
        x=range(len(num_responses)), 
        y=num_responses["fr_normalized"],
        yerr=num_sem["fr_normalized"],
        fmt='none', color='black', capsize=5
    )
    axes[1,1].set_title(f"Second Stimulus Number Selectivity")
    axes[1,1].set_xlabel("Number")
    axes[1,1].set_ylabel("Normalized Firing Rate (Hz)")
    
    # 5. Position selectivity (first vs second)
    # Create a position comparison dataframe
    position_df = pd.DataFrame({
        'Position': ['First', 'Second'],
        'FR': [unit_data['fr_normalized'].mean(), unit_data['fr_normalized'].mean()],
        'SEM': [unit_data['fr_normalized'].sem(), unit_data['fr_normalized'].sem()]
    })
    
    sns.barplot(x="Position", y="FR", data=position_df, ax=axes[0,2])
    axes[0,2].errorbar(
        x=range(len(position_df)), 
        y=position_df["FR"],
        yerr=position_df["SEM"],
        fmt='none', color='black', capsize=5
    )
    axes[0,2].set_title(f"Position Selectivity (First vs Second)")
    axes[0,2].set_xlabel("Stimulus Position")
    axes[0,2].set_ylabel("Normalized Firing Rate (Hz)")
    
    # Add unit information
    axes[1,2].axis('off')
    axes[1,2].text(0.5, 0.5, 
                  f"Unit ID: {unit_id}\nBrain Area: {unit_data['brainAreaOfCell'].iloc[0]}", 
                  horizontalalignment='center',
                  verticalalignment='center',
                  fontsize=12)
    
    plt.suptitle(f"Selectivity Profile for Unit {unit_id}", fontsize=16)
    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    
    return fig

def summarize_selectivity(selectivity_results):
    total_units = len(selectivity_results)

    summary = {
        "total_units": total_units,
        "first_category_selective": {
            "count": int(selectivity_results["is_first_cat_selective"].sum()),
            "percentage": round(100 * selectivity_results["is_first_cat_selective"].mean(), 1)
        },
        "first_number_selective": {
            "count": int(selectivity_results["is_first_num_selective"].sum()),
            "percentage": round(100 * selectivity_results["is_first_num_selective"].mean(), 1)
        },
        "second_category_selective": {
            "count": int(selectivity_results["is_second_cat_selective"].sum()),
            "percentage": round(100 * selectivity_results["is_second_cat_selective"].mean(), 1)
        },
        "second_number_selective": {
            "count": int(selectivity_results["is_second_num_selective"].sum()),
            "percentage": round(100 * selectivity_results["is_second_num_selective"].mean(), 1)
        },
        "position_selective": {
            "count": int(selectivity_results["is_position_selective"].sum()),
            "percentage": round(100 * selectivity_results["is_position_selective"].mean(), 1)
        },
        "any_selective": {
            "count": int(selectivity_results["is_any_selective"].sum()),
            "percentage": round(100 * selectivity_results["is_any_selective"].mean(), 1)
        }
    }

    area_summary = selectivity_results.groupby("area").agg({
        "is_first_cat_selective": "sum",
        "is_first_num_selective": "sum",
        "is_second_cat_selective": "sum",
        "is_second_num_selective": "sum",
        "is_position_selective": "sum",
        "is_any_selective": "sum",
        "unit_id": "count"
    }).rename(columns={"unit_id": "total_units"})

    # Calculate percentage columns explicitly
    for col in ["is_first_cat_selective", "is_first_num_selective",
                "is_second_cat_selective", "is_second_num_selective",
                "is_position_selective", "is_any_selective"]:
        area_summary[f"{col}_pct"] = round(100 * area_summary[col] / area_summary["total_units"], 1)

    summary["by_area"] = area_summary.reset_index().to_dict(orient='records')

    return summary




In [None]:
# Run the selectivity analysis
selectivity_results = identify_selective_neurons(data_filtered)

# Summarize results
selectivity_summary = summarize_selectivity(selectivity_results)

# Print overall results
print("\nOVERALL RESULTS:")
print(f"Total units analyzed: {selectivity_summary['total_units']}")
print(f"First category selective: {selectivity_summary['first_category_selective']['count']} ({selectivity_summary['first_category_selective']['percentage']}%)")
print(f"First number selective: {selectivity_summary['first_number_selective']['count']} ({selectivity_summary['first_number_selective']['percentage']}%)")
print(f"Second category selective: {selectivity_summary['second_category_selective']['count']} ({selectivity_summary['second_category_selective']['percentage']}%)")
print(f"Second number selective: {selectivity_summary['second_number_selective']['count']} ({selectivity_summary['second_number_selective']['percentage']}%)")
print(f"Position selective: {selectivity_summary['position_selective']['count']} ({selectivity_summary['position_selective']['percentage']}%)")
print(f"Any selective: {selectivity_summary['any_selective']['count']} ({selectivity_summary['any_selective']['percentage']}%)")


### Regional Analysis Function

Area-by-area breakdown of selectivity patterns. This function analyzes each brain region separately to identify regional differences in selectivity patterns.


In [None]:

# Overall distribution of selectivity types
selectivity_types = [
    'First Category', 
    'First Number', 
    'Second Category', 
    'Second Number', 
    'Position', 
    'Any'
]

selectivity_counts = [
    selectivity_summary['first_category_selective']['count'],
    selectivity_summary['first_number_selective']['count'],
    selectivity_summary['second_category_selective']['count'],
    selectivity_summary['second_number_selective']['count'],
    selectivity_summary['position_selective']['count'],
    selectivity_summary['any_selective']['count']
]

plt.figure(figsize=(10, 6))
plt.bar(selectivity_types, selectivity_counts, color='skyblue')
plt.title("Distribution of Feature Selectivity")
plt.ylabel("Number of Neurons")
plt.xlabel("Selectivity Type")
plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig("feature_selectivity_overall.png", dpi=300)
plt.show()

# Regional distribution of selective neurons
area_df = pd.DataFrame(selectivity_summary['by_area'])

# Sort by percent selective
area_df_sorted = area_df.sort_values('is_any_selective_pct', ascending=False)

plt.figure(figsize=(12, 8))
plt.bar(area_df_sorted['area'], area_df_sorted['is_any_selective_pct'], color='teal')
plt.title("Percentage of Selective Neurons by Brain Region")
plt.ylabel("Percent of Selective Neurons (%)")
plt.xlabel("Brain Region")
plt.axhline(
    y=selectivity_summary['any_selective']['percentage'], 
    color='red', 
    linestyle='--', 
    label='Overall Average'
)
plt.legend()
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.savefig("selective_neurons_by_region.png", dpi=300)
plt.show()


### Execute Regional Analysis

Run the area-by-area analysis and generate summary plots for each brain region.


In [18]:
def identify_selective_neurons_by_area(data_filtered):
    area_results = {}

    for area, area_df in data_filtered.groupby("brainAreaOfCell"):
        selectivity_stats = []

        for unit_id, unit_df in tqdm(area_df.groupby("unit_id"), desc=f"Analyzing {area}"):
            unit_df = unit_df.copy()

            # Analyze first and second stimulus separately
            unit_df["fr_first"] = unit_df["fr_epoch"]
            unit_df["fr_second"] = unit_df["fr_enc2_epoch"]

            # Skip if no variance
            if unit_df["fr_first"].std() == 0 or unit_df["fr_second"].std() == 0:
                continue

            # First stimulus model
            model_first = smf.ols(
                "fr_first ~ C(first_cat_simple) + C(first_num_simple)", 
                data=unit_df
            ).fit()
            anova_first = sm.stats.anova_lm(model_first, typ=2)

            # Second stimulus model
            model_second = smf.ols(
                "fr_second ~ C(second_cat_simple) + C(second_num_simple)", 
                data=unit_df
            ).fit()
            anova_second = sm.stats.anova_lm(model_second, typ=2)

            # Position analysis
            position_df = pd.DataFrame({
                'fr': np.concatenate([unit_df['fr_first'], unit_df['fr_second']]),
                'category': np.concatenate([unit_df['first_cat_simple'], unit_df['second_cat_simple']]),
                'number': np.concatenate([unit_df['first_num_simple'], unit_df['second_num_simple']]),
                'position': ['first'] * len(unit_df) + ['second'] * len(unit_df)
            })

            model_position = smf.ols(
                "fr ~ C(category) + C(number) + C(position)", 
                data=position_df
            ).fit()
            anova_position = sm.stats.anova_lm(model_position, typ=2)

            stats = {
                'unit_id': unit_id,
                'area': area,
                'first_cat_pvalue': anova_first.loc['C(first_cat_simple)', 'PR(>F)'],
                'first_num_pvalue': anova_first.loc['C(first_num_simple)', 'PR(>F)'],
                'second_cat_pvalue': anova_second.loc['C(second_cat_simple)', 'PR(>F)'],
                'second_num_pvalue': anova_second.loc['C(second_num_simple)', 'PR(>F)'],
                'position_pvalue': anova_position.loc['C(position)', 'PR(>F)'],
                'is_first_cat_selective': anova_first.loc['C(first_cat_simple)', 'PR(>F)'] < 0.05,
                'is_first_num_selective': anova_first.loc['C(first_num_simple)', 'PR(>F)'] < 0.05,
                'is_second_cat_selective': anova_second.loc['C(second_cat_simple)', 'PR(>F)'] < 0.05,
                'is_second_num_selective': anova_second.loc['C(second_num_simple)', 'PR(>F)'] < 0.05,
                'is_position_selective': anova_position.loc['C(position)', 'PR(>F)'] < 0.05,
                'r2_first': model_first.rsquared,
                'r2_second': model_second.rsquared,
                'r2_position': model_position.rsquared,
                'is_any_selective': (
                    (anova_first.loc['C(first_cat_simple)', 'PR(>F)'] < 0.05) or
                    (anova_first.loc['C(first_num_simple)', 'PR(>F)'] < 0.05) or
                    (anova_second.loc['C(second_cat_simple)', 'PR(>F)'] < 0.05) or
                    (anova_second.loc['C(second_num_simple)', 'PR(>F)'] < 0.05) or
                    (anova_position.loc['C(position)', 'PR(>F)'] < 0.05)
                )
            }
            selectivity_stats.append(stats)

        area_results[area] = pd.DataFrame(selectivity_stats)

    areas = list(area_results.keys())
    return area_results, areas

In [None]:
print("Analyzing selectivity patterns by brain region...")

# Run the regional analysis
area_selectivity_results, brain_areas = identify_selective_neurons_by_area(data_filtered)

print(f"Analysis complete for {len(brain_areas)} brain regions: {brain_areas}")

# Generate detailed reports and plots for each brain area
for area, df in area_selectivity_results.items():
    print(f"\n{'='*50}")
    print(f"BRAIN AREA: {area}")
    print(f"{'='*50}")
    print(f"Total units: {len(df)}")
    
    # Calculate selectivity counts for this area
    summary = {
        "first_cat": df['is_first_cat_selective'].sum(),
        "first_num": df['is_first_num_selective'].sum(), 
        "second_cat": df['is_second_cat_selective'].sum(),
        "second_num": df['is_second_num_selective'].sum(),
        "position": df['is_position_selective'].sum(),
        "any_selective": df['is_any_selective'].sum()
    }
    
    # Display selectivity statistics for this area
    print("Selectivity counts and percentages:")
    for k, v in summary.items():
        pct = 100 * v / len(df) if len(df) > 0 else 0
        print(f"  {k}: {v} ({pct:.1f}%)")

    # Create area-specific selectivity plot
    n_units = len(df)
    
    # Prepare data for plotting
    counts = {
        "First Cat": df["is_first_cat_selective"].sum(),
        "First Num": df["is_first_num_selective"].sum(),
        "Second Cat": df["is_second_cat_selective"].sum(),
        "Second Num": df["is_second_num_selective"].sum(),
        "Position": df["is_position_selective"].sum(),
        "Any": df["is_any_selective"].sum()
    }
    
    # Generate bar plot for this brain area
    plt.figure(figsize=(8, 5))
    bars = plt.bar(counts.keys(), counts.values(), color='skyblue')
    plt.title(f"Selectivity Distribution in {area} (n={n_units})")
    plt.ylabel("Number of Selective Neurons")
    plt.xlabel("Selectivity Type")
    plt.xticks(rotation=45)
    
    # Add value labels on bars
    for bar, count in zip(bars, counts.values()):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1, 
                str(count), ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()

print("\nRegional analysis complete!")


### Plot Responsive Neurons - All Trials

Visualize spike rasters and PSTHs for neurons showing significant selectivity:

### Required Libraries for Spike Visualization

Import specialized neural analysis libraries for generating spike raster plots and PSTHs.


### Generate Individual Unit Raster Plots

Create detailed spike raster plots and PSTHs for all selective neurons. Each plot shows trial-by-trial spiking activity aligned to task events, grouped by stimulus conditions.


In [20]:
# Install neural analysis package if needed:
# pip install git+https://github.com/ioqfwfq/rlab_neural_analysis.git@jz

# Import functions for spike visualization and analysis
from neural_analysis.visualize import plot_spikes_with_PSTH
from neural_analysis.spikes import get_spikes

print("Neural analysis libraries loaded successfully.")

In [None]:

print("Preparing spike alignment data for visualization...")

# Extract stimulus onset timestamps for aligning spike plots
epoch_ts = extract_event_timestamps(
    df_sample_new=data_filtered,
    start_idx_col='idxEnc1',     # Align to first stimulus onset
    is_window=True,
    window_size=0                # Use exact timestamps, not windows
)

print("Setting up output directory for plots...")

# Create output directory for individual unit plots
os.makedirs("plot_out", exist_ok=True)

print("Generating raster plots for selective neurons...")

# Define the types of selectivity to visualize
selectivity_types = [
    ("first_cat", "first_cat_simple"),      # First stimulus category selectivity
    ("first_num", "first_num_simple"),      # First stimulus number selectivity  
    ("second_cat", "second_cat_simple"),    # Second stimulus category selectivity
    ("second_num", "second_num_simple"),    # Second stimulus number selectivity
    ("position", "position")                # Temporal position selectivity
]

plot_count = 0
total_selective = sum(selectivity_results[f"is_{st[0]}_selective"].sum() for st in selectivity_types)
print(f"Will generate plots for {total_selective} selective units across {len(selectivity_types)} selectivity types...")

# Generate plots for each selectivity type
for select_type, cond in selectivity_types:
    pval_col = f"{select_type}_pvalue"
    flag_col = f"is_{select_type}_selective"

    # Get units that are selective for this particular feature
    selective_units = selectivity_results[selectivity_results[flag_col]]
    print(f"\nProcessing {len(selective_units)} units selective for {select_type}...")
    
    for _, row in selective_units.iterrows():
        unit_id = row["unit_id"]
        pval = row[pval_col]
        area = row["area"]

        # Extract trial data for this unit
        df_unit = data_filtered[data_filtered["unit_id"] == unit_id].reset_index(drop=True)
        if df_unit.empty or unit_id not in epoch_ts:
            print(f"  Skipping unit {unit_id} (no data available)")
            continue

        # Set up trial grouping based on selectivity type
        if cond == "position":
            # For position selectivity: compare first vs second stimulus presentation
            group_labels = ["first"] * len(df_unit) + ["second"] * len(df_unit)
            df_unit_extended = pd.concat([df_unit.copy(), df_unit.copy()], ignore_index=True)
            df_unit_extended["fr_combined"] = pd.concat([
                df_unit["fr_epoch"], df_unit["fr_enc2_epoch"]
            ], ignore_index=True)
            alignments = np.tile(np.asarray(epoch_ts[unit_id][:, 0], dtype=np.float64) / 1e6, 2)
            df_for_stats = df_unit_extended
        else:
            # For stimulus feature selectivity: group by stimulus condition
            group_labels = df_unit[cond].apply(
                lambda x: np.squeeze(x).item() if isinstance(x, (list, np.ndarray)) else x
            )
            alignments = np.asarray(epoch_ts[unit_id][:, 0], dtype=np.float64) / 1e6
            df_unit["fr_combined"] = df_unit["fr_epoch"]
            df_for_stats = df_unit

        # Extract spike timestamps (convert from microseconds to seconds)
        spikes = np.asarray(df_unit["timestamps"].iloc[0]).flatten().astype(np.float64) / 1e6
        spikes = np.sort(spikes)
        
        # Select appropriate variable for statistical annotation
        if select_type in ["first_cat", "second_cat"]:
            # For category selectivity, show number as secondary grouping
            stats_var = "first_num_simple" if select_type == "first_cat" else "second_num_simple"
        elif select_type in ["first_num", "second_num"]:
            # For number selectivity, show category as secondary grouping
            stats_var = "first_cat_simple" if select_type == "first_num" else "second_cat_simple"
        elif select_type == "position":
            # For position selectivity, show probe picture as secondary grouping
            stats_var = "probe_pic"
        else:
            # Default fallback
            stats_var = "first_cat_simple"
        
        stats = df_for_stats[stats_var]

        try:    
            # Generate raster plot and PSTH
            axes = plot_spikes_with_PSTH(
                spikes,
                alignments,
                window=(-1, 8),          # 1s before to 8s after stimulus
                group_labels=group_labels,
                stats=stats,
                plot_stats=False,        # Don't show statistical annotations
                sig_test=True,           # Perform significance testing
                cmap="Set1",            # Color scheme
            )

            # Add baseline firing rate reference line
            unit_baseline = np.mean(df_unit["fr_baseline"])
            xmin, xmax = axes[1].get_xlim()
            axes[1].hlines(
                y=unit_baseline,
                xmin=xmin,
                xmax=xmax,
                colors="gray",
                linestyles="--",
                label=f"baseline = {unit_baseline:.1f} Hz"
            )

            # Add vertical lines marking task events
            # Task timeline: Enc1(1s), Del1(2s), Enc2(3s), Del2, Probe(5.5s)
            event_times = [1, 2, 3, 5.5]  # Times relative to first stimulus onset
            for event_time in event_times:
                for ax in axes:
                    ax.axvline(x=event_time, color="black", linestyle="--", 
                              linewidth=0.5, alpha=0.7)

            # Add informative title and labels
            axes[0].set_title(
                f"{area} Unit {unit_id} — {select_type.replace('_', ' ').title()} Selective\n"
                f"p = {pval:.3g}"
            )
            axes[1].set_xlabel("Time from First Stimulus Onset [s]")
            axes[1].legend()

            # Save plot with descriptive filename
            fname = f"plot_out/{area}_{unit_id}_{select_type}_p{pval:.3f}.png"
            plt.savefig(fname, dpi=300, bbox_inches="tight")
            plt.close()
            
            plot_count += 1
            
        except Exception as e:
            print(f"  Error plotting unit {unit_id} ({select_type}): {e}")

print(f"\nRaster plot generation complete!")
print(f"Generated {plot_count} plots saved in 'plot_out/' directory.")
print(f"Plot naming convention: [BrainArea]_[UnitID]_[SelectivityType]_p[pvalue].png")
