# Neural Delay Activity and Distractor Resistance Analysis

This notebook analyzes delay period neural activity and distractor resistance in working memory binding tasks.

## 🚀 **QUICK START GUIDE**

### **1. Install Required Packages**
```bash
pip install numpy pandas matplotlib seaborn scipy statsmodels tqdm
```

### **2. Core Analysis (Run cells 1-32)**
- ✅ **Cells 1-32**: Core analysis pipeline (data loading → selectivity → distractor resistance)
- ⚠️ **Cells 33+**: Advanced plotting (requires optional neural_analysis package)

### **3. Expected Results**
- **Selectivity analysis** identifies responsive neurons
- **Distractor resistance analysis** finds persistent neurons  
- **CSV outputs** saved to local files

---

## Analysis Overview

1. **Data Loading & Preprocessing** - Load MATLAB data and prepare neural firing rate calculations
2. **Single Unit Selectivity Analysis** - Identify neurons selective for categories and numerosities  
3. **Delay Activity Analysis** - Find neurons maintaining selectivity during delay periods
4. **Distractor Resistance Analysis** - Test which neurons resist interference from distractors
5. **Visualization** - Plot distractor-resistant neurons and interference patterns (optional)

---

## 1. Data Loading & Preprocessing

### Environment Setup

In [None]:
import sys
print(sys.executable)

In [2]:
# pip install numpy pandas matplotlib seaborn scipy statsmodels tqdm
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy.io import loadmat
from scipy.stats import ttest_1samp, ttest_ind
import statsmodels.formula.api as smf
import statsmodels.api as sm
from tqdm import tqdm
import glob
import os

### Loading MATLAB Data Files

In [None]:
# Find all WMB_P*_v7.mat files in the current directory
mat_files = glob.glob('./WMB_P*_v7.mat')
print(f"Found {len(mat_files)} .mat files: \n{[os.path.basename(f) for f in mat_files]}\n")

# Initialize lists to store data from each file
cell_mats = []
total_mats = []

# Load each file and append its data
for mat_file in mat_files:
    print(f"\nLoading {mat_file}...")
    mat_data = loadmat(mat_file)
    cell_mats.append(mat_data['cellStatsAll'])
    total_mats.append(mat_data['totStats'])

# Print shapes of loaded data for debugging
print("\nShapes of loaded data:")
for i, (cell, total) in enumerate(zip(cell_mats, total_mats)):
    print(f"File {i}: cell_mat shape: {cell.shape}, total_mat shape: {total.shape}")

In [None]:
# Combine the data
# For cell_mat, we need to handle the different dimensions
# First, convert each cell_mat to a list of records
all_cell_records = []
for cell_mat in cell_mats:
    # Convert to list of records
    cell_list = cell_mat[0]  # now shape is (n,)
    records = []
    for cell in cell_list:
        record = {key: cell[key] for key in cell.dtype.names}
        records.append(record)
    all_cell_records.extend(records)

# Convert combined records to DataFrame
df = pd.DataFrame(all_cell_records)

# For total_mat, we can concatenate directly since they have the same structure
total_mat = np.concatenate(total_mats, axis=0)

print(f"\nCombined data shape - total_mat: {total_mat.shape}")
print(f"\nCombined data shape - df: {df.shape}")

### Brain Area Code Translation

In [5]:
from collections import Counter

def count_area_codes(area_column):
    
    mapping = {
        1: 'RH', 2: 'LH', 3: 'RA', 4: 'LA', 5: 'RAC', 6: 'LAC',
        7: 'RSMA', 8: 'LSMA', 9: 'RPT', 10: 'LPT', 11: 'ROFC', 12: 'LOFC',
        50: 'RFFA', 51: 'REC', 52: 'RCM', 53: 'LCM', 54: 'RPUL', 55: 'LPUL',
        56: 'N/A', 57: 'RPRV', 58: 'LPRV'
    }
    
    labels = []
    for code in area_column:
        label = mapping.get(code, 'Unknown')
        labels.append(label)
    
    return dict(Counter(labels))

# neuron number in each area
area_codes = total_mat[:, 3]

counts = count_area_codes(area_codes)
print("Area counts (no prefix):")
for area, count in counts.items():
    print(f"{area}: {count}")

### Data Preprocessing & Filtering

In [7]:
collapsed_area_map = {
    1: 'H', 2: 'H',
    3: 'A', 4: 'A',
    5: 'AC', 6: 'AC',
    7: 'SMA', 8: 'SMA',
    9: 'PT', 10: 'PT',
    11: 'OFC', 12: 'OFC',
    50: 'FFA', 51: 'EC',
    52: 'CM', 53: 'CM',
    54: 'PUL', 55: 'PUL',
    56: 'N/A', 57: 'PRV', 58: 'PRV'
}
# Convert brain area codes in the DataFrame
df['brainAreaOfCell'] = df['brainAreaOfCell'].apply(
    lambda x: collapsed_area_map.get(int(x[0, 0]), 'Unknown') if isinstance(x, np.ndarray) else collapsed_area_map.get(x, 'Unknown')
)

# Filter out units with low firing rate
fr = df['timestamps'].apply(lambda x: len(x) / (x[-1] - x[0]) * 1e6)
df_sample_new = df[fr > 0.1].reset_index(drop=True)

# Add unit identifiers
df_sample_new = df_sample_new.reset_index(drop=True)
df_sample_new["unit_id"] = df_sample_new.index

### Trial Information Extraction

In [9]:
def extract_trial_info(trials_struct, unit_id):
    # Build a DataFrame from the trials structure.
    df_trial = pd.DataFrame({field: trials_struct[field].squeeze() 
                             for field in trials_struct.dtype.names})
    # Add the unit_id so that you can later separate trials by unit/session.
    df_trial["unit_id"] = unit_id
    df_trial["trial_nr"] = df_trial["trial"].apply(lambda x: np.squeeze(x).item() if isinstance(x, (list, np.ndarray)) else x) - 1 # Adjust for 0-indexing
    return df_trial

trial_info_list = []
for idx, row in df_sample_new.iterrows():
    # Use the unit identifier from this row
    unit_id = row["unit_id"]  
    # Extract the trial DataFrame, including the unit identifier.
    trial_info_list.append(extract_trial_info(row["Trials"], unit_id))

# Concatenate the list of trial info DataFrames into one.
trial_info = pd.concat(trial_info_list, ignore_index=True)

### Firing Rate Calculation Across Task Epochs

In [None]:
# This cell is redundant - neural analysis setup is now in cell 33


In [11]:
# Event ts extraction
def extract_event_timestamps(df_sample_new, start_idx_col='idxEnc1', end_idx_col='idxDel1', is_window=False, window_size=0.5):
    # Extract timestamp arrays for each epoch
    epoch_ts = {}
    
    for i, row in df_sample_new.iterrows():
        unit_id = row['unit_id']
        events = row['events'].squeeze()       # Ensure it's 1D array
        
        if is_window:
            # For window around an event (±window_size seconds)
            idxs = row[start_idx_col].squeeze() - 1
            extracted = events[idxs]
            center_times = extracted[:, 0]
            window_start = center_times - window_size * 1e6  # window_size before in microseconds
            window_end = center_times + window_size * 1e6    # window_size after in microseconds
            combined = np.column_stack((window_start, window_end))
        else:
            # For regular epochs between two events
            idxs_start = row[start_idx_col].squeeze() - 1   # Ensure indices are 1D array; start with 0
            idxs_end = row[end_idx_col].squeeze() - 1   # Use specified end index
            
            # Handle case where start and end indices have different lengths
            min_length = min(len(idxs_start), len(idxs_end))
            idxs_start = idxs_start[:min_length]
            idxs_end = idxs_end[:min_length]
            
            # Index into events using the adjusted indices
            extracted_start = events[idxs_start]   # shape (n_trials, 3)
            extracted_end = events[idxs_end]   # shape (n_trials, 3)

            # Store as event start/end times
            combined = np.column_stack((extracted_start[:, 0], extracted_end[:, 0]))
        
        epoch_ts[unit_id] = combined
        
    return epoch_ts

# Get event timestamps for a specific epoch
def compute_firing_rates(df_sample_new, start_idx_col='idxEnc1', end_idx_col='idxDel1', 
                         fr_prefix='fr', is_window=False, window_size=0.5):
    # Extract timestamps for the specified epoch
    epoch_ts = extract_event_timestamps(df_sample_new, start_idx_col, end_idx_col, is_window, window_size)
    
    # Calculate baseline firing rate (between idxEnc1-1 and idxEnc1) if this is the first call
    if 'fr_baseline' not in df_sample_new.columns:
        # Create a baseline period 1 second before idxEnc1
        
        # First get the encoding timestamps
        enc_ts = extract_event_timestamps(df_sample_new, 'idxEnc1', 'idxEnc1')
        
        # Create baseline timestamps 1 second before encoding
        baseline_ts = {}
        for unit_id, timestamps in enc_ts.items():
            # For each trial, create a 1-second window ending at the encoding start
            baseline_start = timestamps[:, 0] - 1e6  # 1 second before in microseconds
            baseline_end = timestamps[:, 0]  # End at encoding start
            baseline_ts[unit_id] = np.column_stack((baseline_start, baseline_end))
        df_sample_new['fr_baseline'] = df_sample_new.apply(
            lambda row: [
                np.sum((np.ravel(row["timestamps"]) >= baseline_on) & 
                       (np.ravel(row["timestamps"]) < baseline_off)) / ((baseline_off - baseline_on) / 1e6)
                for baseline_on, baseline_off in baseline_ts[row["unit_id"]]
            ],
            axis=1
        )
    
    # Stimulus period: from stimulus onset to end of specified epoch
    epoch_col = f"{fr_prefix}_epoch"
    df_sample_new[epoch_col] = df_sample_new.apply(
        lambda row: [
            np.sum((np.ravel(row["timestamps"]) >= epoch_on) & 
                   (np.ravel(row["timestamps"]) < epoch_off)) / ((epoch_off - epoch_on) / 1e6)
            for epoch_on, epoch_off in epoch_ts[row["unit_id"]]
        ],
        axis=1
    )
    
    # Add trial_nr column to track trial numbers if not already present
    if "trial_nr" not in df_sample_new.columns:
        df_sample_new["trial_nr"] = df_sample_new[epoch_col].apply(lambda x: np.arange(len(x)))
    
    return df_sample_new

# Compute firing rates for different epochs
# Encoding 1 period
df_sample_new = compute_firing_rates(df_sample_new, 'idxEnc1', 'idxDel1')

# Delay 1 period
df_sample_new = compute_firing_rates(df_sample_new, 'idxDel1', 'idxEnc2', fr_prefix='fr_del1')

# Encoding 2 period
df_sample_new = compute_firing_rates(df_sample_new, 'idxEnc2', 'idxDel2', fr_prefix='fr_enc2')

# Delay 2 period
df_sample_new = compute_firing_rates(df_sample_new, 'idxDel2', 'idxProbeOn', fr_prefix='fr_del2')

# Response period (±0.5s window around response)
df_sample_new = compute_firing_rates(df_sample_new, 'idxResp', None, fr_prefix='fr_resp', is_window=True, window_size=0.5)

# Now explode the dataframe after all calculations are done
columns_to_explode = ['fr_baseline', 'fr_epoch', 'fr_del1_epoch', 'fr_enc2_epoch', 'fr_del2_epoch', 'fr_resp_epoch', "trial_nr"]
df_sample_new = df_sample_new.explode(columns_to_explode)

In [None]:
# Merge neural data with trial information
df_sample_new = df_sample_new.reset_index(drop=True)
trial_info = trial_info.reset_index(drop=True)

data = pd.merge(
    df_sample_new,
    trial_info,
    on=["unit_id", "trial_nr"],
    how="left",
).infer_objects()

print(f"Merged data shape: {data.shape}")
print(f"Available columns: {list(data.columns)}")


In [12]:
cols_to_keep = [
    "unit_id", "timestamps", "brainAreaOfCell", "fr_epoch", "fr_baseline", "fr_del1_epoch", 
    "fr_enc2_epoch", "fr_del2_epoch", "fr_resp_epoch", "trial_nr",
    "first_cat", "second_cat", "first_num", "second_num",
    "first_pic", "second_pic", "probe_cat", "probe_pic",
    "probe_validity", "probe_num", "correct_answer",
    "rt", "acc", "key", "cat_comparison", "events", "nTrials", "Trials", "idxEnc1", "idxEnc2", "idxDel1",
    "idxDel2", "idxProbeOn", "idxResp", "nrProcessed", "periods_Enc1", "periods_Enc2", "periods_Del1",
    "periods_Del2", "periods_Probe", "periods_Resp", "prestimEnc", "prestimMaint", "prestimProbe",
    "prestimButtonPress", "poststimEnc", "poststimMaint", "poststimProbe", "poststimButtonPress", "sessionIdx", "channel", "cellNr", "sessionID", "origClusterID"
]

data_filtered = data[cols_to_keep].copy()
print(f"Filtered data shape: {data_filtered.shape}")
print(f"Final dataset ready for analysis!")

In [13]:
# Convert to simpler, hashable values for categories and numbers
data_filtered["first_cat_simple"] = data_filtered["first_cat"].apply(
    lambda x: str(np.squeeze(x)) if isinstance(x, (list, np.ndarray)) else str(x)
)
data_filtered["second_cat_simple"] = data_filtered["second_cat"].apply(
    lambda x: str(np.squeeze(x)) if isinstance(x, (list, np.ndarray)) else str(x)
)
data_filtered["first_num_simple"] = data_filtered["first_num"].apply(
    lambda x: str(np.squeeze(x)) if isinstance(x, (list, np.ndarray)) else str(x)
)
data_filtered["second_num_simple"] = data_filtered["second_num"].apply(
    lambda x: str(np.squeeze(x)) if isinstance(x, (list, np.ndarray)) else str(x)
)

---

## 2. Single Unit Selectivity Analysis

### Overview

Identify neurons that show selective responses to stimulus categories vs numerosities during different task epochs.

In [None]:
# % ttl values
# c.marker.expstart        = 89;
# c.marker.expend          = 90;
# c.marker.fixOnset        = 10;
# c.marker.pic1            = 1;
# c.marker.delay1          = 2;
# c.marker.pic2            = 3;
# c.marker.delay2          = 4;
# c.marker.probeOnset      = 5;
# c.marker.response        = 6;
# c.marker.break           = 91;

# what names are in the df
data_filtered.columns

### Neuron Selectivity Analysis

In [15]:
import statsmodels.formula.api as smf
import statsmodels.api as sm
import numpy as np
import pandas as pd
from tqdm import tqdm

def identify_selective_neurons(data_filtered):
    """
    Analyzes neural data to identify neurons that show selective responses to different stimulus categories and numerosities.
    
    This function performs three main analyses for each neuron:
    1. First stimulus analysis: Tests if the neuron responds differently to different categories/numbers of the first stimulus
    2. Second stimulus analysis: Tests if the neuron responds differently to different categories/numbers of the second stimulus
    3. Position analysis: Tests if the neuron responds differently based on stimulus position (first vs second)
    
    Args:
        data_filtered: DataFrame containing neural data with columns for firing rates, categories, and numerosities
        
    Returns:
        DataFrame containing statistical results for each neuron's selectivity
    """
    selectivity_stats = []

    for unit_id, unit_df in tqdm(data_filtered.groupby("unit_id")):
        unit_df = unit_df.copy()
        
        # Extract firing rates for first and second stimulus presentations
        unit_df["fr_first"] = unit_df["fr_epoch"]
        unit_df["fr_second"] = unit_df["fr_enc2_epoch"]

        # Skip neurons with no variance in firing rates
        if unit_df["fr_first"].std() == 0 or unit_df["fr_second"].std() == 0:
            continue

        # Analyze selectivity to first stimulus using ANOVA
        # Tests if firing rate varies with category or number of first stimulus
        model_first = smf.ols(
            "fr_first ~ C(first_cat_simple) + C(first_num_simple)", 
            data=unit_df
        ).fit()
        anova_first = sm.stats.anova_lm(model_first, typ=2)

        # Analyze selectivity to second stimulus using ANOVA
        # Tests if firing rate varies with category or number of second stimulus
        model_second = smf.ols(
            "fr_second ~ C(second_cat_simple) + C(second_num_simple)", 
            data=unit_df
        ).fit()
        anova_second = sm.stats.anova_lm(model_second, typ=2)

        # Analyze position selectivity
        # Combines data from both stimuli to test if firing rate varies with stimulus position
        position_df = pd.DataFrame({
            'fr': np.concatenate([unit_df['fr_first'], unit_df['fr_second']]),
            'category': np.concatenate([unit_df['first_cat_simple'], unit_df['second_cat_simple']]),
            'number': np.concatenate([unit_df['first_num_simple'], unit_df['second_num_simple']]),
            'position': ['first'] * len(unit_df) + ['second'] * len(unit_df)
        })

        model_position = smf.ols(
            "fr ~ C(category) + C(number) + C(position)", 
            data=position_df
        ).fit()
        anova_position = sm.stats.anova_lm(model_position, typ=2)

        # Compile statistical results for this neuron
        stats = {
            'unit_id': unit_id,
            'area': unit_df['brainAreaOfCell'].iloc[0],
            # P-values for category and number selectivity in first stimulus
            'first_cat_pvalue': anova_first.loc['C(first_cat_simple)', 'PR(>F)'],
            'first_num_pvalue': anova_first.loc['C(first_num_simple)', 'PR(>F)'],
            # P-values for category and number selectivity in second stimulus
            'second_cat_pvalue': anova_second.loc['C(second_cat_simple)', 'PR(>F)'],
            'second_num_pvalue': anova_second.loc['C(second_num_simple)', 'PR(>F)'],
            # P-value for position selectivity
            'position_pvalue': anova_position.loc['C(position)', 'PR(>F)'],
            # Binary flags indicating significant selectivity (p < 0.05)
            'is_first_cat_selective': anova_first.loc['C(first_cat_simple)', 'PR(>F)'] < 0.05,
            'is_first_num_selective': anova_first.loc['C(first_num_simple)', 'PR(>F)'] < 0.05,
            'is_second_cat_selective': anova_second.loc['C(second_cat_simple)', 'PR(>F)'] < 0.05,
            'is_second_num_selective': anova_second.loc['C(second_num_simple)', 'PR(>F)'] < 0.05,
            'is_position_selective': anova_position.loc['C(position)', 'PR(>F)'] < 0.05,
            # R-squared values indicating model fit
            'r2_first': model_first.rsquared,
            'r2_second': model_second.rsquared,
            'r2_position': model_position.rsquared,
            # Overall selectivity flag
            'is_any_selective': (
                (anova_first.loc['C(first_cat_simple)', 'PR(>F)'] < 0.05) or
                (anova_first.loc['C(first_num_simple)', 'PR(>F)'] < 0.05) or
                (anova_second.loc['C(second_cat_simple)', 'PR(>F)'] < 0.05) or
                (anova_second.loc['C(second_num_simple)', 'PR(>F)'] < 0.05) or
                (anova_position.loc['C(position)', 'PR(>F)'] < 0.05)
            )
        }
        selectivity_stats.append(stats)

    results_df = pd.DataFrame(selectivity_stats)
    return results_df

In [None]:
# Run the selectivity analysis
selectivity_results = identify_selective_neurons(data_filtered)

---

## 3. Delay Activity & Distractor Resistance Analysis

### Principled Distractor Resistance Analysis

This analysis identifies neurons that:
1. Show category selectivity during encoding (Step 1)
2. Maintain selectivity during delay period (Step 2) 
3. Resist interference from distractors (Step 3)


### distractor resistance

# ---

## 4. Visualization of Distractor-Resistant Neurons

### Setup for Plotting


In [None]:
# ===== PRINCIPLED DISTRACTOR RESISTANCE ANALYSIS =====

def find_distractor_resistant_neurons_principled(data_filtered, selectivity_results):
    """
    Identify neurons that maintain stim1 category information through distraction.
    
    This implements a principled 3-step filtering approach:
    1. Start with stim1 category-selective neurons (encoding period)
    2. Filter to delay neurons (maintain selectivity in delay1)  
    3. Filter to distractor-resistant neurons (maintain stim 1 selectivity in delay2)
    
    Args:
        data_filtered: Trial-level neural data
        selectivity_results: Output from identify_selective_neurons()
        
    Returns:
        DataFrame with distractor resistance analysis for each stim1-selective neuron
    """
    print("=== PRINCIPLED DISTRACTOR RESISTANCE ANALYSIS ===")
    
    # Step 1: Start with neurons selective for stim1 category during encoding
    stim1_selective = selectivity_results[selectivity_results["is_first_cat_selective"]].copy()
    print(f"Step 1: Found {len(stim1_selective)} neurons selective for stim1 category during encoding")
    
    if len(stim1_selective) == 0:
        print("No stim1-selective neurons found!")
        return pd.DataFrame()
    
    # Analyze each stim1-selective neuron for distractor resistance
    distractor_analysis = []
    
    for _, neuron_row in tqdm(stim1_selective.iterrows(), desc="Analyzing stim1-selective neurons"):
        unit_id = neuron_row['unit_id']
        unit_df = data_filtered[data_filtered['unit_id'] == unit_id].copy()
        
        if len(unit_df) == 0:
            continue
        
        # Create variable for whether stim1 and stim2 categories match
        unit_df['category_match'] = (unit_df['first_cat_simple'] == unit_df['second_cat_simple'])
        
        try:
            # Step 2: Test if neuron maintains stim1 selectivity during delay1 (pre-distractor)
            model_del1 = smf.ols("fr_del1_epoch ~ C(first_cat_simple)", data=unit_df).fit()
            anova_del1 = sm.stats.anova_lm(model_del1, typ=2)
            del1_pval = anova_del1.loc['C(first_cat_simple)', 'PR(>F)']
            is_delay_neuron = del1_pval < 0.05
            
            # Step 3: Test if neuron maintains stim1 selectivity during delay2 (post-distractor)
            model_del2 = smf.ols("fr_del2_epoch ~ C(first_cat_simple)", data=unit_df).fit()
            anova_del2 = sm.stats.anova_lm(model_del2, typ=2)
            del2_pval = anova_del2.loc['C(first_cat_simple)', 'PR(>F)']
            is_distractor_resistant = del2_pval < 0.05
            
            # Test for stim1 selectivity in different trial types
            same_cat_trials = unit_df[unit_df['category_match']]
            diff_cat_trials = unit_df[~unit_df['category_match']]
            
            # Analyze same-category trials
            same_cat_del2_pval = 1.0
            if len(same_cat_trials) >= 10 and same_cat_trials['fr_del2_epoch'].std() > 0:
                try:
                    model_same = smf.ols("fr_del2_epoch ~ C(first_cat_simple)", data=same_cat_trials).fit()
                    anova_same = sm.stats.anova_lm(model_same, typ=2)
                    same_cat_del2_pval = anova_same.loc['C(first_cat_simple)', 'PR(>F)']
                except:
                    pass
            
            # Analyze different-category trials
            diff_cat_del2_pval = 1.0
            if len(diff_cat_trials) >= 10 and diff_cat_trials['fr_del2_epoch'].std() > 0:
                try:
                    model_diff = smf.ols("fr_del2_epoch ~ C(first_cat_simple)", data=diff_cat_trials).fit()
                    anova_diff = sm.stats.anova_lm(model_diff, typ=2)
                    diff_cat_del2_pval = anova_diff.loc['C(first_cat_simple)', 'PR(>F)']
                except:
                    pass
            
            # Calculate correlation between delay1 and delay2 response patterns
            del1_means = unit_df.groupby('first_cat_simple')['fr_del1_epoch'].mean()
            del2_means = unit_df.groupby('first_cat_simple')['fr_del2_epoch'].mean()
            
            correlation = 0
            if len(del1_means) > 1 and len(del2_means) > 1:
                correlation = np.corrcoef(del1_means.values, del2_means.values)[0,1]
            
            # Calculate pattern similarity between same and different category trials
            same_cat_responses = {}
            diff_cat_responses = {}
            
            for cat in unit_df['first_cat_simple'].unique():
                same_cat_mask = same_cat_trials['first_cat_simple'] == cat
                diff_cat_mask = diff_cat_trials['first_cat_simple'] == cat
                
                same_cat_responses[cat] = same_cat_trials[same_cat_mask]['fr_del2_epoch'].mean() if same_cat_mask.any() else np.nan
                diff_cat_responses[cat] = diff_cat_trials[diff_cat_mask]['fr_del2_epoch'].mean() if diff_cat_mask.any() else np.nan
            
            # Calculate pattern similarity
            valid_cats = [cat for cat in same_cat_responses.keys() 
                         if not np.isnan(same_cat_responses[cat]) and not np.isnan(diff_cat_responses[cat])]
            
            pattern_similarity = 0
            if len(valid_cats) > 1:
                same_pattern = [same_cat_responses[cat] for cat in valid_cats]
                diff_pattern = [diff_cat_responses[cat] for cat in valid_cats]
                if np.std(same_pattern) > 0 and np.std(diff_pattern) > 0:
                    pattern_similarity = np.corrcoef(same_pattern, diff_pattern)[0,1]
            
        except Exception as e:
            print(f"Error analyzing unit {unit_id}: {e}")
            continue
        
        # Compile results for this neuron
        result = {
            'unit_id': unit_id,  # Unique identifier for each neuron
            'area': unit_df['brainAreaOfCell'].iloc[0],  # Brain area where neuron was recorded
            'stim1_encoding_pval': neuron_row['first_cat_pvalue'],  # P-value for stim1 category selectivity during encoding
            'delay1_stim1_pval': del1_pval,  # P-value for stim1 category selectivity during delay1
            'is_delay_neuron': is_delay_neuron,  # Boolean indicating if neuron maintains selectivity in delay1
            'delay2_stim1_pval': del2_pval,  # P-value for stim1 category selectivity during delay2
            'is_distractor_resistant': is_distractor_resistant,  # Boolean indicating if neuron maintains selectivity in delay2
            'delay2_same_cat_pval': same_cat_del2_pval,  # P-value for stim1 selectivity in same-category trials during delay2
            'delay2_diff_cat_pval': diff_cat_del2_pval,  # P-value for stim1 selectivity in different-category trials during delay2
            'pattern_similarity': pattern_similarity,  # Correlation between response patterns in same vs different category trials
            'delay1_delay2_correlation': correlation,  # Correlation between delay1 and delay2 response patterns
            'n_same_cat_trials': len(same_cat_trials),  # Number of trials where stim1 and stim2 categories match
            'n_diff_cat_trials': len(diff_cat_trials),  # Number of trials where stim1 and stim2 categories differ
            'analysis_stage': (  # Classification of neuron based on selectivity across epochs
                'encoding_only' if not is_delay_neuron else
                'delay_only' if not is_distractor_resistant else  
                'distractor_resistant'
            )
        }
        
        distractor_analysis.append(result)
    
    results_df = pd.DataFrame(distractor_analysis)
    
    # Print analysis summary
    if len(results_df) > 0:
        print(f"\nStep 2: {results_df['is_delay_neuron'].sum()}/{len(results_df)} are delay neurons")
        print(f"Step 3: {results_df['is_distractor_resistant'].sum()}/{len(results_df)} are distractor resistant")  
        
        print(f"\nBreakdown by analysis stage:")
        stage_counts = results_df['analysis_stage'].value_counts()
        for stage, count in stage_counts.items():
            print(f"  {stage}: {count}")
    
    return results_df

# Execute the principled distractor resistance analysis
print("Running principled distractor resistance analysis...")
distractor_results_principled = find_distractor_resistant_neurons_principled(data_filtered, selectivity_results)

# Extract the distractor-resistant neurons
distractor_resistant_neurons = distractor_results_principled[
    distractor_results_principled['is_distractor_resistant']
].copy()

print(f"\n=== DISTRACTOR-RESISTANT NEURONS ===")
if len(distractor_resistant_neurons) > 0:
    print(f"Found {len(distractor_resistant_neurons)} distractor-resistant neurons:")
    for _, row in distractor_resistant_neurons.iterrows():
        print(f"  Unit {row['unit_id']} ({row['area']}) - Pattern similarity: {row['pattern_similarity']:.3f}")
    
    # Save complete analysis results
    distractor_results_principled.to_csv("principled_distractor_analysis.csv", index=False)
    distractor_resistant_neurons.to_csv("distractor_resistant_neurons.csv", index=False)
    print("Results saved to CSV files")
    
else:
    print("No distractor-resistant neurons found!")

In [18]:
# Neural analysis package setup and function definitions
# Install with: pip install git+https://github.com/ioqfwfq/rlab_neural_analysis.git@jz
try:
    from neural_analysis.visualize import plot_spikes_with_PSTH
    from neural_analysis.spikes import get_spikes
    NEURAL_ANALYSIS_AVAILABLE = True
    print("✅ Neural analysis package loaded successfully")
except ImportError:
    NEURAL_ANALYSIS_AVAILABLE = False
    print("⚠️ Warning: neural_analysis package not found.")
    print("Install with: pip install git+https://github.com/ioqfwfq/rlab_neural_analysis.git@jz")
    print("Some visualization functions will be skipped without this package.")
    
    # Define a dummy function to prevent errors
    def plot_spikes_with_PSTH(*args, **kwargs):
        print("plot_spikes_with_PSTH is not available - neural_analysis package not installed")
        return None, None

# Define extract_event_timestamps function if not available
def extract_event_timestamps(df_sample_new, start_idx_col='idxEnc1', is_window=True, window_size=0):
    """
    Extract event timestamps for alignment.
    Returns a dictionary mapping unit_id to timestamp arrays.
    """
    epoch_ts = {}
    for unit_id, unit_df in df_sample_new.groupby("unit_id"):
        # Extract alignment timestamps from the specified column
        timestamps = []
        for idx, row in unit_df.iterrows():
            if start_idx_col in row and pd.notna(row[start_idx_col]):
                ts = row[start_idx_col]
                if isinstance(ts, (list, np.ndarray)):
                    ts = np.squeeze(ts)
                timestamps.append(ts)
        
        if timestamps:
            epoch_ts[unit_id] = np.array(timestamps).reshape(-1, 1)
        else:
            epoch_ts[unit_id] = np.array([]).reshape(0, 1)
    
    print(f"Created epoch timestamps for {len(epoch_ts)} units")
    return epoch_ts

# Create epoch timestamps for plotting
print("Creating epoch timestamps...")
epoch_ts = extract_event_timestamps(
    df_sample_new=data_filtered,
    start_idx_col='idxEnc1',
    is_window=True,
    window_size=0
)

In [None]:
# ===== PLOT STEP 3 DISTRACTOR-RESISTANT NEURONS =====

# Make sure output directory exists
os.makedirs("step3_distractor_plots", exist_ok=True)

# Get neurons that passed step 3 (distractor-resistant) 
if 'distractor_results_principled' in locals():
    step3_neurons = distractor_results_principled[
        distractor_results_principled['is_distractor_resistant']
    ].copy()
    
    print(f"Found {len(step3_neurons)} step 3 distractor-resistant neurons")
    
    if len(step3_neurons) > 0:
        # Use all step3 neurons, sorted by strongest delay2 selectivity (lowest p-value)
        all_step3 = step3_neurons.sort_values('delay2_stim1_pval')
        all_step3_ids = all_step3['unit_id'].tolist()
        
        print(f"All {len(all_step3)} step 3 neurons by delay2 selectivity:")
        for _, row in all_step3.iterrows():
            print(f"  Unit {row['unit_id']} ({row['area']}) - Delay2 p={row['delay2_stim1_pval']:.4f}")
        
        # Plot each neuron
        for unit_id in all_step3_ids:
            print(f"\nPlotting Unit {unit_id}...")
            
            df_unit = data_filtered[data_filtered["unit_id"] == unit_id].reset_index(drop=True)
            if df_unit.empty or unit_id not in epoch_ts:
                print(f"  Skipping unit {unit_id} - no data")
                continue
            
            # Get unit info
            unit_stats = step3_neurons[step3_neurons['unit_id'] == unit_id].iloc[0]
            area = unit_stats['area']
            encoding_pval = unit_stats['stim1_encoding_pval']
            delay1_pval = unit_stats['delay1_stim1_pval'] 
            delay2_pval = unit_stats['delay2_stim1_pval']
            correlation = unit_stats['delay1_delay2_correlation']
            
            try:
                spikes = np.asarray(df_unit["timestamps"].iloc[0]).flatten().astype(np.float64) / 1e6
                spikes = np.sort(spikes)
                
                # Group by stim1 category
                group_labels = df_unit['first_cat_simple'].apply(
                    lambda x: np.squeeze(x).item() if isinstance(x, (list, np.ndarray)) else x
                )
                
                # Use stim1 numerosity for additional info
                stats = df_unit['first_num_simple'].apply(
                    lambda x: np.squeeze(x).item() if isinstance(x, (list, np.ndarray)) else x
                )
                
                alignments = np.asarray(epoch_ts[unit_id][:, 0], dtype=np.float64) / 1e6
                
                # Plot spikes with PSTH (if neural_analysis package is available)
                if NEURAL_ANALYSIS_AVAILABLE:
                    axes = plot_spikes_with_PSTH(
                        spikes,
                        alignments,
                        window=(-1, 8),
                        group_labels=group_labels,
                        stats=stats,
                        plot_stats=False,
                        sig_test=True,
                        cmap="Set1",
                    )
                else:
                    print(f"  Skipping plot for unit {unit_id} - neural_analysis package not available")
                    continue
                
                if axes is None:
                    print(f"  Failed to create plot for unit {unit_id}")
                    continue
                
                # Add baseline
                unit_baseline = np.mean(df_unit["fr_baseline"])
                xmin, xmax = axes[1].get_xlim()
                axes[1].hlines(
                    y=unit_baseline,
                    xmin=xmin,
                    xmax=xmax,
                    colors="gray",
                    linestyles="--",
                    alpha=0.7,
                    label=f"baseline = {unit_baseline:.1f}"
                )
                
                # Add task event lines
                event_times = [1, 2, 3, 5.5]
                event_labels = ['Stim1', 'Delay1', 'Stim2', 'Delay2']
                
                for v, label in zip(event_times, event_labels):
                    for ax in axes:
                        ax.axvline(x=v, color="black", linestyle="--", linewidth=1, alpha=0.8)
                    axes[0].text(v, axes[0].get_ylim()[1]*0.95, label, 
                               rotation=90, ha='right', va='top', fontsize=8)
                
                # Highlight delay2 period
                for ax in axes:
                    ax.axvspan(3, 5.5, alpha=0.2, color='yellow', label='Delay2')
                
                # Title with statistics
                axes[0].set_title(
                    f"{area} Unit {unit_id} — Distractor-Resistant Neuron\n"
                    f"Encoding p={encoding_pval:.4f}, Delay1 p={delay1_pval:.4f}, "
                    f"Delay2 p={delay2_pval:.4f}\n"
                    f"Delay1-Delay2 correlation: {correlation:.3f}"
                )
                
                axes[1].set_xlabel("Time from Stim1 onset [s]")
                axes[1].set_ylabel("Firing Rate (Hz)")
                
                # Save figure
                fname = f"step3_distractor_plots/{area}_Unit{unit_id}_distractor_resistant.png"
                plt.savefig(fname, dpi=300, bbox_inches="tight")
                plt.close()
                print(f"  Saved: {fname}")
                
            except Exception as e:
                print(f"  Error plotting unit {unit_id}: {e}")
    
    print(f"\nCompleted plotting all step 3 distractor-resistant neurons")
    
else:
    print("Need to run the principled distractor analysis first!")

In [None]:

# Summary of distractor-resistant neurons found
if 'distractor_results_principled' in locals():
    print(f"\n=== DISTRACTOR RESISTANCE SUMMARY ===")
    step1_count = len(distractor_results_principled)
    step2_count = distractor_results_principled['is_delay_neuron'].sum()
    step3_count = distractor_results_principled['is_distractor_resistant'].sum()
    
    print(f"Step 1 (Stim1-selective): {step1_count} neurons")
    print(f"Step 2 (Delay neurons): {step2_count} neurons ({step2_count/step1_count*100:.1f}%)")
    print(f"Step 3 (Distractor-resistant): {step3_count} neurons ({step3_count/step1_count*100:.1f}%)")
else:
    print("Run the principled distractor analysis first to get results")

In [None]:
# ===== ADDITIONAL PLOTTING FUNCTIONS =====
# Note: Advanced plotting functions require the neural_analysis package
# Install with: pip install git+https://github.com/ioqfwfq/rlab_neural_analysis.git@jz

print("Advanced plotting functions available when neural_analysis package is installed")
print("Run the core analysis (cells 1-32) first to get basic results")

# Basic analysis summary function that works without additional packages
def summarize_results():
    """Provide a summary of the analysis results"""
    if 'selectivity_results' in locals() or 'selectivity_results' in globals():
        print("✅ Selectivity analysis completed")
        print(f"   Found {len(selectivity_results)} responsive neurons")
        
    if 'distractor_results_principled' in locals() or 'distractor_results_principled' in globals():
        print("✅ Distractor resistance analysis completed") 
        print(f"   Found {distractor_results_principled['is_distractor_resistant'].sum()} distractor-resistant neurons")
    else:
        print("⚠️ Run distractor resistance analysis first")
        
    print("\nTo run advanced plotting:")
    print("1. Install neural_analysis package")
    print("2. Run plotting cells (34+)")

summarize_results()

### proactive interference

In [None]:
# ===== STIM2-TUNED TRIALS GROUPED BY STIM1 CATEGORY (PROACTIVE INTERFERENCE) =====

# Make sure output directory exists
os.makedirs("stim2_tuned_by_stim1_plots", exist_ok=True)

def plot_stim2_tuned_by_stim1(data_filtered, step3_neurons, unit_ids, epoch_ts):
    """
    Test proactive interference: How does stim1 affect stim2 maintenance?
    
    For each neuron:
    1. Find the neuron's best stim2 category (highest response during encoding2)
    2. Filter to trials where stim2 was that best category
    3. Group those trials by stim1 category (the "proactive distractor")
    4. Plot delay2 activity to see if stim1 interferes with stim2 maintenance
    
    This is the reverse of the previous analysis - now stim1 is the "distractor"
    """
    
    print(f"Analyzing proactive interference (stim1 → stim2) for {len(unit_ids)} neurons...")
    
    for unit_id in unit_ids:
        print(f"\nAnalyzing Unit {unit_id} for proactive interference...")
        
        try:
            df_unit = data_filtered[data_filtered["unit_id"] == unit_id].reset_index(drop=True)
            if df_unit.empty or unit_id not in epoch_ts:
                print(f"  Skipping unit {unit_id} - no data")
                continue
            
            # Get unit info
            unit_stats = step3_neurons[step3_neurons['unit_id'] == unit_id].iloc[0]
            area = unit_stats['area']
            
            # Step 1: Find the neuron's best stim2 category
            # Use encoding2 responses to determine best stim2 category
            stim2_responses = df_unit.groupby('second_cat_simple')['fr_enc2_epoch'].mean()
            best_stim2_category = stim2_responses.idxmax()
            best_stim2_response = stim2_responses.max()
            
            print(f"  Best stim2 category: {best_stim2_category} (mean response: {best_stim2_response:.2f} Hz)")
            print(f"  All stim2 responses: {stim2_responses.to_dict()}")
            
            # Step 2: Filter to trials where stim2 was the best category
            best_stim2_trials = df_unit[df_unit['second_cat_simple'] == best_stim2_category].copy()
            
            if len(best_stim2_trials) < 10:
                print(f"  Insufficient trials with best stim2 category ({len(best_stim2_trials)}) - skipping")
                continue
            
            print(f"  Trials with best stim2 category: {len(best_stim2_trials)}")
            
            # Step 3: Group by stim1 category (the proactive distractor)
            stim1_counts = best_stim2_trials['first_cat_simple'].value_counts()
            print(f"  Stim1 category distribution: {stim1_counts.to_dict()}")
            
            # Check if we have enough trials in each stim1 category
            min_trials_per_stim1 = 3
            valid_stim1_cats = stim1_counts[stim1_counts >= min_trials_per_stim1].index.tolist()
            
            if len(valid_stim1_cats) < 2:
                print(f"  Insufficient stim1 category diversity - skipping")
                continue
            
            print(f"  Valid stim1 categories (≥{min_trials_per_stim1} trials): {valid_stim1_cats}")
            
            # Filter to only trials with valid stim1 categories
            plot_trials = best_stim2_trials[best_stim2_trials['first_cat_simple'].isin(valid_stim1_cats)].copy()
            
            # Step 4: Plot using existing infrastructure
            spikes = np.asarray(df_unit["timestamps"].iloc[0]).flatten().astype(np.float64) / 1e6
            spikes = np.sort(spikes)
            
            # Get alignments for the filtered trials
            all_alignments = np.asarray(epoch_ts[unit_id][:, 0], dtype=np.float64) / 1e6
            plot_alignments = all_alignments[plot_trials.index.values]
            
            # Group labels = stim1 categories (the proactive distractors)
            group_labels = plot_trials['first_cat_simple'].apply(
                lambda x: np.squeeze(x).item() if isinstance(x, (list, np.ndarray)) else x
            )
            
            # Stats = stim2 numerosity for additional info
            stats = plot_trials['second_num_simple'].apply(
                lambda x: np.squeeze(x).item() if isinstance(x, (list, np.ndarray)) else x
            )
            
            # Use vibrant colormap - different from previous analysis
            axes = plot_spikes_with_PSTH(
                spikes,
                plot_alignments,
                window=(-1, 8),
                group_labels=group_labels,  # Groups by stim1 category (proactive distractor)
                stats=stats,
                plot_stats=False,
                sig_test=True,
                cmap="viridis",  # Different colormap to distinguish this analysis
            )
            
            # Add baseline
            unit_baseline = np.mean(df_unit["fr_baseline"])
            xmin, xmax = axes[1].get_xlim()
            axes[1].hlines(
                y=unit_baseline,
                xmin=xmin,
                xmax=xmax,
                colors="gray",
                linestyles="--",
                alpha=0.7,
                label=f"baseline = {unit_baseline:.1f}"
            )
            
            # Add task event lines (no labels for cleaner appearance)
            event_times = [1, 2, 3, 5.5]
            for v in event_times:
                for ax in axes:
                    ax.axvline(x=v, color="black", linestyle="--", linewidth=1, alpha=0.8)
            
            # Highlight key periods with different color scheme
            for ax in axes:
                ax.axvspan(0, 1, alpha=0.1, color='red', label='Stim1 (proactive distractor)')
                ax.axvspan(2, 3, alpha=0.1, color='green', label='Stim2 (best category)')
                ax.axvspan(3, 5.5, alpha=0.2, color='yellow', label='Delay2 (interference test)')
            
            # Title explaining the analysis
            axes[0].set_title(
                f"{area} Unit {unit_id} — Best Stim2 Category: '{best_stim2_category}'\n"
                f"Grouped by Stim1 Category (Lines = different Stim1, n={len(plot_trials)} trials)\n"
                f"Question: Does Stim1 interfere with Stim2 maintenance? (Proactive interference)"
            )
            
            axes[1].set_xlabel("Time from Stim1 onset [s]")
            axes[1].set_ylabel("Firing Rate (Hz)")
            # axes[1].legend(loc='upper right', fontsize=8)
            
            # Save the figure
            fname = f"stim2_tuned_by_stim1_plots/{area}_Unit{unit_id}_best_stim2_{best_stim2_category}_by_stim1.png"
            plt.savefig(fname, dpi=300, bbox_inches="tight")
            plt.close()
            print(f"  Saved: {fname}")
            
            # Detailed analysis of proactive interference
            print(f"  === Proactive Interference Analysis ===")
            
            # Compare encoding2 responses by stim1 category
            stim1_enc2_responses = plot_trials.groupby('first_cat_simple')['fr_enc2_epoch'].agg(['mean', 'sem', 'count'])
            print(f"  Stim2 encoding responses by stim1 category:")
            for stim1_cat, row in stim1_enc2_responses.iterrows():
                print(f"    Stim1={stim1_cat}: {row['mean']:.2f}±{row['sem']:.2f} Hz (n={row['count']})")
            
            # Compare delay2 responses (key test for proactive interference)
            stim1_del2_responses = plot_trials.groupby('first_cat_simple')['fr_del2_epoch'].agg(['mean', 'sem', 'count'])
            print(f"  Delay2 responses by stim1 category (proactive interference test):")
            for stim1_cat, row in stim1_del2_responses.iterrows():
                print(f"    Stim1={stim1_cat}: {row['mean']:.2f}±{row['sem']:.2f} Hz (n={row['count']})")
            
            # Test if stim1 category affects stim2 maintenance (proactive interference)
            if len(valid_stim1_cats) >= 2:
                try:
                    # Test effect on encoding2 period
                    model_stim1_enc2 = smf.ols("fr_enc2_epoch ~ C(first_cat_simple)", data=plot_trials).fit()
                    anova_stim1_enc2 = sm.stats.anova_lm(model_stim1_enc2, typ=2)
                    stim1_enc2_pval = anova_stim1_enc2.loc['C(first_cat_simple)', 'PR(>F)']
                    
                    # Test effect on delay2 period (main test for proactive interference)
                    model_stim1_del2 = smf.ols("fr_del2_epoch ~ C(first_cat_simple)", data=plot_trials).fit()
                    anova_stim1_del2 = sm.stats.anova_lm(model_stim1_del2, typ=2)
                    stim1_del2_pval = anova_stim1_del2.loc['C(first_cat_simple)', 'PR(>F)']
                    
                    print(f"  Stim1 effect on stim2 encoding: p = {stim1_enc2_pval:.4f} {'*' if stim1_enc2_pval < 0.05 else ''}")
                    print(f"  Stim1 effect on stim2 maintenance (delay2): p = {stim1_del2_pval:.4f} {'*' if stim1_del2_pval < 0.05 else ''}")
                    
                    # Determine interference pattern
                    if stim1_del2_pval < 0.05:
                        print(f"  → PROACTIVE INTERFERENCE DETECTED: Stim1 interferes with stim2 maintenance")
                    else:
                        print(f"  → No proactive interference: Stim2 maintenance resistant to stim1")
                        
                except Exception as e:
                    print(f"  Could not test proactive interference: {e}")
            # Test if stim1 category affects stim2 maintenance (proactive interference)\n",
            if len(valid_stim1_cats) >= 2:
                try:
                    # Test effect on encoding2 period
                    model_stim1_enc2 = smf.ols("fr_enc2_epoch ~ C(first_cat_simple)", data=plot_trials).fit()
                    anova_stim1_enc2 = sm.stats.anova_lm(model_stim1_enc2, typ=2)
                    stim1_enc2_pval = anova_stim1_enc2.loc['C(first_cat_simple)', 'PR(>F)']
                    
                    # Test effect on delay2 period (main test for proactive interference)
                    model_stim1_del2 = smf.ols("fr_del2_epoch ~ C(first_cat_simple)", data=plot_trials).fit()
                    anova_stim1_del2 = sm.stats.anova_lm(model_stim1_del2, typ=2)
                    stim1_del2_pval = anova_stim1_del2.loc['C(first_cat_simple)', 'PR(>F)']
                    
                    print(f"  Stim1 effect on stim2 encoding: p = {stim1_enc2_pval:.4f} {'*' if stim1_enc2_pval < 0.05 else ''}")
                    print(f"  Stim1 effect on stim2 maintenance (delay2): p = {stim1_del2_pval:.4f} {'*' if stim1_del2_pval < 0.05 else ''}")
                    
                    # Determine interference pattern
                    if stim1_del2_pval < 0.05:
                        print(f"  → PROACTIVE INTERFERENCE DETECTED: Stim1 interferes with stim2 maintenance")
                    else:
                        print(f"  → No proactive interference: Stim2 maintenance resistant to stim1")
                        
                except Exception as e:
                    print(f"  Could not test proactive interference: {e}")
            # Compare with category congruence\n",
            # Compare with category congruence
            same_cat_mask = plot_trials['first_cat_simple'] == plot_trials['second_cat_simple']
            if same_cat_mask.any() and (~same_cat_mask).any():
                same_cat_del2 = plot_trials[same_cat_mask]['fr_del2_epoch'].mean()
                diff_cat_del2 = plot_trials[~same_cat_mask]['fr_del2_epoch'].mean()
                print(f"  Same category (stim1=stim2): {same_cat_del2:.2f} Hz")
                print(f"  Different categories: {diff_cat_del2:.2f} Hz")
                print(f"  Category congruence effect: {same_cat_del2 - diff_cat_del2:.2f} Hz")
            
        except Exception as e:
            print(f"  Error analyzing unit {unit_id}: {e}")
    
    print(f"\nCompleted proactive interference analysis for all step 3 neurons")
    print("Key insights to look for:")
    print("- Do colored lines (stim1 categories) affect delay2 responses when stim2 is optimal?")
    print("- Strong proactive interference = lines diverge and stay separated in delay2")
    print("- Weak proactive interference = lines converge together in delay2")
    print("- Compare with previous analysis to see bidirectional vs unidirectional interference")

# Run the proactive interference analysis
if 'all_step3_ids' in locals() and 'step3_neurons' in locals() and 'epoch_ts' in locals():
    plot_stim2_tuned_by_stim1(data_filtered, step3_neurons, all_step3_ids, epoch_ts)
else:
    print("Need to run the step 3 analysis first to get:")
    print("- all_step3_ids")
    print("- step3_neurons")
    print("- epoch_ts")
    print("\nOr run the setup variables code block first.")

### category congruence effects -- during stim 2

In [None]:
# ===== CATEGORY CONGRUENCE ANALYSIS - ALL NEURONS =====

def analyze_category_congruence_all_neurons(data_filtered, selectivity_results):
    """
    Analyze category congruence effects (stim1 vs stim2 relationship) across ALL neurons.
    
    This tests adaptation/facilitation effects in the broader population, not just 
    distractor-resistant neurons, since the core question is about encoding responses.
    """
    
    print("Analyzing category congruence effects across ALL neurons...")
    
    # Get all neurons with any selectivity (broader than just stim1 category selective)
    responsive_neurons = selectivity_results[selectivity_results["is_any_selective"]].copy()
    print(f"Found {len(responsive_neurons)} responsive neurons to analyze")
    
    congruence_results = []
    
    for _, neuron_row in tqdm(responsive_neurons.iterrows(), desc="Analyzing neurons"):
        unit_id = neuron_row['unit_id']
        
        try:
            df_unit = data_filtered[data_filtered["unit_id"] == unit_id].reset_index(drop=True)
            if df_unit.empty:
                continue
            
            area = neuron_row['area']
            
            # Create category match variable
            df_unit['category_match'] = (df_unit['first_cat_simple'] == df_unit['second_cat_simple'])
            
            # Split into same vs different category trials
            same_cat_trials = df_unit[df_unit['category_match']].copy()
            diff_cat_trials = df_unit[~df_unit['category_match']].copy()
            
            if len(same_cat_trials) < 5 or len(diff_cat_trials) < 5:
                continue
            
            # === CORE ANALYSIS: Response Differences (Stim2 - Stim1) ===
            same_cat_trials['response_diff'] = same_cat_trials['fr_enc2_epoch'] - same_cat_trials['fr_epoch']
            diff_cat_trials['response_diff'] = diff_cat_trials['fr_enc2_epoch'] - diff_cat_trials['fr_epoch']
            
            same_cat_diff_mean = same_cat_trials['response_diff'].mean()
            same_cat_diff_sem = same_cat_trials['response_diff'].sem()
            diff_cat_diff_mean = diff_cat_trials['response_diff'].mean()
            diff_cat_diff_sem = diff_cat_trials['response_diff'].sem()
            
            # Statistical test
            from scipy.stats import ttest_ind
            diff_ttest_stat, diff_ttest_pval = ttest_ind(
                same_cat_trials['response_diff'].dropna(), 
                diff_cat_trials['response_diff'].dropna()
            )
            
            # === Response Ratios ===
            epsilon = 0.01  # Avoid division by zero
            same_cat_trials['response_ratio'] = (same_cat_trials['fr_enc2_epoch'] + epsilon) / (same_cat_trials['fr_epoch'] + epsilon)
            diff_cat_trials['response_ratio'] = (diff_cat_trials['fr_enc2_epoch'] + epsilon) / (diff_cat_trials['fr_epoch'] + epsilon)
            
            same_cat_ratio_mean = same_cat_trials['response_ratio'].mean()
            diff_cat_ratio_mean = diff_cat_trials['response_ratio'].mean()
            
            # Test ratio differences
            ratio_ttest_stat, ratio_ttest_pval = ttest_ind(
                same_cat_trials['response_ratio'].dropna(), 
                diff_cat_trials['response_ratio'].dropna()
            )
            
            # === Correlations ===
            same_cat_corr = same_cat_trials[['fr_epoch', 'fr_enc2_epoch']].corr().iloc[0,1]
            diff_cat_corr = diff_cat_trials[['fr_epoch', 'fr_enc2_epoch']].corr().iloc[0,1]
            
            # === Effect Classification ===
            adaptation_effect = "none"
            if min(diff_ttest_pval, ratio_ttest_pval) < 0.05:
                if same_cat_diff_mean < diff_cat_diff_mean:
                    if same_cat_diff_mean < 0:
                        adaptation_effect = "adaptation"  # Same category shows reduction
                    else:
                        adaptation_effect = "less_facilitation"
                else:
                    if same_cat_diff_mean > 0:
                        adaptation_effect = "facilitation"  # Same category shows enhancement
                    else:
                        adaptation_effect = "less_adaptation"
            
            # === Include Neuron Type Classification ===
            # Determine what type of neuron this is
            neuron_type = "encoding_only"
            if 'distractor_results_principled' in globals():
                distractor_info = distractor_results_principled[distractor_results_principled['unit_id'] == unit_id]
                if len(distractor_info) > 0:
                    neuron_type = distractor_info.iloc[0]['analysis_stage']
            
            # === Store Results ===
            result = {
                'unit_id': unit_id,
                'area': area,
                'neuron_type': neuron_type,
                'n_same_cat_trials': len(same_cat_trials),
                'n_diff_cat_trials': len(diff_cat_trials),
                
                # Response differences
                'same_cat_diff_mean': same_cat_diff_mean,
                'same_cat_diff_sem': same_cat_diff_sem,
                'diff_cat_diff_mean': diff_cat_diff_mean,
                'diff_cat_diff_sem': diff_cat_diff_sem,
                'difference_pvalue': diff_ttest_pval,
                
                # Response ratios
                'same_cat_ratio_mean': same_cat_ratio_mean,
                'diff_cat_ratio_mean': diff_cat_ratio_mean,
                'ratio_pvalue': ratio_ttest_pval,
                
                # Correlations
                'same_cat_correlation': same_cat_corr,
                'diff_cat_correlation': diff_cat_corr,
                
                # Classification
                'adaptation_effect': adaptation_effect,
                'significant_congruence_effect': min(diff_ttest_pval, ratio_ttest_pval) < 0.05,
                
                # Selectivity info
                'is_first_cat_selective': neuron_row['is_first_cat_selective'],
                'is_first_num_selective': neuron_row['is_first_num_selective'],
                'is_second_cat_selective': neuron_row['is_second_cat_selective'],
                'is_second_num_selective': neuron_row['is_second_num_selective'],
            }
            
            congruence_results.append(result)
            
        except Exception as e:
            print(f"Error analyzing unit {unit_id}: {e}")
            continue
    
    # Convert to DataFrame
    congruence_df = pd.DataFrame(congruence_results)
    
    if len(congruence_df) == 0:
        print("No neurons analyzed successfully!")
        return pd.DataFrame()
    
    # === SUMMARY ANALYSIS ===
    print(f"\n=== CATEGORY CONGRUENCE SUMMARY (ALL NEURONS) ===")
    print(f"Successfully analyzed: {len(congruence_df)} neurons")
    print(f"Significant congruence effects: {congruence_df['significant_congruence_effect'].sum()}/{len(congruence_df)} ({congruence_df['significant_congruence_effect'].mean()*100:.1f}%)")
    
    # Effect type distribution
    print(f"\n=== EFFECT TYPE DISTRIBUTION ===")
    effect_counts = congruence_df['adaptation_effect'].value_counts()
    for effect, count in effect_counts.items():
        print(f"  {effect}: {count} neurons ({count/len(congruence_df)*100:.1f}%)")
    
    # Average effects
    avg_same_diff = congruence_df['same_cat_diff_mean'].mean()
    avg_diff_diff = congruence_df['diff_cat_diff_mean'].mean()
    print(f"\n=== AVERAGE RESPONSE DIFFERENCES ===")
    print(f"  Same category trials: {avg_same_diff:.3f} Hz")
    print(f"  Different category trials: {avg_diff_diff:.3f} Hz")
    print(f"  Overall congruence effect: {avg_same_diff - avg_diff_diff:.3f} Hz")
    
    if avg_same_diff < avg_diff_diff:
        print("  → Population shows ADAPTATION (same category reduces stim2 response)")
    else:
        print("  → Population shows FACILITATION (same category enhances stim2 response)")
    
    # === ANALYSIS BY NEURON TYPE ===
    if 'neuron_type' in congruence_df.columns:
        print(f"\n=== EFFECTS BY NEURON TYPE ===")
        type_summary = congruence_df.groupby('neuron_type').agg({
            'significant_congruence_effect': ['count', 'sum', 'mean'],
            'same_cat_diff_mean': 'mean',
            'diff_cat_diff_mean': 'mean'
        }).round(3)
        
        for neuron_type in congruence_df['neuron_type'].unique():
            type_data = congruence_df[congruence_df['neuron_type'] == neuron_type]
            n_total = len(type_data)
            n_sig = type_data['significant_congruence_effect'].sum()
            avg_effect = type_data['same_cat_diff_mean'].mean() - type_data['diff_cat_diff_mean'].mean()
            
            print(f"  {neuron_type}: {n_sig}/{n_total} significant ({n_sig/n_total*100:.1f}%), avg effect: {avg_effect:.3f} Hz")
    
    # === ANALYSIS BY BRAIN AREA ===
    print(f"\n=== EFFECTS BY BRAIN AREA ===")
    area_summary = congruence_df.groupby('area').agg({
        'significant_congruence_effect': ['count', 'sum'],
        'same_cat_diff_mean': 'mean',
        'adaptation_effect': lambda x: (x == 'adaptation').sum()
    }).round(3)
    
    for area in congruence_df['area'].unique():
        area_data = congruence_df[congruence_df['area'] == area]
        if len(area_data) >= 3:  # Only show areas with enough neurons
            n_total = len(area_data)
            n_sig = area_data['significant_congruence_effect'].sum()
            n_adapt = (area_data['adaptation_effect'] == 'adaptation').sum()
            avg_effect = area_data['same_cat_diff_mean'].mean() - area_data['diff_cat_diff_mean'].mean()
            
            print(f"  {area}: {n_sig}/{n_total} significant, {n_adapt} adaptation, avg effect: {avg_effect:.3f} Hz")
    
    # Save results
    os.makedirs("category_congruence_analysis", exist_ok=True)
    congruence_df.to_csv("category_congruence_analysis/all_neurons_congruence_analysis.csv", index=False)
    print(f"\nResults saved to: category_congruence_analysis/all_neurons_congruence_analysis.csv")
    
    return congruence_df


# === RUN THE ANALYSIS ===
if 'selectivity_results' in locals() and 'data_filtered' in locals():
    print("Running category congruence analysis on all responsive neurons...")
    all_neurons_congruence = analyze_category_congruence_all_neurons(data_filtered, selectivity_results)
    
    # Quick comparison with distractor-resistant neurons if available
    if 'all_step3_ids' in locals() and len(all_neurons_congruence) > 0:
        print(f"\n=== COMPARISON: ALL NEURONS vs DISTRACTOR-RESISTANT ===")
        
        all_neurons_effects = all_neurons_congruence['significant_congruence_effect'].mean()
        distractor_resistant_data = all_neurons_congruence[all_neurons_congruence['neuron_type'] == 'distractor_resistant']
        
        if len(distractor_resistant_data) > 0:
            distractor_effects = distractor_resistant_data['significant_congruence_effect'].mean()
            print(f"All neurons: {all_neurons_effects*100:.1f}% show significant congruence effects")
            print(f"Distractor-resistant neurons: {distractor_effects*100:.1f}% show significant congruence effects")
        else:
            print("No distractor-resistant neurons found in the analysis")
    
else:
    print("Need selectivity_results and data_filtered to run this analysis!")
    print("Make sure you've run the neuron selectivity analysis first.")

In [None]:
# ===== CATEGORY CONGRUENCE VISUALIZATIONS FOR DISTRACTOR RESISTANCE =====

# import matplotlib.pyplot as plt
# import seaborn as sns
# import numpy as np
# import pandas as pd

# Make sure we have the results
if 'all_neurons_congruence' not in locals():
    print("Need to run the category congruence analysis first!")
else:
    # Create output directory
    os.makedirs("category_congruence_plots", exist_ok=True)
    
    print("Creating category congruence visualizations...")
    
    # === PLOT 1: Distractor Resistance vs Congruence Effects ===
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # Plot 1A: Percentage with significant effects by neuron type
    type_summary = all_neurons_congruence.groupby('neuron_type').agg({
        'significant_congruence_effect': ['count', 'sum', 'mean']
    }).round(3)
    
    neuron_types = []
    percentages = []
    counts = []
    
    for ntype in all_neurons_congruence['neuron_type'].unique():
        type_data = all_neurons_congruence[all_neurons_congruence['neuron_type'] == ntype]
        pct = type_data['significant_congruence_effect'].mean() * 100
        n_sig = type_data['significant_congruence_effect'].sum()
        n_total = len(type_data)
        
        neuron_types.append(f"{ntype}\n(n={n_total})")
        percentages.append(pct)
        counts.append(f"{n_sig}/{n_total}")
    
    bars = axes[0,0].bar(neuron_types, percentages, 
                        color=['lightblue', 'orange', 'red'])
    axes[0,0].set_ylabel('% with Significant Congruence Effects')
    axes[0,0].set_title('Congruence Effects by Neuron Type')
    axes[0,0].set_ylim(0, max(percentages) * 1.2)
    
    # Add count labels on bars
    for bar, count in zip(bars, counts):
        height = bar.get_height()
        axes[0,0].text(bar.get_x() + bar.get_width()/2., height + 0.5,
                      count, ha='center', va='bottom', fontsize=10)
    
    # Plot 1B: Effect size by neuron type
    type_effects = []
    type_labels = []
    for ntype in all_neurons_congruence['neuron_type'].unique():
        type_data = all_neurons_congruence[all_neurons_congruence['neuron_type'] == ntype]
        effect = type_data['same_cat_diff_mean'].mean() - type_data['diff_cat_diff_mean'].mean()
        type_effects.append(effect)
        type_labels.append(ntype)
    
    bars = axes[0,1].bar(type_labels, type_effects, 
                        color=['lightblue', 'orange', 'red'])
    axes[0,1].set_ylabel('Average Congruence Effect (Hz)')
    axes[0,1].set_title('Effect Size by Neuron Type')
    axes[0,1].axhline(y=0, color='black', linestyle='--', alpha=0.5)
    plt.setp(axes[0,1].xaxis.get_majorticklabels(), rotation=45)
    
    # Plot 1C: Brain area effects
    area_data = []
    for area in all_neurons_congruence['area'].unique():
        area_subset = all_neurons_congruence[all_neurons_congruence['area'] == area]
        if len(area_subset) >= 10:  # Only areas with enough neurons
            pct_sig = area_subset['significant_congruence_effect'].mean() * 100
            avg_effect = area_subset['same_cat_diff_mean'].mean() - area_subset['diff_cat_diff_mean'].mean()
            n_neurons = len(area_subset)
            area_data.append({'area': area, 'pct_sig': pct_sig, 'avg_effect': avg_effect, 'n_neurons': n_neurons})
    
    area_df = pd.DataFrame(area_data)
    if len(area_df) > 0:
        bars = axes[1,0].bar(area_df['area'], area_df['pct_sig'])
        axes[1,0].set_ylabel('% with Significant Effects')
        axes[1,0].set_title('Congruence Effects by Brain Area')
        plt.setp(axes[1,0].xaxis.get_majorticklabels(), rotation=45)
        
        # Add sample size labels
        for i, (bar, n) in enumerate(zip(bars, area_df['n_neurons'])):
            height = bar.get_height()
            axes[1,0].text(bar.get_x() + bar.get_width()/2., height + 0.5,
                          f'n={n}', ha='center', va='bottom', fontsize=8)
    
    # Plot 1D: Effect types distribution
    effect_counts = all_neurons_congruence['adaptation_effect'].value_counts()
    colors = {'none': 'lightgray', 'adaptation': 'blue', 'facilitation': 'red', 
              'less_adaptation': 'lightblue', 'less_facilitation': 'pink'}
    
    pie_colors = [colors.get(effect, 'gray') for effect in effect_counts.index]
    axes[1,1].pie(effect_counts.values, labels=effect_counts.index, autopct='%1.1f%%',
                 colors=pie_colors, startangle=90)
    axes[1,1].set_title('Distribution of Effect Types')
    
    plt.tight_layout()
    plt.savefig("category_congruence_plots/congruence_by_neuron_type.png", dpi=300, bbox_inches="tight")
    plt.close()
    print("  Saved: congruence_by_neuron_type.png")
    
    # === PLOT 2: Individual Neuron Effects ===
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # Plot 2A: Scatter plot of same vs different category effects
    axes[0,0].scatter(all_neurons_congruence['same_cat_diff_mean'], 
                     all_neurons_congruence['diff_cat_diff_mean'],
                     c=all_neurons_congruence['significant_congruence_effect'].map({True: 'red', False: 'lightblue'}),
                     alpha=0.6)
    
    # Add diagonal line
    max_val = max(all_neurons_congruence['same_cat_diff_mean'].max(), 
                  all_neurons_congruence['diff_cat_diff_mean'].max())
    min_val = min(all_neurons_congruence['same_cat_diff_mean'].min(), 
                  all_neurons_congruence['diff_cat_diff_mean'].min())
    axes[0,0].plot([min_val, max_val], [min_val, max_val], 'k--', alpha=0.5)
    
    axes[0,0].set_xlabel('Same Category Effect (Hz)')
    axes[0,0].set_ylabel('Different Category Effect (Hz)')
    axes[0,0].set_title('Same vs Different Category Effects')
    axes[0,0].axhline(y=0, color='gray', linestyle='--', alpha=0.3)
    axes[0,0].axvline(x=0, color='gray', linestyle='--', alpha=0.3)
    
    # Add quadrant labels
    axes[0,0].text(0.7*max_val, 0.1*max_val, 'Same category\nfacilitation only', 
                  ha='center', va='center', bbox=dict(boxstyle="round,pad=0.3", facecolor="yellow", alpha=0.5))
    
    # Plot 2B: Distribution of congruence effects
    axes[0,1].hist(all_neurons_congruence['same_cat_diff_mean'] - all_neurons_congruence['diff_cat_diff_mean'], 
                  bins=30, alpha=0.7, color='skyblue', edgecolor='black')
    axes[0,1].axvline(x=0, color='red', linestyle='--', linewidth=2, label='No effect')
    axes[0,1].set_xlabel('Congruence Effect Size (Hz)')
    axes[0,1].set_ylabel('Number of Neurons')
    axes[0,1].set_title('Distribution of Congruence Effect Sizes')
    axes[0,1].legend()
    
    # Plot 2C: Effect size vs significance
    effect_sizes = all_neurons_congruence['same_cat_diff_mean'] - all_neurons_congruence['diff_cat_diff_mean']
    p_values = np.minimum(all_neurons_congruence['difference_pvalue'], all_neurons_congruence['ratio_pvalue'])
    
    # Color by neuron type
    type_colors = {'encoding_only': 'blue', 'delay_only': 'orange', 'distractor_resistant': 'red'}
    colors = [type_colors.get(t, 'gray') for t in all_neurons_congruence['neuron_type']]
    
    axes[1,0].scatter(effect_sizes, -np.log10(p_values), c=colors, alpha=0.6)
    axes[1,0].axhline(y=-np.log10(0.05), color='red', linestyle='--', label='p=0.05')
    axes[1,0].axvline(x=0, color='gray', linestyle='--', alpha=0.5)
    axes[1,0].set_xlabel('Congruence Effect Size (Hz)')
    axes[1,0].set_ylabel('-log10(p-value)')
    axes[1,0].set_title('Volcano Plot: Effect Size vs Significance')
    axes[1,0].legend()
    
    # Add legend for neuron types
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor=color, label=ntype) 
                      for ntype, color in type_colors.items()]
    axes[1,0].legend(handles=legend_elements, loc='upper right')
    
    # Plot 2D: Response correlations same vs different
    axes[1,1].scatter(all_neurons_congruence['same_cat_correlation'], 
                     all_neurons_congruence['diff_cat_correlation'],
                     c=all_neurons_congruence['significant_congruence_effect'].map({True: 'red', False: 'lightblue'}),
                     alpha=0.6)
    axes[1,1].plot([-1, 1], [-1, 1], 'k--', alpha=0.5)
    axes[1,1].set_xlabel('Same Category Stim1-Stim2 Correlation')
    axes[1,1].set_ylabel('Different Category Stim1-Stim2 Correlation')
    axes[1,1].set_title('Response Correlations: Same vs Different Categories')
    
    plt.tight_layout()
    plt.savefig("category_congruence_plots/individual_neuron_effects.png", dpi=300, bbox_inches="tight")
    plt.close()
    print("  Saved: individual_neuron_effects.png")
    
    # === PLOT 3: Distractor-Resistant Neuron Focus ===
    distractor_resistant = all_neurons_congruence[all_neurons_congruence['neuron_type'] == 'distractor_resistant']
    
    if len(distractor_resistant) > 0:
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        
        # Plot 3A: Distractor-resistant neurons individual effects
        axes[0,0].scatter(distractor_resistant['same_cat_diff_mean'], 
                         distractor_resistant['diff_cat_diff_mean'],
                         c=distractor_resistant['significant_congruence_effect'].map({True: 'red', False: 'blue'}),
                         s=100, alpha=0.7)
        
        # Add unit ID labels
        for _, row in distractor_resistant.iterrows():
            axes[0,0].annotate(f"U{row['unit_id']}", 
                              (row['same_cat_diff_mean'], row['diff_cat_diff_mean']),
                              xytext=(5, 5), textcoords='offset points', fontsize=8)
        
        axes[0,0].axhline(y=0, color='gray', linestyle='--', alpha=0.5)
        axes[0,0].axvline(x=0, color='gray', linestyle='--', alpha=0.5)
        axes[0,0].set_xlabel('Same Category Effect (Hz)')
        axes[0,0].set_ylabel('Different Category Effect (Hz)')
        axes[0,0].set_title('Distractor-Resistant Neurons: Individual Effects')
        
        # Plot 3B: Effect types in distractor-resistant neurons
        dr_effects = distractor_resistant['adaptation_effect'].value_counts()
        if len(dr_effects) > 0:
            axes[0,1].bar(dr_effects.index, dr_effects.values, 
                         color=['lightgray', 'blue', 'red', 'lightblue', 'pink'])
            axes[0,1].set_ylabel('Number of Neurons')
            axes[0,1].set_title('Effect Types in Distractor-Resistant Neurons')
            plt.setp(axes[0,1].xaxis.get_majorticklabels(), rotation=45)
        
        # Plot 3C: Brain area distribution of distractor-resistant neurons
        dr_areas = distractor_resistant['area'].value_counts()
        axes[1,0].bar(dr_areas.index, dr_areas.values)
        axes[1,0].set_ylabel('Number of Neurons')
        axes[1,0].set_title('Distractor-Resistant Neurons by Brain Area')
        
        # Add congruence effect info
        for i, area in enumerate(dr_areas.index):
            area_data = distractor_resistant[distractor_resistant['area'] == area]
            n_sig = area_data['significant_congruence_effect'].sum()
            axes[1,0].text(i, dr_areas.iloc[i] + 0.1, f'{n_sig} sig', 
                          ha='center', va='bottom', fontsize=8)
        
        # Plot 3D: Comparison with other neuron types
        comparison_data = []
        for ntype in all_neurons_congruence['neuron_type'].unique():
            type_data = all_neurons_congruence[all_neurons_congruence['neuron_type'] == ntype]
            avg_effect = type_data['same_cat_diff_mean'].mean() - type_data['diff_cat_diff_mean'].mean()
            pct_sig = type_data['significant_congruence_effect'].mean() * 100
            comparison_data.append({'type': ntype, 'avg_effect': avg_effect, 'pct_sig': pct_sig})
        
        comp_df = pd.DataFrame(comparison_data)
        
        # Create a dual-axis plot
        ax1 = axes[1,1]
        ax2 = ax1.twinx()
        
        x_pos = np.arange(len(comp_df))
        bars1 = ax1.bar(x_pos - 0.2, comp_df['avg_effect'], width=0.4, 
                       label='Avg Effect Size', color='skyblue', alpha=0.7)
        bars2 = ax2.bar(x_pos + 0.2, comp_df['pct_sig'], width=0.4, 
                       label='% Significant', color='orange', alpha=0.7)
        
        ax1.set_xlabel('Neuron Type')
        ax1.set_ylabel('Average Effect Size (Hz)', color='blue')
        ax2.set_ylabel('% with Significant Effects', color='orange')
        ax1.set_title('Congruence Effects: Type Comparison')
        
        ax1.set_xticks(x_pos)
        ax1.set_xticklabels(comp_df['type'], rotation=45)
        ax1.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
        
        # Add legends
        lines1, labels1 = ax1.get_legend_handles_labels()
        lines2, labels2 = ax2.get_legend_handles_labels()
        ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
        
        plt.tight_layout()
        plt.savefig("category_congruence_plots/distractor_resistant_focus.png", dpi=300, bbox_inches="tight")
        plt.close()
        print("  Saved: distractor_resistant_focus.png")
    
    # === PLOT 4: Example Neurons with Strong Effects ===
    # Find neurons with strongest effects
    significant_neurons = all_neurons_congruence[all_neurons_congruence['significant_congruence_effect']]
    
    if len(significant_neurons) > 0:
        # Sort by effect size
        significant_neurons['abs_effect'] = abs(significant_neurons['same_cat_diff_mean'] - significant_neurons['diff_cat_diff_mean'])
        top_neurons = significant_neurons.nlargest(4, 'abs_effect')
        
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        axes = axes.flatten()
        
        for i, (_, neuron) in enumerate(top_neurons.iterrows()):
            if i >= 4:
                break
                
            unit_id = neuron['unit_id']
            
            # Get unit data for detailed plotting
            unit_data = data_filtered[data_filtered['unit_id'] == unit_id].copy()
            if len(unit_data) > 0:
                unit_data['category_match'] = unit_data['first_cat_simple'] == unit_data['second_cat_simple']
                
                same_cat = unit_data[unit_data['category_match']]
                diff_cat = unit_data[~unit_data['category_match']]
                
                # Create comparison plot
                x_pos = [0, 1]
                same_means = [same_cat['fr_epoch'].mean(), same_cat['fr_enc2_epoch'].mean()]
                same_sems = [same_cat['fr_epoch'].sem(), same_cat['fr_enc2_epoch'].sem()]
                diff_means = [diff_cat['fr_epoch'].mean(), diff_cat['fr_enc2_epoch'].mean()]
                diff_sems = [diff_cat['fr_epoch'].sem(), diff_cat['fr_enc2_epoch'].sem()]
                
                axes[i].errorbar(x_pos, same_means, yerr=same_sems, 
                               marker='o', label='Same Category', linewidth=2, markersize=8)
                axes[i].errorbar(x_pos, diff_means, yerr=diff_sems, 
                               marker='s', label='Different Category', linewidth=2, markersize=8)
                
                axes[i].set_xticks(x_pos)
                axes[i].set_xticklabels(['Stim1', 'Stim2'])
                axes[i].set_ylabel('Firing Rate (Hz)')
                axes[i].set_title(f"Unit {unit_id} ({neuron['area']}, {neuron['neuron_type']})\n"
                                f"Effect: {neuron['same_cat_diff_mean'] - neuron['diff_cat_diff_mean']:.3f} Hz")
                axes[i].legend()
                axes[i].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig("category_congruence_plots/example_neurons_with_effects.png", dpi=300, bbox_inches="tight")
        plt.close()
        print("  Saved: example_neurons_with_effects.png")
    
    print("\nplot finished")

###  mixed selectivity during stim 2

In [None]:
# ===== MIXED SELECTIVITY: STIM1 x STIM2 INTERACTION ANALYSIS =====
# Testing for interaction effects during stim2 presentation
# Inspired by Parthasarathy et al. (2017) Nature Neuroscience

import statsmodels.formula.api as smf
import statsmodels.api as sm
import numpy as np
import pandas as pd
from tqdm import tqdm

def analyze_mixed_selectivity_interactions(data_filtered, selectivity_results):
    """
    Test for mixed selectivity during stim2 presentation:
    1. CategoryStim1 x CategoryStim2 interaction effects
    2. NumerosityStim1 x NumerosityStim2 interaction effects
    
    This tests whether stim2 responses depend on specific combinations of 
    stim1 and stim2 features, not just individual features.
    
    Reference: Parthasarathy et al. (2017) Nature Neuroscience
    "Mixed selectivity morphs population codes in prefrontal cortex"
    """
    
    print("=== MIXED SELECTIVITY INTERACTION ANALYSIS ===")
    print("Testing for Stim1 x Stim2 interaction effects during stim2 presentation...")
    
    # Use ALL neurons (like Parthasarathy et al.) - no pre-selection for individual selectivity
    all_units = data_filtered['unit_id'].unique()
    print(f"Analyzing {len(all_units)} total neurons (no pre-selection)")
    print("Note: Using all neurons to avoid bias against pure mixed selectivity")
    
    interaction_results = []
    
    for unit_id in tqdm(all_units, desc="Testing interactions"):
        try:
            df_unit = data_filtered[data_filtered["unit_id"] == unit_id].reset_index(drop=True)
            if df_unit.empty:
                continue
            
            # Get area info (handle case where unit isn't in selectivity_results)
            area_info = selectivity_results[selectivity_results['unit_id'] == unit_id]
            if len(area_info) > 0:
                area = area_info.iloc[0]['area']
            else:
                # Get area from data_filtered directly
                area = df_unit['brainAreaOfCell'].iloc[0]
            
            # Get selectivity info if available
            selectivity_info = selectivity_results[selectivity_results['unit_id'] == unit_id]
            if len(selectivity_info) > 0:
                is_first_cat_selective = selectivity_info.iloc[0]['is_first_cat_selective']
                is_second_cat_selective = selectivity_info.iloc[0]['is_second_cat_selective'] 
                is_first_num_selective = selectivity_info.iloc[0]['is_first_num_selective']
                is_second_num_selective = selectivity_info.iloc[0]['is_second_num_selective']
            else:
                # Neuron wasn't in selectivity analysis (low firing rate, etc.)
                is_first_cat_selective = False
                is_second_cat_selective = False
                is_first_num_selective = False
                is_second_num_selective = False
            
            # Check data quality for interaction analysis
            n_stim1_cats = df_unit['first_cat_simple'].nunique()
            n_stim2_cats = df_unit['second_cat_simple'].nunique()
            n_stim1_nums = df_unit['first_num_simple'].nunique()
            n_stim2_nums = df_unit['second_num_simple'].nunique()
            
            if n_stim1_cats < 2 or n_stim2_cats < 2:
                continue
            
            # === ANALYSIS 1: Category x Category Interaction ===
            cat_interaction_pval = np.nan
            cat_main_stim1_pval = np.nan
            cat_main_stim2_pval = np.nan
            cat_model_r2 = np.nan
            
            try:
                # 2-way ANOVA with interaction: stim2_response ~ stim1_cat * stim2_cat
                cat_model = smf.ols(
                    "fr_enc2_epoch ~ C(first_cat_simple) * C(second_cat_simple)", 
                    data=df_unit
                ).fit()
                cat_anova = sm.stats.anova_lm(cat_model, typ=2)
                
                # Extract p-values
                cat_main_stim1_pval = cat_anova.loc['C(first_cat_simple)', 'PR(>F)']
                cat_main_stim2_pval = cat_anova.loc['C(second_cat_simple)', 'PR(>F)']
                cat_interaction_pval = cat_anova.loc['C(first_cat_simple):C(second_cat_simple)', 'PR(>F)']
                cat_model_r2 = cat_model.rsquared
                
            except Exception as e:
                pass  # Keep NaN values
            
            # === ANALYSIS 2: Numerosity x Numerosity Interaction ===
            num_interaction_pval = np.nan
            num_main_stim1_pval = np.nan
            num_main_stim2_pval = np.nan
            num_model_r2 = np.nan
            
            if n_stim1_nums >= 2 and n_stim2_nums >= 2:
                try:
                    # 2-way ANOVA with interaction: stim2_response ~ stim1_num * stim2_num
                    num_model = smf.ols(
                        "fr_enc2_epoch ~ C(first_num_simple) * C(second_num_simple)", 
                        data=df_unit
                    ).fit()
                    num_anova = sm.stats.anova_lm(num_model, typ=2)
                    
                    # Extract p-values
                    num_main_stim1_pval = num_anova.loc['C(first_num_simple)', 'PR(>F)']
                    num_main_stim2_pval = num_anova.loc['C(second_num_simple)', 'PR(>F)']
                    num_interaction_pval = num_anova.loc['C(first_num_simple):C(second_num_simple)', 'PR(>F)']
                    num_model_r2 = num_model.rsquared
                    
                except Exception as e:
                    pass  # Keep NaN values
            
            # === ANALYSIS 3: Mixed Model with All Features ===
            mixed_model_r2 = np.nan
            mixed_cat1_pval = np.nan
            mixed_cat2_pval = np.nan
            mixed_num1_pval = np.nan
            mixed_num2_pval = np.nan
            
            try:
                # Full model with all features (no interactions for complexity)
                mixed_model = smf.ols(
                    "fr_enc2_epoch ~ C(first_cat_simple) + C(second_cat_simple) + C(first_num_simple) + C(second_num_simple)", 
                    data=df_unit
                ).fit()
                mixed_anova = sm.stats.anova_lm(mixed_model, typ=2)
                
                mixed_model_r2 = mixed_model.rsquared
                mixed_cat1_pval = mixed_anova.loc['C(first_cat_simple)', 'PR(>F)']
                mixed_cat2_pval = mixed_anova.loc['C(second_cat_simple)', 'PR(>F)']
                mixed_num1_pval = mixed_anova.loc['C(first_num_simple)', 'PR(>F)']
                mixed_num2_pval = mixed_anova.loc['C(second_num_simple)', 'PR(>F)']
                
            except Exception as e:
                pass
            
            # === Neuron Type Classification ===
            neuron_type = "no_selectivity"  # Default for neurons not in selectivity analysis
            if len(selectivity_info) > 0:
                neuron_type = "encoding_only"  # Has some selectivity
                if 'distractor_results_principled' in globals():
                    distractor_info = distractor_results_principled[distractor_results_principled['unit_id'] == unit_id]
                    if len(distractor_info) > 0:
                        neuron_type = distractor_info.iloc[0]['analysis_stage']
            
            # Mixed selectivity classification
            mixed_selectivity_type = "none"
            if not np.isnan(cat_interaction_pval) and cat_interaction_pval < 0.05:
                if not np.isnan(num_interaction_pval) and num_interaction_pval < 0.05:
                    mixed_selectivity_type = "both_interactions"
                else:
                    mixed_selectivity_type = "category_interaction"
            elif not np.isnan(num_interaction_pval) and num_interaction_pval < 0.05:
                mixed_selectivity_type = "numerosity_interaction"
            
            # Store results
            result = {
                'unit_id': unit_id,
                'area': area,
                'neuron_type': neuron_type,
                
                # Data quality
                'n_trials': len(df_unit),
                'n_stim1_categories': n_stim1_cats,
                'n_stim2_categories': n_stim2_cats,
                'n_stim1_numerosities': n_stim1_nums,
                'n_stim2_numerosities': n_stim2_nums,
                
                # Category interaction analysis
                'cat_interaction_pval': cat_interaction_pval,
                'cat_main_stim1_pval': cat_main_stim1_pval,
                'cat_main_stim2_pval': cat_main_stim2_pval,
                'cat_model_r2': cat_model_r2,
                'cat_interaction_significant': cat_interaction_pval < 0.05 if not np.isnan(cat_interaction_pval) else False,
                
                # Numerosity interaction analysis  
                'num_interaction_pval': num_interaction_pval,
                'num_main_stim1_pval': num_main_stim1_pval,
                'num_main_stim2_pval': num_main_stim2_pval,
                'num_model_r2': num_model_r2,
                'num_interaction_significant': num_interaction_pval < 0.05 if not np.isnan(num_interaction_pval) else False,
                
                # Mixed model
                'mixed_model_r2': mixed_model_r2,
                'mixed_cat1_pval': mixed_cat1_pval,
                'mixed_cat2_pval': mixed_cat2_pval,
                'mixed_num1_pval': mixed_num1_pval,
                'mixed_num2_pval': mixed_num2_pval,
                
                # Classification
                'mixed_selectivity_type': mixed_selectivity_type,
                'any_interaction_significant': (
                    (cat_interaction_pval < 0.05 if not np.isnan(cat_interaction_pval) else False) or
                    (num_interaction_pval < 0.05 if not np.isnan(num_interaction_pval) else False)
                ),
                
                # Selectivity info (may be False if neuron wasn't in original analysis)
                'is_first_cat_selective': is_first_cat_selective,
                'is_second_cat_selective': is_second_cat_selective,
                'is_first_num_selective': is_first_num_selective,
                'is_second_num_selective': is_second_num_selective,
            }
            
            interaction_results.append(result)
            
        except Exception as e:
            print(f"Error analyzing unit {unit_id}: {e}")
            continue
    
    # Convert to DataFrame
    interaction_df = pd.DataFrame(interaction_results)
    
    if len(interaction_df) == 0:
        print("No neurons analyzed successfully!")
        return pd.DataFrame()
    
    # === SUMMARY ANALYSIS ===
    print(f"\n=== MIXED SELECTIVITY RESULTS ===")
    print(f"Successfully analyzed: {len(interaction_df)} neurons (all recorded neurons)")
    print(f"Note: This includes neurons with no individual feature selectivity")
    print(f"Mixed selectivity can emerge in apparently 'non-selective' neurons")
    
    # Overall interaction statistics
    cat_interactions = interaction_df['cat_interaction_significant'].sum()
    num_interactions = interaction_df['num_interaction_significant'].sum()
    any_interactions = interaction_df['any_interaction_significant'].sum()
    
    print(f"\n=== INTERACTION EFFECTS ===")
    print(f"Category x Category interactions: {cat_interactions}/{len(interaction_df)} ({cat_interactions/len(interaction_df)*100:.1f}%)")
    print(f"Numerosity x Numerosity interactions: {num_interactions}/{len(interaction_df)} ({num_interactions/len(interaction_df)*100:.1f}%)")
    print(f"Any interaction effects: {any_interactions}/{len(interaction_df)} ({any_interactions/len(interaction_df)*100:.1f}%)")
    
    # Mixed selectivity type distribution
    print(f"\n=== MIXED SELECTIVITY TYPES ===")
    type_counts = interaction_df['mixed_selectivity_type'].value_counts()
    for stype, count in type_counts.items():
        print(f"  {stype}: {count} neurons ({count/len(interaction_df)*100:.1f}%)")
    
    # Analysis by neuron type
    print(f"\n=== BY NEURON TYPE ===")
    for ntype in interaction_df['neuron_type'].unique():
        type_data = interaction_df[interaction_df['neuron_type'] == ntype]
        n_total = len(type_data)
        n_any_interact = type_data['any_interaction_significant'].sum()
        n_cat_interact = type_data['cat_interaction_significant'].sum()
        n_num_interact = type_data['num_interaction_significant'].sum()
        
        print(f"  {ntype} (n={n_total}):")
        print(f"    Any interactions: {n_any_interact}/{n_total} ({n_any_interact/n_total*100:.1f}%)")
        print(f"    Category interactions: {n_cat_interact}/{n_total} ({n_cat_interact/n_total*100:.1f}%)")
        print(f"    Numerosity interactions: {n_num_interact}/{n_total} ({n_num_interact/n_total*100:.1f}%)")
    
    # Analysis by brain area
    print(f"\n=== BY BRAIN AREA ===")
    for area in interaction_df['area'].unique():
        area_data = interaction_df[interaction_df['area'] == area]
        if len(area_data) >= 5:  # Only show areas with enough neurons
            n_total = len(area_data)
            n_any_interact = area_data['any_interaction_significant'].sum()
            n_cat_interact = area_data['cat_interaction_significant'].sum()
            avg_r2 = area_data['cat_model_r2'].mean()
            
            print(f"  {area} (n={n_total}): {n_any_interact} interactions ({n_any_interact/n_total*100:.1f}%), avg R²={avg_r2:.3f}")
    
    # Model performance comparison
    print(f"\n=== MODEL PERFORMANCE ===")
    avg_cat_r2 = interaction_df['cat_model_r2'].mean()
    avg_num_r2 = interaction_df['num_model_r2'].mean() 
    avg_mixed_r2 = interaction_df['mixed_model_r2'].mean()
    
    print(f"Average R² - Category model: {avg_cat_r2:.3f}")
    print(f"Average R² - Numerosity model: {avg_num_r2:.3f}")
    print(f"Average R² - Mixed model: {avg_mixed_r2:.3f}")
    
    # Save results
    os.makedirs("mixed_selectivity_analysis", exist_ok=True)
    interaction_df.to_csv("mixed_selectivity_analysis/stim1_stim2_interactions.csv", index=False)
    print(f"\nResults saved to: mixed_selectivity_analysis/stim1_stim2_interactions.csv")
    
    # === COMPARISON TO PARTHASARATHY ET AL. ===
    print(f"\n=== COMPARISON TO PARTHASARATHY ET AL. (2017) ===")
    print(f"Mixed selectivity prevalence in your data: {any_interactions/len(interaction_df)*100:.1f}%")
    print("Parthasarathy et al. found ~30-50% of PFC neurons with mixed selectivity")
    print("Key difference: They used all neurons (like this analysis), not pre-selected ones")
    
    # Additional breakdown
    has_individual_selectivity = interaction_df[
        (interaction_df['is_first_cat_selective']) | 
        (interaction_df['is_second_cat_selective']) |
        (interaction_df['is_first_num_selective']) | 
        (interaction_df['is_second_num_selective'])
    ]
    
    no_individual_selectivity = interaction_df[
        (~interaction_df['is_first_cat_selective']) & 
        (~interaction_df['is_second_cat_selective']) &
        (~interaction_df['is_first_num_selective']) & 
        (~interaction_df['is_second_num_selective'])
    ]
    
    print(f"\n=== BREAKDOWN BY INDIVIDUAL SELECTIVITY ===")
    print(f"Neurons with individual feature selectivity: {len(has_individual_selectivity)}")
    if len(has_individual_selectivity) > 0:
        print(f"  Mixed selectivity in these: {has_individual_selectivity['any_interaction_significant'].sum()}/{len(has_individual_selectivity)} ({has_individual_selectivity['any_interaction_significant'].mean()*100:.1f}%)")
    
    print(f"Neurons with NO individual selectivity: {len(no_individual_selectivity)}")
    if len(no_individual_selectivity) > 0:
        print(f"  Mixed selectivity in these: {no_individual_selectivity['any_interaction_significant'].sum()}/{len(no_individual_selectivity)} ({no_individual_selectivity['any_interaction_significant'].mean()*100:.1f}%)")
        print(f"  → This is the key insight: pure mixed selectivity without individual tuning!")
    
   
    print("\nThis analysis uses the Parthasarathy approach: all neurons, no pre-selection")
    print("Mixed selectivity = neurons encoding feature combinations, not just individual features")
    
    return interaction_df


# === RUN THE MIXED SELECTIVITY ANALYSIS ===
if 'selectivity_results' in locals() and 'data_filtered' in locals():
    print("Running mixed selectivity interaction analysis...")
    mixed_selectivity_results = analyze_mixed_selectivity_interactions(data_filtered, selectivity_results)
    
    print("\nMixed selectivity analysis complete!")
    print("This tests whether neurons encode combinations of stim1 and stim2 features,")
    print("as described in Parthasarathy et al. (2017) Nature Neuroscience.")
    
else:
    print("Need selectivity_results and data_filtered to run this analysis!")
    print("Make sure you've run the neuron selectivity analysis first.")

In [None]:
# ===== MIXED SELECTIVITY: STIM1 x STIM2 INTERACTION ANALYSIS =====
# Testing for interaction effects during stim2 presentation
# Inspired by Parthasarathy et al. (2017) Nature Neuroscience

# import statsmodels.formula.api as smf
# import statsmodels.api as sm
# import numpy as np
# import pandas as pd
# from tqdm import tqdm

def analyze_mixed_selectivity_interactions(data_filtered, selectivity_results):
    """
    Test for mixed selectivity during stim2 presentation:
    1. CategoryStim1 x CategoryStim2 interaction effects
    2. NumerosityStim1 x NumerosityStim2 interaction effects
    
    This tests whether stim2 responses depend on specific combinations of 
    stim1 and stim2 features, not just individual features.
    
    Reference: Parthasarathy et al. (2017) Nature Neuroscience
    "Mixed selectivity morphs population codes in prefrontal cortex"
    """
    
    print("=== MIXED SELECTIVITY INTERACTION ANALYSIS ===")
    print("Testing for Stim1 x Stim2 interaction effects during stim2 presentation...")
    
    # Use all responsive neurons for broader analysis
    responsive_neurons = selectivity_results[selectivity_results["is_any_selective"]].copy()
    print(f"Analyzing {len(responsive_neurons)} responsive neurons")
    
    interaction_results = []
    
    for _, neuron_row in tqdm(responsive_neurons.iterrows(), desc="Testing interactions"):
        unit_id = neuron_row['unit_id']
        
        try:
            df_unit = data_filtered[data_filtered["unit_id"] == unit_id].reset_index(drop=True)
            if df_unit.empty:
                continue
            
            area = neuron_row['area']
            
            # Check if we have enough data for interaction analysis
            # Need multiple levels of each factor
            n_stim1_cats = df_unit['first_cat_simple'].nunique()
            n_stim2_cats = df_unit['second_cat_simple'].nunique()
            n_stim1_nums = df_unit['first_num_simple'].nunique()
            n_stim2_nums = df_unit['second_num_simple'].nunique()
            
            if n_stim1_cats < 2 or n_stim2_cats < 2:
                continue
            
            # === ANALYSIS 1: Category x Category Interaction ===
            cat_interaction_pval = np.nan
            cat_main_stim1_pval = np.nan
            cat_main_stim2_pval = np.nan
            cat_model_r2 = np.nan
            
            try:
                # 2-way ANOVA with interaction: stim2_response ~ stim1_cat * stim2_cat
                cat_model = smf.ols(
                    "fr_enc2_epoch ~ C(first_cat_simple) * C(second_cat_simple)", 
                    data=df_unit
                ).fit()
                cat_anova = sm.stats.anova_lm(cat_model, typ=2)
                
                # Extract p-values
                cat_main_stim1_pval = cat_anova.loc['C(first_cat_simple)', 'PR(>F)']
                cat_main_stim2_pval = cat_anova.loc['C(second_cat_simple)', 'PR(>F)']
                cat_interaction_pval = cat_anova.loc['C(first_cat_simple):C(second_cat_simple)', 'PR(>F)']
                cat_model_r2 = cat_model.rsquared
                
            except Exception as e:
                pass  # Keep NaN values
            
            # === ANALYSIS 2: Numerosity x Numerosity Interaction ===
            num_interaction_pval = np.nan
            num_main_stim1_pval = np.nan
            num_main_stim2_pval = np.nan
            num_model_r2 = np.nan
            
            if n_stim1_nums >= 2 and n_stim2_nums >= 2:
                try:
                    # 2-way ANOVA with interaction: stim2_response ~ stim1_num * stim2_num
                    num_model = smf.ols(
                        "fr_enc2_epoch ~ C(first_num_simple) * C(second_num_simple)", 
                        data=df_unit
                    ).fit()
                    num_anova = sm.stats.anova_lm(num_model, typ=2)
                    
                    # Extract p-values
                    num_main_stim1_pval = num_anova.loc['C(first_num_simple)', 'PR(>F)']
                    num_main_stim2_pval = num_anova.loc['C(second_num_simple)', 'PR(>F)']
                    num_interaction_pval = num_anova.loc['C(first_num_simple):C(second_num_simple)', 'PR(>F)']
                    num_model_r2 = num_model.rsquared
                    
                except Exception as e:
                    pass  # Keep NaN values
            
            # === ANALYSIS 3: Mixed Model with Both Features ===
            mixed_model_r2 = np.nan
            mixed_cat1_pval = np.nan
            mixed_cat2_pval = np.nan
            mixed_num1_pval = np.nan
            mixed_num2_pval = np.nan
            
            try:
                # Full model with all features (no interactions for complexity)
                mixed_model = smf.ols(
                    "fr_enc2_epoch ~ C(first_cat_simple) + C(second_cat_simple) + C(first_num_simple) + C(second_num_simple)", 
                    data=df_unit
                ).fit()
                mixed_anova = sm.stats.anova_lm(mixed_model, typ=2)
                
                mixed_model_r2 = mixed_model.rsquared
                mixed_cat1_pval = mixed_anova.loc['C(first_cat_simple)', 'PR(>F)']
                mixed_cat2_pval = mixed_anova.loc['C(second_cat_simple)', 'PR(>F)']
                mixed_num1_pval = mixed_anova.loc['C(first_num_simple)', 'PR(>F)']
                mixed_num2_pval = mixed_anova.loc['C(second_num_simple)', 'PR(>F)']
                
            except Exception as e:
                pass
            
            # === Classification ===
            # Determine neuron type
            neuron_type = "encoding_only"
            if 'distractor_results_principled' in globals():
                distractor_info = distractor_results_principled[distractor_results_principled['unit_id'] == unit_id]
                if len(distractor_info) > 0:
                    neuron_type = distractor_info.iloc[0]['analysis_stage']
            
            # Mixed selectivity classification
            mixed_selectivity_type = "none"
            if not np.isnan(cat_interaction_pval) and cat_interaction_pval < 0.05:
                if not np.isnan(num_interaction_pval) and num_interaction_pval < 0.05:
                    mixed_selectivity_type = "both_interactions"
                else:
                    mixed_selectivity_type = "category_interaction"
            elif not np.isnan(num_interaction_pval) and num_interaction_pval < 0.05:
                mixed_selectivity_type = "numerosity_interaction"
            
            # Store results
            result = {
                'unit_id': unit_id,
                'area': area,
                'neuron_type': neuron_type,
                
                # Data quality
                'n_trials': len(df_unit),
                'n_stim1_categories': n_stim1_cats,
                'n_stim2_categories': n_stim2_cats,
                'n_stim1_numerosities': n_stim1_nums,
                'n_stim2_numerosities': n_stim2_nums,
                
                # Category interaction analysis
                'cat_interaction_pval': cat_interaction_pval,
                'cat_main_stim1_pval': cat_main_stim1_pval,
                'cat_main_stim2_pval': cat_main_stim2_pval,
                'cat_model_r2': cat_model_r2,
                'cat_interaction_significant': cat_interaction_pval < 0.05 if not np.isnan(cat_interaction_pval) else False,
                
                # Numerosity interaction analysis  
                'num_interaction_pval': num_interaction_pval,
                'num_main_stim1_pval': num_main_stim1_pval,
                'num_main_stim2_pval': num_main_stim2_pval,
                'num_model_r2': num_model_r2,
                'num_interaction_significant': num_interaction_pval < 0.05 if not np.isnan(num_interaction_pval) else False,
                
                # Mixed model
                'mixed_model_r2': mixed_model_r2,
                'mixed_cat1_pval': mixed_cat1_pval,
                'mixed_cat2_pval': mixed_cat2_pval,
                'mixed_num1_pval': mixed_num1_pval,
                'mixed_num2_pval': mixed_num2_pval,
                
                # Classification
                'mixed_selectivity_type': mixed_selectivity_type,
                'any_interaction_significant': (
                    (cat_interaction_pval < 0.05 if not np.isnan(cat_interaction_pval) else False) or
                    (num_interaction_pval < 0.05 if not np.isnan(num_interaction_pval) else False)
                ),
                
                # Selectivity info from original analysis
                'is_first_cat_selective': neuron_row['is_first_cat_selective'],
                'is_second_cat_selective': neuron_row['is_second_cat_selective'],
                'is_first_num_selective': neuron_row['is_first_num_selective'],
                'is_second_num_selective': neuron_row['is_second_num_selective'],
            }
            
            interaction_results.append(result)
            
        except Exception as e:
            print(f"Error analyzing unit {unit_id}: {e}")
            continue
    
    # Convert to DataFrame
    interaction_df = pd.DataFrame(interaction_results)
    
    if len(interaction_df) == 0:
        print("No neurons analyzed successfully!")
        return pd.DataFrame()
    
    # === SUMMARY ANALYSIS ===
    print(f"\n=== MIXED SELECTIVITY RESULTS ===")
    print(f"Successfully analyzed: {len(interaction_df)} neurons")
    
    # Overall interaction statistics
    cat_interactions = interaction_df['cat_interaction_significant'].sum()
    num_interactions = interaction_df['num_interaction_significant'].sum()
    any_interactions = interaction_df['any_interaction_significant'].sum()
    
    print(f"\n=== INTERACTION EFFECTS ===")
    print(f"Category x Category interactions: {cat_interactions}/{len(interaction_df)} ({cat_interactions/len(interaction_df)*100:.1f}%)")
    print(f"Numerosity x Numerosity interactions: {num_interactions}/{len(interaction_df)} ({num_interactions/len(interaction_df)*100:.1f}%)")
    print(f"Any interaction effects: {any_interactions}/{len(interaction_df)} ({any_interactions/len(interaction_df)*100:.1f}%)")
    
    # Mixed selectivity type distribution
    print(f"\n=== MIXED SELECTIVITY TYPES ===")
    type_counts = interaction_df['mixed_selectivity_type'].value_counts()
    for stype, count in type_counts.items():
        print(f"  {stype}: {count} neurons ({count/len(interaction_df)*100:.1f}%)")
    
    # Analysis by neuron type
    print(f"\n=== BY NEURON TYPE ===")
    for ntype in interaction_df['neuron_type'].unique():
        type_data = interaction_df[interaction_df['neuron_type'] == ntype]
        n_total = len(type_data)
        n_any_interact = type_data['any_interaction_significant'].sum()
        n_cat_interact = type_data['cat_interaction_significant'].sum()
        n_num_interact = type_data['num_interaction_significant'].sum()
        
        print(f"  {ntype} (n={n_total}):")
        print(f"    Any interactions: {n_any_interact}/{n_total} ({n_any_interact/n_total*100:.1f}%)")
        print(f"    Category interactions: {n_cat_interact}/{n_total} ({n_cat_interact/n_total*100:.1f}%)")
        print(f"    Numerosity interactions: {n_num_interact}/{n_total} ({n_num_interact/n_total*100:.1f}%)")
    
    # Analysis by brain area
    print(f"\n=== BY BRAIN AREA ===")
    for area in interaction_df['area'].unique():
        area_data = interaction_df[interaction_df['area'] == area]
        if len(area_data) >= 5:  # Only show areas with enough neurons
            n_total = len(area_data)
            n_any_interact = area_data['any_interaction_significant'].sum()
            n_cat_interact = area_data['cat_interaction_significant'].sum()
            avg_r2 = area_data['cat_model_r2'].mean()
            
            print(f"  {area} (n={n_total}): {n_any_interact} interactions ({n_any_interact/n_total*100:.1f}%), avg R²={avg_r2:.3f}")
    
    # Model performance comparison
    print(f"\n=== MODEL PERFORMANCE ===")
    avg_cat_r2 = interaction_df['cat_model_r2'].mean()
    avg_num_r2 = interaction_df['num_model_r2'].mean()
    avg_mixed_r2 = interaction_df['mixed_model_r2'].mean()
    
    print(f"Average R² - Category model: {avg_cat_r2:.3f}")
    print(f"Average R² - Numerosity model: {avg_num_r2:.3f}")
    print(f"Average R² - Mixed model: {avg_mixed_r2:.3f}")
    
    # Save results
    os.makedirs("mixed_selectivity_analysis", exist_ok=True)
    interaction_df.to_csv("mixed_selectivity_analysis/stim1_stim2_interactions.csv", index=False)
    print(f"\nResults saved to: mixed_selectivity_analysis/stim1_stim2_interactions.csv")
    
    # === COMPARISON TO PARTHASARATHY ET AL. ===
    print(f"\n=== COMPARISON TO PARTHASARATHY ET AL. (2017) ===")
    print(f"Mixed selectivity prevalence in your data: {any_interactions/len(interaction_df)*100:.1f}%")
    print("Parthasarathy et al. found ~30-50% of PFC neurons with mixed selectivity")
    
    if any_interactions/len(interaction_df) > 0.3:
        print("→ Your data shows HIGH levels of mixed selectivity, similar to PFC")
    elif any_interactions/len(interaction_df) > 0.15:
        print("→ Your data shows MODERATE levels of mixed selectivity")
    else:
        print("→ Your data shows LOW levels of mixed selectivity")
    
    print("This suggests population coding may rely on feature combinations rather than pure selectivity")
    
    return interaction_df


# === RUN THE MIXED SELECTIVITY ANALYSIS ===
if 'selectivity_results' in locals() and 'data_filtered' in locals():
    print("Running mixed selectivity interaction analysis...")
    mixed_selectivity_results = analyze_mixed_selectivity_interactions(data_filtered, selectivity_results)
    
    print("\nMixed selectivity analysis complete!")
    print("This tests whether neurons encode combinations of stim1 and stim2 features,")
    print("as described in Parthasarathy et al. (2017) Nature Neuroscience.")
    
else:
    print("Need selectivity_results and data_filtered to run this analysis!")
    print("Make sure you've run the neuron selectivity analysis first.")