# WM conjunction coding
conjunction coding

In [None]:
import sys
print(sys.executable)

In [2]:
# pip install numpy pandas matplotlib seaborn scipy mat73
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy.io import loadmat
from scipy.stats import ttest_1samp
import glob
import os
# import mat73

## Load data

### Loading MATLAB objects

In [None]:
# Find all WMB_P*_v7.mat files in the current directory
mat_files = glob.glob('./WMB_P*_v7.mat')
print(f"Found {len(mat_files)} .mat files: {[os.path.basename(f) for f in mat_files]}")

# Initialize lists to store data from each file
cell_mats = []
total_mats = []

# Load each file and append its data
for mat_file in mat_files:
    print(f"\nLoading {mat_file}...")
    mat_data = loadmat(mat_file)
    cell_mats.append(mat_data['cellStatsAll'])
    total_mats.append(mat_data['totStats'])

# Print shapes of loaded data for debugging
print("\nShapes of loaded data:")
for i, (cell, total) in enumerate(zip(cell_mats, total_mats)):
    print(f"File {i}: cell_mat shape: {cell.shape}, total_mat shape: {total.shape}")

### Combine data

In [None]:

# First, convert each cell_mat to a list of records
all_cell_records = []
for cell_mat in cell_mats:
    # Convert to list of records
    cell_list = cell_mat[0]  # now shape is (n,)
    records = []
    for cell in cell_list:
        record = {key: cell[key] for key in cell.dtype.names}
        records.append(record)
    all_cell_records.extend(records)

# Convert combined records to DataFrame
df = pd.DataFrame(all_cell_records)

# For total_mat, we can concatenate directly since they have the same structure
total_mat = np.concatenate(total_mats, axis=0)

print(f"\nCombined data shape - total_mat: {total_mat.shape}")
print(f"\nCombined data shape - df: {df.shape}")

### Translates area codes

In [5]:
from collections import Counter

def count_area_codes(area_column):
    
    mapping = {
        1: 'RH', 2: 'LH', 3: 'RA', 4: 'LA', 5: 'RAC', 6: 'LAC',
        7: 'RSMA', 8: 'LSMA', 9: 'RPT', 10: 'LPT', 11: 'ROFC', 12: 'LOFC',
        50: 'RFFA', 51: 'REC', 52: 'RCM', 53: 'LCM', 54: 'RPUL', 55: 'LPUL',
        56: 'N/A', 57: 'RPRV', 58: 'LPRV'
    }
    
    labels = []
    for code in area_column:
        label = mapping.get(code, 'Unknown')
        labels.append(label)
    
    return dict(Counter(labels))

In [None]:
# neuron number in each area
area_codes = total_mat[:, 3]

counts = count_area_codes(area_codes)
print("Area counts (no prefix):")
for area, count in counts.items():
    print(f"{area}: {count}")

### Format cell data

All brain area

In [7]:
collapsed_area_map = {
    1: 'H', 2: 'H',
    3: 'A', 4: 'A',
    5: 'AC', 6: 'AC',
    7: 'SMA', 8: 'SMA',
    9: 'PT', 10: 'PT',
    11: 'OFC', 12: 'OFC',
    50: 'FFA', 51: 'EC',
    52: 'CM', 53: 'CM',
    54: 'PUL', 55: 'PUL',
    56: 'N/A', 57: 'PRV', 58: 'PRV'
}

In [8]:
# Convert brain area codes in the DataFrame
df['brainAreaOfCell'] = df['brainAreaOfCell'].apply(
    lambda x: collapsed_area_map.get(int(x[0, 0]), 'Unknown') if isinstance(x, np.ndarray) else collapsed_area_map.get(x, 'Unknown')
)

### Filter out units with low firing rate

In [9]:
# Filter out units with low firing rate
fr = df['timestamps'].apply(lambda x: len(x) / (x[-1] - x[0]) * 1e6)
df_sample_new = df[fr > 0.1].reset_index(drop=True)

# unit id
df_sample_new = df_sample_new.reset_index(drop=True)
df_sample_new["unit_id"] = df_sample_new.index

## Extract trial info

In [10]:
def extract_trial_info(trials_struct, unit_id):
    # Build a DataFrame from the trials structure.
    # We use .squeeze() for each field – adjust if necessary.
    df_trial = pd.DataFrame({field: trials_struct[field].squeeze() 
                             for field in trials_struct.dtype.names})
    # Add the unit_id so that you can later separate trials by unit/session.
    df_trial["unit_id"] = unit_id
    df_trial["trial_nr"] = df_trial["trial"].apply(lambda x: np.squeeze(x).item() if isinstance(x, (list, np.ndarray)) else x) - 1 # Adjust for 0-indexing
    return df_trial

In [11]:
trial_info_list = []
for idx, row in df_sample_new.iterrows():
    # Use the unit identifier from this row
    unit_id = row["unit_id"]  
    # Extract the trial DataFrame, including the unit identifier.
    trial_info_list.append(extract_trial_info(row["Trials"], unit_id, ))

# Concatenate the list of trial info DataFrames into one.
trial_info = pd.concat(trial_info_list, ignore_index=True)

In [None]:
# % ttl values
# c.marker.expstart        = 89;
# c.marker.expend          = 90;
# c.marker.fixOnset        = 10;
# c.marker.pic1            = 1;
# c.marker.delay1          = 2;
# c.marker.pic2            = 3;
# c.marker.delay2          = 4;
# c.marker.probeOnset      = 5;
# c.marker.response        = 6;
# c.marker.break           = 91;

# what names are in the df
df_sample_new.columns

## Event ts extraction

In [13]:

# Event ts extraction
def extract_event_timestamps(df_sample_new):
    # Extract timestamp arrays for each epoch
    epoch_ts = {}
    
    for i, row in df_sample_new.iterrows():
        unit_id = row['unit_id']
        events = row['events'].squeeze()       # Ensure it's 1D array
        idxs1 = row['idxEnc1'].squeeze() - 1   # Ensure indices are 1D array; start with 0
        idxs2 = row['idxDel1'].squeeze() - 1   # Use delay onset for epoch end
        
        # Index into events using the adjusted indices
        extracted1 = events[idxs1]   # shape (n_trials, 3)
        extracted2 = events[idxs2]   # shape (n_trials, 3)

        # Store as event start/end times
        combined = np.column_stack((extracted1[:, 0], extracted2[:, 0]))
        epoch_ts[unit_id] = combined
        
    return epoch_ts

# Get event timestamps
epoch_ts = extract_event_timestamps(df_sample_new)

# Now compute firing rates for baseline and stimulus epochs
def compute_firing_rates(df_sample_new, epoch_ts):
    # Baseline period: 1 second before stimulus
    df_sample_new["fr_baseline"] = df_sample_new.apply(
        lambda row: [
            np.sum((np.ravel(row["timestamps"]) >= epoch_on - 1 * 1e6) & 
                   (np.ravel(row["timestamps"]) < epoch_on)) / 1.0  # 1 second window
            for epoch_on, _ in epoch_ts[row["unit_id"]]
        ],
        axis=1
    )
    
    # Stimulus period: from stimulus onset to end of encoding (defined by epoch_ts)
    df_sample_new["fr_epoch"] = df_sample_new.apply(
        lambda row: [
            np.sum((np.ravel(row["timestamps"]) >= epoch_on) & 
                   (np.ravel(row["timestamps"]) < epoch_off)) / ((epoch_off - epoch_on) / 1e6)
            for epoch_on, epoch_off in epoch_ts[row["unit_id"]]
        ],
        axis=1
    )
    # Add trial_nr column to track trial numbers
    df_sample_new["trial_nr"] = df_sample_new["fr_epoch"].apply(lambda x: np.arange(len(x)))
    
    # Explode the dataframe so each trial is a row
    df_exploded = df_sample_new.explode(["fr_baseline", "fr_epoch", "trial_nr"])
    
    return df_exploded

# Compute firing rates
df_sample_new = compute_firing_rates(df_sample_new, epoch_ts)

In [14]:
df_sample_new = df_sample_new.reset_index(drop=True)
trial_info = trial_info.reset_index(drop=True)

data = pd.merge(
    df_sample_new,
    trial_info,
    on=["unit_id", "trial_nr"],
    how="left",
).infer_objects()

cols_to_keep = [
    "unit_id", "timestamps", "brainAreaOfCell", "fr_epoch", "fr_baseline", "trial_nr",
    "first_cat", "second_cat", "first_num", "second_num",
    "first_pic", "second_pic", "probe_cat", "probe_pic",
    "probe_validity", "probe_num", "correct_answer",
    "rt", "acc", "key", "cat_comparison", "events", "nTrials", "Trials", "idxEnc1", "idxEnc2", "idxDel1",
    "idxDel2", "idxProbeOn", "idxResp", "nrProcessed", "periods_Enc1", "periods_Enc2", "periods_Del1",
    "periods_Del2", "periods_Probe", "periods_Resp", "prestimEnc", "prestimMaint", "prestimProbe",
    "prestimButtonPress", "poststimEnc", "poststimMaint", "poststimProbe", "poststimButtonPress", "sessionIdx", "channel", "cellNr", "sessionID", "origClusterID"
]


data_filtered = data[cols_to_keep].copy()

In [15]:
# Convert to simpler, hashable values for categories and numbers
data_filtered["first_cat_simple"] = data_filtered["first_cat"].apply(
    lambda x: str(np.squeeze(x)) if isinstance(x, (list, np.ndarray)) else str(x)
)
data_filtered["second_cat_simple"] = data_filtered["second_cat"].apply(
    lambda x: str(np.squeeze(x)) if isinstance(x, (list, np.ndarray)) else str(x)
)
data_filtered["first_num_simple"] = data_filtered["first_num"].apply(
    lambda x: str(np.squeeze(x)) if isinstance(x, (list, np.ndarray)) else str(x)
)
data_filtered["second_num_simple"] = data_filtered["second_num"].apply(
    lambda x: str(np.squeeze(x)) if isinstance(x, (list, np.ndarray)) else str(x)
)

## Conjunction coding

In [16]:
import statsmodels.formula.api as smf
import statsmodels.api as sm
from tqdm import tqdm

def identify_conjunctive_neurons(data_filtered):
    """
    Identifies neurons that show conjunctive coding of stimulus features.
    
    This function analyzes neural responses to identify units that encode:
    1. Conjunctive representations (interaction between features)
    2. Feature-selective responses (main effects only)
    
    Parameters
    ----------
    data_filtered : pandas.DataFrame
        DataFrame containing neural data with columns:
        - unit_id: unique identifier for each neuron
        - fr_epoch: firing rate during the epoch
        - fr_baseline: baseline firing rate
        - first_cat_simple: stimulus category
        - first_num_simple: stimulus number
        - brainAreaOfCell: brain region of recorded neuron
        
    Returns
    -------
    conjunctive_units : list
        List of unit IDs showing significant conjunctive coding
    feature_selective_units : list
        List of unit IDs showing significant selectivity for individual features
    mixed_selectivity_stats : list
        Detailed statistics for each unit's selectivity patterns
    """
    conjunctive_units = []
    feature_selective_units = []
    mixed_selectivity_stats = []
    
    for unit_id, unit_df in data_filtered.groupby("unit_id"):
        # Center firing rates around baseline
        unit_df = unit_df.copy()
        unit_df["fr_normalized"] = unit_df["fr_epoch"] - unit_df["fr_baseline"]
        
        # Skip units with no variance
        if unit_df["fr_normalized"].std() == 0:
            continue
            
        # Model 1: Main effects only (category and number independently)
        model_independent = smf.ols(
            "fr_normalized ~ C(first_cat_simple) + C(first_num_simple)", 
            data=unit_df
        )
        
        # Model 2: Main effects + Interaction (conjunctive coding)
        model_conjunctive = smf.ols(
            "fr_normalized ~ C(first_cat_simple) * C(first_num_simple)", 
            data=unit_df
        )
        
        # Fit both models
        results_independent = model_independent.fit()
        results_conjunctive = model_conjunctive.fit()
        
        # Formal F-test comparing the models
        from statsmodels.stats.anova import anova_lm
        comparison = anova_lm(results_independent, results_conjunctive)
        
        # Extract interaction significance
        interaction_pvalue = comparison["Pr(>F)"].iloc[1]
        
        # Calculate effect sizes for main effects and interaction
        anova_table = sm.stats.anova_lm(results_conjunctive, typ=2)
        
        # Store detailed statistics
        stats = {
            'unit_id': unit_id,
            'area': unit_df['brainAreaOfCell'].iloc[0],
            'interaction_pvalue': interaction_pvalue,
            'cat_pvalue': anova_table.loc['C(first_cat_simple)', 'PR(>F)'],
            'num_pvalue': anova_table.loc['C(first_num_simple)', 'PR(>F)'],
            'r2_independent': results_independent.rsquared,
            'r2_conjunctive': results_conjunctive.rsquared,
            'r2_increase': results_conjunctive.rsquared - results_independent.rsquared,
            'is_conjunctive': interaction_pvalue < 0.05,
            'is_cat_selective': anova_table.loc['C(first_cat_simple)', 'PR(>F)'] < 0.05,
            'is_num_selective': anova_table.loc['C(first_num_simple)', 'PR(>F)'] < 0.05
        }
        
        mixed_selectivity_stats.append(stats)
        
        # Classify the neuron
        if interaction_pvalue < 0.05:
            conjunctive_units.append(unit_id)
        elif anova_table.loc['C(first_cat_simple)', 'PR(>F)'] < 0.05 or \
             anova_table.loc['C(first_num_simple)', 'PR(>F)'] < 0.05:
            feature_selective_units.append(unit_id)
    
    return pd.DataFrame(mixed_selectivity_stats), conjunctive_units, feature_selective_units
def visualize_conjunctive_responses(data_filtered, unit_id):
    """Create a heatmap showing firing rates for each category-number combination"""
    unit_data = data_filtered[data_filtered["unit_id"] == unit_id].copy()
    
    # Pivot data to create category-number response matrix
    response_matrix = unit_data.pivot_table(
        index="first_cat_simple", 
        columns="first_num_simple",
        values="fr_epoch",
        aggfunc="mean"
    )
    
    # Normalize by subtracting mean baseline
    baseline = unit_data["fr_baseline"].mean()
    response_matrix_norm = response_matrix - baseline
    
    # Create visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Raw firing rates
    sns.heatmap(response_matrix, annot=True, fmt=".1f", cmap="viridis", ax=ax1)
    ax1.set_title(f"Unit {unit_id}: Raw firing rates (Hz)")
    ax1.set_xlabel("Number")
    ax1.set_ylabel("Category")
    
    # Normalized firing rates (change from baseline)
    sns.heatmap(response_matrix_norm, annot=True, fmt=".1f", cmap="coolwarm", 
                center=0, ax=ax2)
    ax2.set_title(f"Unit {unit_id}: Normalized firing rates (change from baseline)")
    ax2.set_xlabel("Number")
    ax2.set_ylabel("Category")
    
    plt.tight_layout()
    return fig

def characterize_conjunction_pattern(data_filtered, unit_id):
    """Analyze the pattern of conjunctive coding for a specific neuron"""
    unit_data = data_filtered[data_filtered["unit_id"] == unit_id].copy()
    
    # Fit the full model with interaction
    model = smf.ols(
        "fr_epoch ~ C(first_cat_simple) * C(first_num_simple)", 
        data=unit_data
    ).fit()
    
    # Get the coefficients table
    coef_table = model.summary2().tables[1]
    
    # Extract interaction coefficients
    interaction_coefs = coef_table[coef_table.index.str.contains(":")].copy()
    
    # Parse the interaction terms to get category-number pairs
    interaction_coefs["Category"] = interaction_coefs.index.str.extract(r'C\(first_cat_simple\)\[T.(.*?)\]')[0]
    interaction_coefs["Number"] = interaction_coefs.index.str.extract(r'C\(first_num_simple\)\[T.(.*?)\]')[0]
    
    # Sort by coefficient magnitude to find strongest conjunctions
    interaction_coefs["Abs_Coef"] = abs(interaction_coefs["Coef."])
    interaction_coefs = interaction_coefs.sort_values("Abs_Coef", ascending=False)
    
    # Calculate selectivity index: ratio of strongest conjunction to mean of others
    top_conj = interaction_coefs["Abs_Coef"].max() if len(interaction_coefs) > 0 else 0
    others_mean = interaction_coefs["Abs_Coef"].iloc[1:].mean() if len(interaction_coefs) > 1 else 0
    selectivity_index = top_conj / others_mean if others_mean > 0 else float('inf')
    
    return {
        "top_conjunctions": interaction_coefs[["Category", "Number", "Coef.", "P>|t|"]].head(3) if len(interaction_coefs) > 0 else pd.DataFrame(),
        "selectivity_index": selectivity_index,
        "num_significant_conjunctions": sum(interaction_coefs["P>|t|"] < 0.05) if len(interaction_coefs) > 0 else 0
    }

In [None]:
# Run the main analysis to identify conjunctive neurons across all data
print("Running conjunction analysis across all neurons...")
mixed_selectivity_stats, conjunctive_units, feature_selective_units = identify_conjunctive_neurons(data_filtered)

# Summarize findings
n_total = len(data_filtered['unit_id'].unique())
n_conj = len(conjunctive_units)
n_feat = len(feature_selective_units)

print(f"\nOVERALL RESULTS:")
print(f"Total units analyzed: {n_total}")
print(f"Conjunctive-coding units: {n_conj} ({n_conj/n_total:.1%})")
print(f"Feature-selective units: {n_feat} ({n_feat/n_total:.1%})")
print(f"Non-responsive units: {n_total - n_conj - n_feat} ({(n_total - n_conj - n_feat)/n_total:.1%})")

# Visualize overall distribution of neuron types
plt.figure(figsize=(10, 6))
plt.bar(['Conjunctive', 'Feature-selective', 'Non-responsive'], 
       [n_conj, n_feat, n_total - n_conj - n_feat])
plt.title("Distribution of Neuron Types")
plt.ylabel("Number of Neurons")
plt.tight_layout()
plt.savefig("neuron_types_overall.png", dpi=300)
plt.show()

# Analyze distribution by brain region
region_stats = mixed_selectivity_stats.groupby('area').agg(
    total_units=('unit_id', 'count'),
    conjunctive=('is_conjunctive', 'sum'),
    category_selective=('is_cat_selective', 'sum'),
    number_selective=('is_num_selective', 'sum')
).reset_index()

region_stats['conj_percent'] = region_stats['conjunctive'] / region_stats['total_units'] * 100
region_stats['cat_percent'] = region_stats['category_selective'] / region_stats['total_units'] * 100
region_stats['num_percent'] = region_stats['number_selective'] / region_stats['total_units'] * 100

# Sort by percentage of conjunctive neurons
region_stats_sorted = region_stats.sort_values('conj_percent', ascending=False)

# Visualize regional distribution of conjunctive neurons
plt.figure(figsize=(12, 8))
plt.bar(region_stats_sorted['area'], region_stats_sorted['conj_percent'])
plt.title("Percentage of Conjunctive Neurons by Brain Region")
plt.ylabel("Percent of Units")
plt.xlabel("Brain Region")
plt.axhline(y=n_conj/n_total*100, color='r', linestyle='--', label='Overall Average')
plt.legend()
plt.tight_layout()
plt.savefig("conjunctive_neurons_by_region.png", dpi=300)
plt.show()

# Analyze the top 5 conjunctive neurons in detail
top_conj_units = mixed_selectivity_stats.sort_values('r2_increase', ascending=False)['unit_id'].iloc[:5]
print(f"\nAnalyzing top 5 conjunctive neurons in detail:")

for unit_id in top_conj_units:
    unit_info = mixed_selectivity_stats[mixed_selectivity_stats['unit_id'] == unit_id].iloc[0]
    print(f"\nUnit {unit_id} (Area: {unit_info['area']}):")
    print(f"  R² increase from conjunction: {unit_info['r2_increase']:.3f}")
    print(f"  Category p-value: {unit_info['cat_pvalue']:.4f}")
    print(f"  Number p-value: {unit_info['num_pvalue']:.4f}")
    print(f"  Interaction p-value: {unit_info['interaction_pvalue']:.4f}")
    
    # Create response heatmap
    fig = visualize_conjunctive_responses(data_filtered, unit_id)
    plt.savefig(f"unit_{unit_id}_response_matrix.png", dpi=300)
    plt.close(fig)
    
    # Analyze conjunction pattern
    pattern = characterize_conjunction_pattern(data_filtered, unit_id)
    print(f"  Selectivity index: {pattern['selectivity_index']:.2f}")
    print(f"  Significant conjunctions: {pattern['num_significant_conjunctions']}")
    if len(pattern['top_conjunctions']) > 0:
        print("  Top conjunctions:")
        print(pattern['top_conjunctions'])

In [None]:
# Calculate binding index to quantify the degree of conjunction coding
def calculate_binding_index(data_filtered, unit_id):
    """
    Calculate binding index for a given unit.
    
    Binding Index = Conjunction Selectivity / (Category Selectivity + Numerosity Selectivity)
    
    Parameters
    ----------
    data_filtered : DataFrame
        DataFrame containing neural data
    unit_id : int
        Unit identifier
    
    Returns
    -------
    dict
        Dictionary with selectivity indices and binding index
    """
    unit_data = data_filtered[data_filtered['unit_id'] == unit_id].copy()
    
    # Extract trial conditions and firing rates
    categories = unit_data['first_cat_simple'].values
    numerosities = unit_data['first_num_simple'].values
    firing_rates = unit_data['fr_epoch'].values
    
    # Calculate mean firing rates for each condition
    unique_categories = np.unique(categories)
    unique_numerosities = np.unique(numerosities)
    
    # Category selectivity
    fr_category = [np.mean(firing_rates[categories == cat]) for cat in unique_categories]
    category_selectivity = np.max(fr_category) - np.min(fr_category)
    
    # Numerosity selectivity
    fr_numerosity = [np.mean(firing_rates[numerosities == num]) for num in unique_numerosities]
    numerosity_selectivity = np.max(fr_numerosity) - np.min(fr_numerosity)
    
    # Conjunction selectivity
    fr_conjunction = []
    for cat in unique_categories:
        for num in unique_numerosities:
            mask = (categories == cat) & (numerosities == num)
            if np.any(mask):
                fr_conjunction.append(np.mean(firing_rates[mask]))
    
    conjunction_selectivity = np.max(fr_conjunction) - np.min(fr_conjunction)
    
    # Calculate binding index
    # Avoid division by zero
    denominator = category_selectivity + numerosity_selectivity
    binding_index = conjunction_selectivity / denominator if denominator > 0 else 0
    
    return {
        'category_selectivity': category_selectivity,
        'numerosity_selectivity': numerosity_selectivity,
        'conjunction_selectivity': conjunction_selectivity,
        'binding_index': binding_index
    }

# Calculate binding indices for all units
print("Calculating binding indices for all units...")
binding_indices = []

for unit_id in data_filtered['unit_id'].unique():
    indices = calculate_binding_index(data_filtered, unit_id)
    indices['unit_id'] = unit_id
    indices['area'] = mixed_selectivity_stats[mixed_selectivity_stats['unit_id'] == unit_id]['area'].iloc[0]
    binding_indices.append(indices)

binding_df = pd.DataFrame(binding_indices)

# Get conjunctive neurons
conjunctive_units = mixed_selectivity_stats[mixed_selectivity_stats['is_conjunctive'] == 1]['unit_id'].values
binding_df['is_conjunctive'] = binding_df['unit_id'].isin(conjunctive_units)

# Summary statistics for all neurons
print(f"\nBinding Index Summary (All Neurons):")
print(f"Mean: {binding_df['binding_index'].mean():.3f}")
print(f"Median: {binding_df['binding_index'].median():.3f}")
print(f"Range: {binding_df['binding_index'].min():.3f} - {binding_df['binding_index'].max():.3f}")
print(f"n: {len(binding_df)}")

# Summary statistics for conjunctive neurons
conj_df = binding_df[binding_df['is_conjunctive']]
print(f"\nBinding Index Summary (Conjunctive Neurons):")
print(f"Mean: {conj_df['binding_index'].mean():.3f}")
print(f"Median: {conj_df['binding_index'].median():.3f}")
print(f"Range: {conj_df['binding_index'].min():.3f} - {conj_df['binding_index'].max():.3f}")
print(f"n: {len(conj_df)}")

# Visualize distribution of binding indices
plt.figure(figsize=(10, 6))
plt.hist(binding_df[~binding_df['is_conjunctive']]['binding_index'], bins=20, 
         alpha=0.7, label='Non-conjunctive Neurons')
plt.hist(conj_df['binding_index'], bins=20, 
         alpha=0.7, label='Conjunctive Neurons')
plt.title("Distribution of Binding Indices")
plt.xlabel("Binding Index")
plt.ylabel("Number of Units")
plt.axvline(binding_df['binding_index'].mean(), color='r', linestyle='--', 
            label=f'All Mean: {binding_df["binding_index"].mean():.3f}')
plt.axvline(conj_df['binding_index'].mean(), color='g', linestyle='--', 
            label=f'Conjunctive Mean: {conj_df["binding_index"].mean():.3f}')
plt.legend()
plt.savefig("binding_index_distribution.png", dpi=300)
plt.show()

# Create 3D visualization of neurons in selectivity space
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection='3d')

# Color points by brain area
areas = binding_df['area'].unique()
colors = plt.cm.tab10(np.linspace(0, 1, len(areas)))
area_color_map = dict(zip(areas, colors))

for area in areas:
    area_data = binding_df[binding_df['area'] == area]
    ax.scatter(
        area_data['category_selectivity'],
        area_data['numerosity_selectivity'],
        area_data['binding_index'],
        label=area,
        alpha=0.7,
        s=50,
        color=area_color_map[area]
    )

# Highlight conjunctive neurons
conjunctive_neurons = binding_df[binding_df['unit_id'].isin(mixed_selectivity_stats[mixed_selectivity_stats['is_conjunctive'] == 1]['unit_id'])]
ax.scatter(
    conjunctive_neurons['category_selectivity'],
    conjunctive_neurons['numerosity_selectivity'],
    conjunctive_neurons['binding_index'],
    color='red',
    s=100,
    marker='*',
    label='Conjunctive Neurons'
)

ax.set_xlabel('Category Selectivity')
ax.set_ylabel('Numerosity Selectivity')
ax.set_zlabel('Binding Index')
ax.set_title('3D Distribution of Neuronal Selectivity')

# Add legend with smaller font size
ax.legend(fontsize='small', loc='upper right')

# Add grid for better visualization
ax.grid(True)

# Adjust view angle for better visualization
ax.view_init(elev=30, azim=45)

plt.tight_layout()
plt.savefig("selectivity_3d_space.png", dpi=300)
plt.show()

# Analyze relationship between binding index and brain regions
region_binding = binding_df.groupby('area').agg(
    mean_binding=('binding_index', 'mean'),
    median_binding=('binding_index', 'median'),
    max_binding=('binding_index', 'max'),
    unit_count=('unit_id', 'count')
).reset_index()

# Sort by mean binding index
region_binding_sorted = region_binding.sort_values('mean_binding', ascending=False)

# Visualize binding index by brain region
plt.figure(figsize=(12, 8))
plt.bar(region_binding_sorted['area'], region_binding_sorted['mean_binding'])
plt.title("Mean Binding Index by Brain Region")
plt.ylabel("Mean Binding Index")
plt.xlabel("Brain Region")
plt.axhline(y=binding_df['binding_index'].mean(), color='r', linestyle='--', label='Overall Average')
plt.legend()
plt.tight_layout()
plt.savefig("binding_index_by_region.png", dpi=300)
plt.show()


### Temporal Dynamics of Binding

In [19]:
from scipy.ndimage import gaussian_filter1d
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import balanced_accuracy_score
from sklearn.preprocessing import StandardScaler

def extract_trial_aligned_spikes(data_filtered, unit_id, event_marker='idxEnc1', window=(-1.0, 7.0)):
    """
    Extract spike times aligned to a specific event for a given unit.
    
    Parameters
    ----------
    data_filtered : DataFrame
        DataFrame containing neural data
    unit_id : int
        Unit identifier
    event_marker : str
        Event marker to align spikes to ('idxEnc1', 'idxDel1', 'idxProbeOn', etc.)
    window : tuple
        Time window around event in seconds (pre, post)
    
    Returns
    -------
    dict
        Dictionary with aligned spike times for each trial
    """
    unit_data = data_filtered[data_filtered['unit_id'] == unit_id].copy()
    
    # Get full spike train for this unit
    spikes = np.asarray(unit_data["timestamps"].iloc[0]).flatten().astype(np.float64) / 1e6
    spikes = np.sort(spikes)
    
    # Get events for alignment
    events = unit_data['events'].iloc[0].squeeze()
    event_indices = unit_data[event_marker].iloc[0].squeeze() - 1  # Convert to 0-indexed
    
    trial_aligned_spikes = {}
    
    # For each trial, get spike times relative to the alignment event
    for trial_nr, event_idx in enumerate(event_indices):
        if isinstance(event_idx, np.ndarray):
            event_idx = event_idx[0]
            
        # Get timestamp for this event
        event_time = events[event_idx, 0] / 1e6  # Convert to seconds
        
        # Get spikes around this event
        window_start = event_time + window[0]  # pre-event time
        window_end = event_time + window[1]    # post-event time
        
        # Find spikes in this window
        trial_spikes = spikes[(spikes >= window_start) & (spikes <= window_end)]
        
        # Convert to seconds relative to event
        trial_spikes_aligned = trial_spikes - event_time
        
        # Store with trial info
        try:
            first_cat = unit_data.iloc[trial_nr]['first_cat_simple']
            first_num = unit_data.iloc[trial_nr]['first_num_simple']
            
            trial_aligned_spikes[trial_nr] = {
                'spikes': trial_spikes_aligned,
                'first_cat': first_cat,
                'first_num': first_num
            }
        except:
            # Skip trials with missing info
            continue
    
    return trial_aligned_spikes

def analyze_binding_dynamics(data_filtered, unit_id):
    """
    Analyze how conjunction coding emerges over time for a specific neuron.
    
    Parameters
    ----------
    data_filtered : DataFrame
        DataFrame containing neural data
    unit_id : int
        Unit identifier
    
    Returns
    -------
    DataFrame
        DataFrame with p-values over time
    matplotlib.figure.Figure
        Figure showing temporal dynamics
    """
    unit_data = data_filtered[data_filtered["unit_id"] == unit_id].copy()
    
    # Get trial-aligned spikes
    aligned_spikes = extract_trial_aligned_spikes(data_filtered, unit_id)
    
    # Define time windows with overlap
    window_width = 50  # milliseconds - smaller window for more time bins
    step_size = 25  # milliseconds - smaller step size for more time bins
    t_start = -500  # start before stimulus (ms)
    t_end = 2000    # end after stimulus (ms)
    
    windows = [(t, t + window_width) for t in range(t_start, t_end - window_width, step_size)]
    window_centers = [t + window_width/2 for t in range(t_start, t_end - window_width, step_size)]
    
    # Initialize results containers
    cat_pvals = []
    num_pvals = []
    interaction_pvals = []
    r2_increases = []
    
    # Import required modules
    import statsmodels.formula.api as smf
    import statsmodels.api as sm
    
    # Analyze each time window
    for window in windows:
        # Count spikes in window for each trial
        spike_counts = []
        categories = []
        numbers = []
        
        for trial_nr, trial_data in aligned_spikes.items():
            # Extract spikes in this window (convert ms to s)
            window_start_s = window[0] / 1000
            window_end_s = window[1] / 1000
            trial_spikes = trial_data['spikes']
            
            # Count spikes in window
            count = np.sum((trial_spikes >= window_start_s) & (trial_spikes < window_end_s))
            
            spike_counts.append(count)
            categories.append(trial_data['first_cat'])
            numbers.append(trial_data['first_num'])
        
        # Create temporary dataframe for this window
        temp_df = pd.DataFrame({
            'window_fr': spike_counts,
            'first_cat_simple': categories,
            'first_num_simple': numbers
        })
        
        # Skip windows with no spikes
        if sum(spike_counts) == 0:
            cat_pvals.append(1.0)
            num_pvals.append(1.0)
            interaction_pvals.append(1.0)
            r2_increases.append(0.0)
            continue
        
        # Fit models
        try:
            model_main = smf.ols("window_fr ~ C(first_cat_simple) + C(first_num_simple)", data=temp_df).fit()
            model_inter = smf.ols("window_fr ~ C(first_cat_simple) * C(first_num_simple)", data=temp_df).fit()
            
            # Extract statistics
            anova = sm.stats.anova_lm(model_inter, typ=2)
            
            # Safely extract p-values
            cat_pval = anova.loc['C(first_cat_simple)', 'PR(>F)'] if 'C(first_cat_simple)' in anova.index else 1.0
            num_pval = anova.loc['C(first_num_simple)', 'PR(>F)'] if 'C(first_num_simple)' in anova.index else 1.0
            
            cat_pvals.append(cat_pval)
            num_pvals.append(num_pval)
            
            if 'C(first_cat_simple):C(first_num_simple)' in anova.index:
                interaction_pvals.append(anova.loc['C(first_cat_simple):C(first_num_simple)', 'PR(>F)'])
            else:
                interaction_pvals.append(1.0)
                
            r2_increases.append(model_inter.rsquared - model_main.rsquared)
            
        except Exception as e:
            # Handle errors (e.g., perfect separation)
            cat_pvals.append(1.0)
            num_pvals.append(1.0)
            interaction_pvals.append(1.0)
            r2_increases.append(0.0)
    
    # Create result dataframe
    result = pd.DataFrame({
        'time': window_centers,
        'category_p': cat_pvals,
        'number_p': num_pvals,
        'interaction_p': interaction_pvals,
        'r2_increase': r2_increases,
        'category_sig': np.array(cat_pvals) < 0.05,
        'number_sig': np.array(num_pvals) < 0.05,
        'conjunction_sig': np.array(interaction_pvals) < 0.05
    })
    
    # Apply smoothing for visualization
    smoothed_cat = gaussian_filter1d(-np.log10(np.array(cat_pvals)), sigma=2)
    smoothed_num = gaussian_filter1d(-np.log10(np.array(num_pvals)), sigma=2)
    smoothed_int = gaussian_filter1d(-np.log10(np.array(interaction_pvals)), sigma=2)
    smoothed_r2 = gaussian_filter1d(np.array(r2_increases), sigma=2)
    
    # Visualize
    fig, axes = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
    
    # Plot significance
    axes[0].plot(window_centers, smoothed_cat, 'b-', label='Category', linewidth=2)
    axes[0].plot(window_centers, smoothed_num, 'g-', label='Number', linewidth=2)
    axes[0].plot(window_centers, smoothed_int, 'r-', label='Conjunction', linewidth=2)
    axes[0].axhline(-np.log10(0.05), color='k', linestyle='--', label='p=0.05')
    axes[0].set_ylabel('-log10(p-value)')
    axes[0].set_title(f'Unit {unit_id}: Temporal Evolution of Feature Encoding')
    axes[0].legend()
    
    # Plot R² increase
    axes[1].plot(window_centers, smoothed_r2, 'k-', linewidth=2)
    axes[1].set_ylabel('R² increase from interaction')
    axes[1].set_xlabel('Time from stimulus onset (ms)')
    
    # Mark key epochs
    for ax in axes:
        ax.axvline(0, color='k', linestyle='-', alpha=0.3, label='Stimulus onset')
        ax.axvline(1000, color='r', linestyle='-', alpha=0.3, label='Delay 1')
        
        # Only add legend to the first plot to avoid duplicates
        if ax == axes[0]:
            handles, labels = ax.get_legend_handles_labels()
            unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]]
            ax.legend(*zip(*unique), loc='upper right')
    
    # Set reasonable y-limits
    axes[0].set_ylim(bottom=0)
    if np.max(smoothed_r2) > 0:
        axes[1].set_ylim(bottom=0)
    
    plt.tight_layout()
    
    return result, fig

def cross_temporal_binding_analysis(data_filtered, responsive_units):
    """
    Perform cross-temporal decoding to track binding information across the neural population.
    
    Parameters
    ----------
    data_filtered : DataFrame
        DataFrame containing neural data
    responsive_units : list
        List of responsive unit IDs
    
    Returns
    -------
    dict
        Dictionary with decoding accuracies
    matplotlib.figure.Figure
        Figure showing cross-temporal decoding results
    """
    # Define time windows covering the entire trial
    window_width = 100  # milliseconds - smaller window for more time bins
    step_size = 50     # milliseconds - smaller step size for more time bins
    t_start = -500     # start 500ms before stimulus
    t_end = 2000        # end 2 seconds after stimulus
    
    windows = [(t, t + window_width) for t in range(t_start, t_end - window_width, step_size)]
    window_centers = [t + window_width/2 for t in range(t_start, t_end - window_width, step_size)]
    window_centers = np.array(window_centers)
    n_windows = len(windows)
    
    # Initialize matrices to store population response
    n_units = len(responsive_units)
    n_trials = len(data_filtered['trial_nr'].unique())
    
    # Dictionary to store neural data for each time window
    X_time_windows = {}
    
    # We also need to track categories and numbers for each trial
    categories = []
    numbers = []
    conjunctions = []
    
    # Process each unit
    all_trial_data = []
    
    print("Extracting spike data for each unit and time window...")
    for unit_idx, unit_id in enumerate(tqdm(responsive_units)):
        # Get aligned spikes for this unit
        aligned_spikes = extract_trial_aligned_spikes(data_filtered, unit_id, window=(-0.5, 2.0))
        
        # Process each time window
        for window_idx, (window_start, window_end) in enumerate(windows):
            window_key = f"window_{window_idx}"
            
            if window_key not in X_time_windows:
                X_time_windows[window_key] = np.zeros((n_trials, n_units))
            
            # Process each trial
            for trial_idx, (trial_nr, trial_data) in enumerate(aligned_spikes.items()):
                # Convert window boundaries to seconds
                window_start_s = window_start / 1000
                window_end_s = window_end / 1000
                
                # Count spikes in this window
                spike_count = np.sum((trial_data['spikes'] >= window_start_s) & 
                                     (trial_data['spikes'] < window_end_s))
                
                # Store spike count
                X_time_windows[window_key][trial_idx, unit_idx] = spike_count
                
                # Store labels (only once per trial)
                if unit_idx == 0 and window_idx == 0:
                    categories.append(trial_data['first_cat'])
                    numbers.append(trial_data['first_num'])
                    
                    # Create a unique conjunction ID
                    conj_id = f"{trial_data['first_cat']}_{trial_data['first_num']}"
                    conjunctions.append(conj_id)
                    
                    # Store trial data for reference
                    all_trial_data.append({
                        'trial_nr': trial_nr,
                        'first_cat': trial_data['first_cat'],
                        'first_num': trial_data['first_num'],
                        'conj_id': conj_id
                    })
    
    # Convert labels to arrays
    categories = np.array(categories)
    numbers = np.array(numbers)
    
    # Create numerical labels for conjunctions
    from sklearn.preprocessing import LabelEncoder
    le = LabelEncoder()
    conjunction_labels = le.fit_transform(conjunctions)
    
    # Initialize decoding accuracy matrices
    cat_accuracy = np.zeros((n_windows, n_windows))
    num_accuracy = np.zeros((n_windows, n_windows))
    conj_accuracy = np.zeros((n_windows, n_windows))
    
    print("Performing cross-temporal decoding...")
    # Perform cross-temporal decoding
    for train_idx in tqdm(range(n_windows)):
        train_key = f"window_{train_idx}"
        X_train = X_time_windows[train_key]
        
        # Standardize
        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train)
        
        # Train models
        cat_model = LogisticRegression(max_iter=1000, C=1.0, solver='liblinear')
        num_model = LogisticRegression(max_iter=1000, C=1.0, solver='liblinear')
        conj_model = LogisticRegression(max_iter=1000, C=1.0, solver='liblinear')
        
        # Fit models
        cat_model.fit(X_train_scaled, categories)
        num_model.fit(X_train_scaled, numbers)
        conj_model.fit(X_train_scaled, conjunction_labels)
        
        # Test on all time windows
        for test_idx in range(n_windows):
            test_key = f"window_{test_idx}"
            X_test = X_time_windows[test_key]
            X_test_scaled = scaler.transform(X_test)
            
            # Predict and calculate accuracy
            cat_pred = cat_model.predict(X_test_scaled)
            num_pred = num_model.predict(X_test_scaled)
            conj_pred = conj_model.predict(X_test_scaled)
            
            cat_accuracy[train_idx, test_idx] = balanced_accuracy_score(categories, cat_pred)
            num_accuracy[train_idx, test_idx] = balanced_accuracy_score(numbers, num_pred)
            conj_accuracy[train_idx, test_idx] = balanced_accuracy_score(conjunction_labels, conj_pred)
    
    # Apply smoothing for better visualization
    from scipy.ndimage import gaussian_filter
    cat_accuracy_smooth = gaussian_filter(cat_accuracy, sigma=1.0)
    num_accuracy_smooth = gaussian_filter(num_accuracy, sigma=1.0)
    conj_accuracy_smooth = gaussian_filter(conj_accuracy, sigma=1.0)
    
    # Visualize
    fig, axes = plt.subplots(1, 3, figsize=(20, 6))
    
    # Calculate chance levels
    cat_chance = 1.0 / len(np.unique(categories))
    num_chance = 1.0 / len(np.unique(numbers))
    conj_chance = 1.0 / len(np.unique(conjunction_labels))
    
    # Set vmin and vmax for consistent color scaling
    vmin = min(cat_chance, num_chance, conj_chance)
    vmax = max(np.max(cat_accuracy_smooth), np.max(num_accuracy_smooth), np.max(conj_accuracy_smooth))
    
    # Category decoding
    im0 = axes[0].imshow(cat_accuracy_smooth, cmap='viridis', vmin=vmin, vmax=vmax,
                     extent=[window_centers[0], window_centers[-1], window_centers[-1], window_centers[0]])
    axes[0].set_title(f"Category Decoding (chance={cat_chance:.2f})")
    axes[0].set_xlabel("Testing time (ms)")
    axes[0].set_ylabel("Training time (ms)")
    plt.colorbar(im0, ax=axes[0])
    
    # Number decoding
    im1 = axes[1].imshow(num_accuracy_smooth, cmap='viridis', vmin=vmin, vmax=vmax,
                     extent=[window_centers[0], window_centers[-1], window_centers[-1], window_centers[0]])
    axes[1].set_title(f"Number Decoding (chance={num_chance:.2f})")
    axes[1].set_xlabel("Testing time (ms)")
    plt.colorbar(im1, ax=axes[1])
    
    # Conjunction decoding
    im2 = axes[2].imshow(conj_accuracy_smooth, cmap='viridis', vmin=vmin, vmax=vmax,
                     extent=[window_centers[0], window_centers[-1], window_centers[-1], window_centers[0]])
    axes[2].set_title(f"Conjunction Decoding (chance={conj_chance:.2f})")
    axes[2].set_xlabel("Testing time (ms)")
    plt.colorbar(im2, ax=axes[2])
    
    # Add reference lines for key task events
    for ax in axes:
        # Diagonal line (training=testing)
        ax.plot([window_centers[0], window_centers[-1]], [window_centers[0], window_centers[-1]], 
                'k--', alpha=0.5)
        
        # Stimulus onsets and task events
        ax.axvline(0, color='r', linestyle='-', alpha=0.5, label='Stim 1')
        ax.axhline(0, color='r', linestyle='-', alpha=0.5)
        
        # Delay 1 onset
        ax.axvline(1000, color='b', linestyle='--', alpha=0.5, label='Delay 1')
        ax.axhline(1000, color='b', linestyle='--', alpha=0.5)
        
        # Add a legend to the first plot only
        if ax == axes[0]:
            handles, labels = ax.get_legend_handles_labels()
            by_label = dict(zip(labels, handles))
            ax.legend(by_label.values(), by_label.keys(), loc='upper right')
    
    plt.tight_layout()
    
    # Extract diagonal accuracies (same-time decoding)
    cat_diag = np.diag(cat_accuracy_smooth)
    num_diag = np.diag(num_accuracy_smooth)
    conj_diag = np.diag(conj_accuracy_smooth)
    
    # Create a second figure showing time course of decoding accuracy
    fig2, ax2 = plt.subplots(figsize=(14, 7))
    ax2.plot(window_centers, cat_diag, 'b-', label='Category', linewidth=2)
    ax2.plot(window_centers, num_diag, 'g-', label='Number', linewidth=2)
    ax2.plot(window_centers, conj_diag, 'r-', label='Conjunction', linewidth=2)
    
    # Add chance levels
    ax2.axhline(y=cat_chance, color='b', linestyle='--', alpha=0.5, label='Cat. chance')
    ax2.axhline(y=num_chance, color='g', linestyle='--', alpha=0.5, label='Num. chance')
    ax2.axhline(y=conj_chance, color='r', linestyle='--', alpha=0.5, label='Conj. chance')
    
    # Add event markers with shaded regions for key trial phases
    # Stimulus 1 presentation (0-1000ms)
    ax2.axvspan(0, 1000, color='lightgray', alpha=0.3, label='Stim 1')
    
    # Delay 1 period (1000-2000ms)
    ax2.axvspan(1000, 2000, color='lightblue', alpha=0.2, label='Delay 1')
    
    # Add vertical lines for key events
    ax2.axvline(x=0, color='k', linestyle='-', alpha=0.5, label='Stim 1 onset')
    ax2.axvline(x=1000, color='k', linestyle='--', alpha=0.5, label='Delay 1 onset')
    
    ax2.set_xlabel('Time (ms)')
    ax2.set_ylabel('Decoding Accuracy')
    ax2.set_title('Time Course of Information Representation Throughout Trial')
    
    # Only show some handles in the legend to avoid overcrowding
    handles, labels = ax2.get_legend_handles_labels()
    selected_handles = [handles[0], handles[1], handles[2], handles[3], handles[4], handles[5]]
    selected_labels = [labels[0], labels[1], labels[2], labels[3], labels[4], labels[5]]
    ax2.legend(selected_handles, selected_labels, loc='upper right')
    
    ax2.set_ylim(vmin-0.05, vmax+0.05)
    ax2.set_xlim(t_start, t_end)
    plt.tight_layout()
    
    # Calculate when each type of information reaches peak
    cat_peak_idx = np.argmax(cat_diag)
    num_peak_idx = np.argmax(num_diag)
    conj_peak_idx = np.argmax(conj_diag)
    
    cat_peak_time = window_centers[cat_peak_idx]
    num_peak_time = window_centers[num_peak_idx]
    conj_peak_time = window_centers[conj_peak_idx]
    
    # Calculate temporal offsets
    if conj_peak_time > cat_peak_time and conj_peak_time > num_peak_time:
        binding_delay = min(conj_peak_time - cat_peak_time, conj_peak_time - num_peak_time)
    else:
        binding_delay = 0
    
    # Calculate persistence of information
    # Find all timepoints with above-chance decoding
    cat_above_chance = np.where(cat_diag > (cat_chance + 0.05))[0]
    num_above_chance = np.where(num_diag > (num_chance + 0.05))[0]
    conj_above_chance = np.where(conj_diag > (conj_chance + 0.05))[0]
    
    # Calculate persistence (duration of significant decoding)
    if len(cat_above_chance) > 0:
        cat_persistence = window_centers[cat_above_chance[-1]] - window_centers[cat_above_chance[0]]
    else:
        cat_persistence = 0
        
    if len(num_above_chance) > 0:
        num_persistence = window_centers[num_above_chance[-1]] - window_centers[num_above_chance[0]]
    else:
        num_persistence = 0
        
    if len(conj_above_chance) > 0:
        conj_persistence = window_centers[conj_above_chance[-1]] - window_centers[conj_above_chance[0]]
    else:
        conj_persistence = 0
    
    results = {
        'cat_accuracy': cat_accuracy_smooth,
        'num_accuracy': num_accuracy_smooth,
        'conj_accuracy': conj_accuracy_smooth,
        'cat_peak_time': cat_peak_time,
        'num_peak_time': num_peak_time,
        'conj_peak_time': conj_peak_time,
        'binding_delay': binding_delay,
        'cat_persistence': cat_persistence,
        'num_persistence': num_persistence,
        'conj_persistence': conj_persistence,
        'window_centers': window_centers,
        'time_course_fig': fig2,
        'cross_temporal_fig': fig
    }

    return results, fig, fig2

In [None]:
# Analyze temporal dynamics for each brain area separately
print("\nAnalyzing temporal dynamics of binding by brain area...")

# Get unique brain regions
brain_regions = data_filtered['brainAreaOfCell'].unique()
print(f"Found {len(brain_regions)} brain regions: {brain_regions.tolist()}")

# Dictionary to store binding delays for each region
region_binding_delays = {}

# For each brain region, analyze all units
for region in brain_regions:
    print(f"\nAnalyzing temporal dynamics for region: {region}")
    
    # Select only data from this region
    region_data = data_filtered[data_filtered['brainAreaOfCell'] == region].copy()
    
    # Get all units from this region
    region_units = region_data['unit_id'].unique().tolist()
    
    # Skip regions with too few units
    if len(region_units) < 5:
        print(f"  Skipping {region} - only {len(region_units)} units (minimum 5 required)")
        continue
        
    print(f"  Analyzing {len(region_units)} units from {region}")
    
    # Store binding delays for this region
    binding_delays = []
    
    # Analyze each unit in this region
    for unit_id in tqdm(region_units):
        # Run the temporal dynamics analysis
        dynamics, _ = analyze_binding_dynamics(region_data, unit_id)
        
        # Find when different signals become significant
        cat_onset = None
        if any(dynamics['category_sig']):
            cat_onset = dynamics.loc[dynamics['category_sig'].idxmax(), 'time']
            
        num_onset = None
        if any(dynamics['number_sig']):
            num_onset = dynamics.loc[dynamics['number_sig'].idxmax(), 'time']
            
        conj_onset = None
        if any(dynamics['conjunction_sig']):
            conj_onset = dynamics.loc[dynamics['conjunction_sig'].idxmax(), 'time']
            
        # Calculate binding delay if all onsets are present
        if cat_onset is not None and num_onset is not None and conj_onset is not None:
            feature_onset = min(cat_onset, num_onset)
            binding_delay = conj_onset - feature_onset
            binding_delays.append(binding_delay)
    
    # Store results for this region
    if binding_delays:
        region_binding_delays[region] = binding_delays
        print(f"  {region}: {len(binding_delays)} units with valid binding delays")
        print(f"  Average binding delay: {np.mean(binding_delays):.0f} ms")
        print(f"  Range: {np.min(binding_delays):.0f} - {np.max(binding_delays):.0f} ms")
    else:
        print(f"  {region}: No units with valid binding delays")

# Create bar plot with error bars for binding delays by region
if region_binding_delays:
    plt.figure(figsize=(12, 6))
    
    regions = []
    means = []
    errors = []
    
    for region, delays in region_binding_delays.items():
        if len(delays) > 0:
            regions.append(region)
            means.append(np.mean(delays))
            errors.append(np.std(delays) / np.sqrt(len(delays)))  # Standard error
    
    # Sort by mean binding delay
    sorted_indices = np.argsort(means)
    sorted_regions = [regions[i] for i in sorted_indices]
    sorted_means = [means[i] for i in sorted_indices]
    sorted_errors = [errors[i] for i in sorted_indices]
    
    plt.bar(sorted_regions, sorted_means, yerr=sorted_errors, capsize=10)
    plt.ylabel('Binding Delay (ms)')
    plt.title('Average Binding Delay by Brain Region')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig("binding_delay_by_region.png", dpi=300, bbox_inches='tight')
    plt.show()
else:
    print("No regions with valid binding delays to plot")

In [None]:
# Perform population-level temporal dynamics analysis for each brain area separately
print("\nPerforming population-level temporal dynamics analysis by brain area...")

# Make sure we have imported the required modules
from scipy.ndimage import gaussian_filter1d
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import balanced_accuracy_score
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm

# Get unique brain regions
brain_regions = data_filtered['brainAreaOfCell'].unique()
print(f"Found {len(brain_regions)} brain regions: {brain_regions.tolist()}")

# For each brain region, perform independent analysis
for region in brain_regions:
    print(f"\nAnalyzing temporal dynamics for region: {region}")
    
    # Select only data from this region
    region_data = data_filtered[data_filtered['brainAreaOfCell'] == region].copy()
    
    # Get all units from this region
    region_units = region_data['unit_id'].unique().tolist()
    
    # Skip regions with too few units
    if len(region_units) < 5:
        print(f"  Skipping {region} - only {len(region_units)} units (minimum 5 required)")
        continue
        
    print(f"  Using all {len(region_units)} units from this region")
    
    # Run the cross-temporal analysis for this region
    try:
        ct_results, ct_fig, time_course_fig = cross_temporal_binding_analysis(region_data, region_units)
        plt.figure(ct_fig.number)
        plt.suptitle(f"Cross-Temporal Decoding - {region}", fontsize=16)
        plt.savefig(f"{region}_cross_temporal_decoding.png", dpi=300, bbox_inches='tight')
              
        # Save the time course figure
        plt.figure(ct_results['time_course_fig'].number)
        plt.suptitle(f"Decoding Time Course - {region}", fontsize=16)
        plt.savefig(f"{region}_decoding_time_course.png", dpi=300, bbox_inches='tight')
        plt.close(ct_results['time_course_fig'])
        
     
        # Report timing results for this region
        print(f"\nTemporal dynamics in {region}:")
        print(f"Category information peaks at {ct_results['cat_peak_time']:.0f} ms")
        print(f"Number information peaks at {ct_results['num_peak_time']:.0f} ms")
        print(f"Conjunction information peaks at {ct_results['conj_peak_time']:.0f} ms")
        
        if ct_results['binding_delay'] > 0:
            print(f"Binding delay: {ct_results['binding_delay']:.0f} ms")
        else:
            print("No binding delay observed")
        
        print(f"\nPersistence of information in {region}:")
        print(f"Category information persists for {ct_results['cat_persistence']:.0f} ms")
        print(f"Number information persists for {ct_results['num_persistence']:.0f} ms")
        print(f"Conjunction information persists for {ct_results['conj_persistence']:.0f} ms")
        
        # Add analysis of information persistence during delay periods
        delay1_indices = np.where((ct_results['window_centers'] >= 1000) & 
                                  (ct_results['window_centers'] < 2000))[0]
        delay2_indices = np.where((ct_results['window_centers'] >= 3000) & 
                                  (ct_results['window_centers'] < 5500))[0]
        
        if len(delay1_indices) > 0 and len(delay2_indices) > 0:
            # Get decoding accuracies during delay periods
            cat_diag = np.diag(ct_results['cat_accuracy'])
            num_diag = np.diag(ct_results['num_accuracy'])
            conj_diag = np.diag(ct_results['conj_accuracy'])
            
            cat_delay1 = np.mean(cat_diag[delay1_indices])
            num_delay1 = np.mean(num_diag[delay1_indices])
            conj_delay1 = np.mean(conj_diag[delay1_indices])
            
            cat_delay2 = np.mean(cat_diag[delay2_indices])
            num_delay2 = np.mean(num_diag[delay2_indices])
            conj_delay2 = np.mean(conj_diag[delay2_indices])
            
            print(f"\nDelay period information maintenance in {region}:")
            print(f"Delay 1 - Category: {cat_delay1:.3f}, Number: {num_delay1:.3f}, Conjunction: {conj_delay1:.3f}")
            print(f"Delay 2 - Category: {cat_delay2:.3f}, Number: {num_delay2:.3f}, Conjunction: {conj_delay2:.3f}")
            
            # Calculate relative change in decoding accuracy
            if cat_delay1 > 0:
                cat_change = (cat_delay2 - cat_delay1) / cat_delay1 * 100
                print(f"Category change: {cat_change:.1f}%")
            if num_delay1 > 0:
                num_change = (num_delay2 - num_delay1) / num_delay1 * 100
                print(f"Number change: {num_change:.1f}%")
            if conj_delay1 > 0:
                conj_change = (conj_delay2 - conj_delay1) / conj_delay1 * 100
                print(f"Conjunction change: {conj_change:.1f}%")
        
        # Update summary report with region-specific information
        with open("conjunction_coding_summary.txt", "a") as f:
            f.write(f"\nTEMPORAL DYNAMICS IN {region.upper()}\n")
            f.write("-" * 40 + "\n")
            f.write(f"Number of units: {len(region_units)}\n")
            f.write(f"Category information peaks at {ct_results['cat_peak_time']:.0f} ms\n")
            f.write(f"Number information peaks at {ct_results['num_peak_time']:.0f} ms\n")
            f.write(f"Conjunction information peaks at {ct_results['conj_peak_time']:.0f} ms\n\n")
            
            if ct_results['binding_delay'] > 0:
                f.write(f"Binding delay: {ct_results['binding_delay']:.0f} ms\n\n")
            else:
                f.write("No binding delay observed\n\n")
            
            f.write(f"Persistence of information:\n")
            f.write(f"Category: {ct_results['cat_persistence']:.0f} ms, ")
            f.write(f"Number: {ct_results['num_persistence']:.0f} ms, ")
            f.write(f"Conjunction: {ct_results['conj_persistence']:.0f} ms\n\n")
            
            if len(delay1_indices) > 0 and len(delay2_indices) > 0:
                f.write("Delay period information maintenance:\n")
                f.write(f"Delay 1 - Category: {cat_delay1:.3f}, Number: {num_delay1:.3f}, Conjunction: {conj_delay1:.3f}\n")
                f.write(f"Delay 2 - Category: {cat_delay2:.3f}, Number: {num_delay2:.3f}, Conjunction: {conj_delay2:.3f}\n\n")
    
    except Exception as e:
        print(f"Error analyzing {region}: {str(e)}")
        continue

### Compare conjunction coding between valid and invalid probe conditions