# WM binding

In [23]:
# import sys
# print(sys.executable)

In [24]:
# pip install numpy pandas matplotlib seaborn scipy mat73
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy.io import loadmat
# import mat73

# Load data

## Loading MATLAB objects

Reading .mat files can be achieved either through scipy.io.loadmat (well maintained, but only loads simple structs) or pymatreader.read_mat (supports all types of objects). What you will typically end up with is a dictionary of values or NumPy arrays

In [25]:
# Load the mat file
cell_mat = loadmat('./WMB_P106_1_20250511.mat')['cellStatsAll']
total_mat = loadmat('./WMB_P106_1_20250511.mat')['totStats']

In [26]:
# # OR Load the CSV file
# file_path = "D:/neuro1/code/psychophysics/WM_binding_pilot/data/P106_session_1_WM_binding_202527_9381_behav_cleaned.csv"
# df = pd.read_csv(file_path)

## basic database stats

In [27]:
from collections import Counter

def count_area_codes(area_column):
    """
    Translates numeric area codes to text labels and counts their occurrences.

    Args:
        area_column (array-like): A list or array of area code numbers.
        addMicroPrefix (bool): If True, prepend 'micro-' to each area label.

    Returns:
        dict: A dictionary mapping area label to its count.
    """
    mapping = {
        1: 'RH', 2: 'LH', 3: 'RA', 4: 'LA', 5: 'RAC', 6: 'LAC',
        7: 'RSMA', 8: 'LSMA', 9: 'RPT', 10: 'LPT', 11: 'ROFC', 12: 'LOFC',
        50: 'RFFA', 51: 'REC', 52: 'RCM', 53: 'LCM', 54: 'RPUL', 55: 'LPUL',
        56: 'N/A', 57: 'RPRV', 58: 'LPRV'
    }
    
    labels = []
    for code in area_column:
        label = mapping.get(code, 'Unknown')
        labels.append(label)
    
    return dict(Counter(labels))


In [28]:
# neuron number in each area
area_codes = total_mat[:, 3]

counts = count_area_codes(area_codes)
print("Area counts (no prefix):")
for area, count in counts.items():
    print(f"{area}: {count}")


Area counts (no prefix):
LAC: 17
LSMA: 10
LA: 9
LH: 13
RAC: 4
RSMA: 16
RA: 16
RH: 7
LOFC: 17
LPT: 5
ROFC: 11


# Format cell data

### Brain area

In [29]:
collapsed_area_map = {
    1: 'H', 2: 'H',
    3: 'A', 4: 'A',
    5: 'AC', 6: 'AC',
    7: 'SMA', 8: 'SMA',
    9: 'PT', 10: 'PT',
    11: 'OFC', 12: 'OFC',
    50: 'FFA', 51: 'EC',
    52: 'CM', 53: 'CM',
    54: 'PUL', 55: 'PUL',
    56: 'N/A', 57: 'PRV', 58: 'PRV'
}


In [30]:
# Flatten the (1, 305) array to (305,)
flat_array = cell_mat['brainAreaOfCell'].flatten()

# Extract the scalar code from each array([[x]])
flattened_codes = [int(item[0, 0]) for item in flat_array]

# Apply the collapsed area map
converted_labels = [collapsed_area_map.get(code, 'Unknown') for code in flattened_codes]

# Overwrite the original field with the new labels
cell_mat['brainAreaOfCell'][0] = converted_labels  # because it's 1 row



In [31]:
# cell_mat['brainAreaOfCell']


### Convert

In [32]:
# Flatten to a 1D array
cell_list = cell_mat[0]  # now shape is (n,)

# Extract each field into a dictionary
records = []
for cell in cell_list:
    record = {key: cell[key] for key in cell.dtype.names}
    records.append(record)

# Convert list of dicts to DataFrame
df = pd.DataFrame(records)
# df.convert_dtypes()

In [8]:
# df_sample = df.sample(10) # OPTIONAL: make sample for testing
df_sample = df
# df_sample.convert_dtypes()

### Filter out units with low firing rate

In [9]:
# Filter out units with low firing rate
fr = df_sample['timestamps'].apply(lambda x: len(x) / (x[-1] - x[0]) * 1e6)
df_sample_new = df_sample[fr > 0.1].reset_index(drop=True)

# unit id
df_sample_new = df_sample_new.reset_index(drop=True)
df_sample_new["unit_id"] = df_sample_new.index


# Extract trial info

In [10]:
df_sample_new.shape

(299, 34)

In [11]:
trial = df_sample_new["Trials"].iloc[0]
print(type(trial))
print(trial.dtype.names)


<class 'numpy.ndarray'>
('trial', 'rt', 'acc', 'key', 't_pre_stim', 't_delay1', 't_delay2', 'first_cat', 'second_cat', 'first_num', 'second_num', 'first_pic', 'second_pic', 'probe_cat', 'probe_pic', 'probe_validity', 'probe_num', 'correct_answer', 'rt_sliding_mean', 'cat_comparison', 'error_type')


In [None]:
def extract_trial_info(trials_struct, unit_id):
    # Build a DataFrame from the trials structure.
    # We use .squeeze() for each field – adjust if necessary.
    df_trial = pd.DataFrame({field: trials_struct[field].squeeze() 
                             for field in trials_struct.dtype.names})
    # Add the unit_id so that you can later separate trials by unit/session.
    df_trial["unit_id"] = unit_id
    df_trial["trial_nr"] = df_trial["trial"].apply(lambda x: np.squeeze(x).item() if isinstance(x, (list, np.ndarray)) else x) - 1 # Adjust for 0-indexing
    return df_trial


In [None]:

trial_info_list = []
for idx, row in df_sample_new.iterrows():
    # Use the unit identifier from this row
    unit_id = row["unit_id"]  
    # Extract the trial DataFrame, including the unit identifier.
    trial_info_list.append(extract_trial_info(row["Trials"], unit_id, ))

# Concatenate the list of trial info DataFrames into one.
trial_info = pd.concat(trial_info_list, ignore_index=True)


In [None]:
trial_info.iloc[0:5, :]

# Single unit analyses

In [None]:
# % ttl values
# c.marker.expstart        = 89;
# c.marker.expend          = 91;
# c.marker.fixOnset        = 10;
# c.marker.pic1            = 1;
# c.marker.delay1          = 2;
# c.marker.pic2            = 3;
# c.marker.delay2          = 4;
# c.marker.probeOnset      = 5;
# c.marker.response        = 6;
# c.marker.break           = 90;

# what names are in the df
df_sample_new.columns

## Baseline

### Event ts extraction

In [None]:
# Event ts extraction
result_temp = []
for i, row in df_sample_new.iterrows():
    events = row['events'].squeeze()       # Ensure it's 1D array
    idxs1 = row['idxEnc1'].squeeze() - 1   # Ensure indices are 1D array; start with 0!!!
    idxs2 = row['idxEnc1'].squeeze() - 1
    # Index into events using the adjusted indices:
    extracted1 = events[idxs1]   # shape (n_trials, 3)
    extracted2 = events[idxs2]   # shape (n_trials, 3)

    combined = np.column_stack((extracted1[:, 0], extracted2[:, 0]))
    result_temp.append(combined)

# Save as numpy array of arrays (object dtype)
result_array_temp = np.array(result_temp, dtype=object)

# Extract the first element from each 3-element array
epoch_ts = result_array_temp

In [None]:
# epoch_ts

### Compute baseline FR

In [None]:
row = df_sample_new.iloc[0]
# print(row.name)
# print(np.array(row["periods_Enc1"]).shape)     # Should be (n, 3)
# print(np.array(row["periods_Enc1"])[:, 1:3])   # Should be the start and end timestamps

# # Count spikes during epochs defined in periods_Enc1
# df_sample_new["fr_enc1"] = df_sample_new.apply(
#     lambda row: [
#         np.searchsorted(np.ravel(row["timestamps"]), end, side="right") 
#         - np.searchsorted(np.ravel(row["timestamps"]), start, side="left")
#         for start, end in row["periods_Enc1"][:, 1:3]
#     ],
#     axis=1
# )

# Count spikes during epochs defined in using epoch_ts
df_sample_new["fr_baseline"] = df_sample_new.apply(
    lambda row: [
        np.searchsorted(np.ravel(row["timestamps"]), epoch_off + 0 * 1e6, side="right") 
        - np.searchsorted(np.ravel(row["timestamps"]), epoch_on - 1 * 1e6, side="left")
        for epoch_on, epoch_off in np.array(epoch_ts[row.name])
    ],
    axis=1
)
# Generate trial numbers for each row
df_sample_new["trial_nr"] = df_sample_new["fr_baseline"].apply(lambda x: np.arange(len(x)))

## Epoch of interest

### Event ts extraction

In [None]:
# Event ts extraction
result_temp = []
for i, row in df_sample_new.iterrows():
    events = row['events'].squeeze()       # Ensure it's 1D array
    idxs1 = row['idxEnc1'].squeeze() - 1   # Ensure indices are 1D array; start with 0!!!
    idxs2 = row['idxEnc1'].squeeze() - 1
    # Index into events using the adjusted indices:
    extracted1 = events[idxs1]   # shape (n_trials, 3)
    extracted2 = events[idxs2]   # shape (n_trials, 3)

    combined = np.column_stack((extracted1[:, 0], extracted2[:, 0]))
    result_temp.append(combined)

# Save as numpy array of arrays (object dtype)
result_array_temp = np.array(result_temp, dtype=object)

# Extract the first element from each 3-element array
epoch_ts = result_array_temp

### Compute epoch FR

In [None]:
row = df_sample_new.iloc[0]
# print(row.name)
# print(np.array(row["periods_Enc1"]).shape)     # Should be (n, 3)
# print(np.array(row["periods_Enc1"])[:, 1:3])   # Should be the start and end timestamps

# # Count spikes during epochs defined in periods_Enc1
# df_sample_new["fr_enc1"] = df_sample_new.apply(
#     lambda row: [
#         np.searchsorted(np.ravel(row["timestamps"]), end, side="right") 
#         - np.searchsorted(np.ravel(row["timestamps"]), start, side="left")
#         for start, end in row["periods_Enc1"][:, 1:3]
#     ],
#     axis=1
# )

# Count spikes during epochs defined in using epoch_ts
df_sample_new["fr_epoch"] = df_sample_new.apply(
    lambda row: [
        np.searchsorted(np.ravel(row["timestamps"]), epoch_off + 1 * 1e6, side="right") 
        - np.searchsorted(np.ravel(row["timestamps"]), epoch_on - 0 * 1e6, side="left")
        for epoch_on, epoch_off in np.array(epoch_ts[row.name])
    ],
    axis=1
)
# # Generate trial numbers for each row
# df_sample_new["trial_nr"] = df_sample_new["fr_epoch"].apply(lambda x: np.arange(len(x)))

## Save data

In [None]:
# Convert lists to rows
df_sample_new = df_sample_new.explode(["fr_epoch", "trial_nr", "fr_baseline"])

In [None]:
# delete df_sample_new["trials"]
df_sample_new = df_sample_new.drop(columns=["Trials"])

In [None]:
df_sample_new.head()

## Join df

In [None]:
# trial_info.head()

In [None]:
df_sample_new = df_sample_new.reset_index(drop=True)
trial_info = trial_info.reset_index(drop=True)

data = pd.merge(
    df_sample_new,
    trial_info,
    on=["unit_id", "trial_nr"],
    how="left",
).infer_objects()


In [None]:
cols_to_keep = [
    "unit_id", "timestamps", "brainAreaOfCell", "fr_epoch","fr_baseline", "trial_nr",
    "first_cat", "second_cat", "first_num", "second_num",
    "first_pic", "second_pic", "probe_cat", "probe_pic",
    "probe_validity", "probe_num", "correct_answer",
    "rt", "acc", "key"
]

data_filtered = data[cols_to_keep]


In [None]:
# data_filtered.timestamps[0]

## Tuning Analysis

In [None]:
from tqdm import tqdm
import statsmodels.formula.api as smf
import statsmodels.api as sm

In [None]:
# Convert to simpler, hashable values.
data_filtered["first_cat_simple"] = data_filtered["first_cat"].apply(
    lambda x: str(np.squeeze(x)) if isinstance(x, (list, np.ndarray)) else str(x)
)
data_filtered["second_cat_simple"] = data_filtered["second_cat"].apply(
    lambda x: str(np.squeeze(x)) if isinstance(x, (list, np.ndarray)) else str(x)
)
data_filtered["first_num_simple"] = data_filtered["first_num"].apply(
    lambda x: str(np.squeeze(x)) if isinstance(x, (list, np.ndarray)) else str(x)
)
data_filtered["second_num_simple"] = data_filtered["second_num"].apply(
    lambda x: str(np.squeeze(x)) if isinstance(x, (list, np.ndarray)) else str(x)
)

In [None]:
data_filtered.columns

In [None]:
# Choose to use Poisson GLM or OLS/ANOVA (set use_poisson = True or False)
use_poisson = False
records = []

# Group the data by unit (neuron)
for unit_id, unit_df in tqdm(data_filtered.groupby("unit_id"), desc="Tuning analysis per unit"):
    unit_df["fr_epoch"] =unit_df["fr_epoch"] - unit_df["fr_baseline"]  # Subtract baseline firing rate

    # Skip if the firing rate doesn't vary (avoid singular fit issues)
    if unit_df["fr_epoch"].std() == 0:
        continue

    # Use a GLM with Poisson family + Wald Test if use_poisson is True
    if use_poisson:
        # Formula with categorical predictors for first_cat and first_num
        model = smf.glm(
            formula="fr_epoch ~ C(first_cat_simple) * C(first_num_simple)",
            # formula="fr_enc1 ~ C(first_cat) + C(first_num)",
            data=unit_df,
            family=sm.families.Poisson(),
        )
        # Fit the model and perform Wald tests for each term
        results = model.fit().wald_test_terms(scalar=True).table
    else:
        # Otherwise, use OLS and compute Type II ANOVA
        model = smf.ols(
            # formula="fr_epoch ~ C(first_cat_simple) + C(first_num_simple)",
            formula="fr_epoch ~ C(first_cat_simple) + C(first_num_simple)",
            data=unit_df,
        )
        results = sm.stats.anova_lm(model.fit(), typ=2)[:-1]
        results = results.rename(columns={"PR(>F)": "pvalue"})

    # Add the neuron identifier and additional info, if available
    results["unit_id"] = unit_id
    if "brainAreaOfCell" in unit_df.columns:
        results["brainAreaOfCell"] = [unit_df["brainAreaOfCell"].iloc[0]] * len(results)
    
    records.append(results)

# Combine records from all units into a single DataFrame.
records = pd.concat(records).reset_index(names="predictor")

In [None]:
# Convert from wide to long format for plotting
df_stats = records.melt(id_vars=["unit_id", "brainAreaOfCell", "predictor"], value_vars=["pvalue"])
df_stats["is_significant"] = df_stats["value"] < 0.05

df_stats

In [None]:
# Plot significant counts
fg = sns.catplot(
    data=df_stats[df_stats["predictor"] != "Intercept"],
    x="brainAreaOfCell",
    # order=["AMY", "HPC", "dACC", "preSMA", "vmPFC", "VTC"],
    hue="is_significant",
    col="predictor",
    kind="count",
    palette=["tab:red", "tab:green"],
)
plt.show()

In [None]:
fg = sns.catplot(
    data=df_stats.query("predictor != 'Intercept'"),
    x="brainAreaOfCell",
    hue="is_significant",
    col="predictor",
    kind="count",
    stat="percent",        # <— tell seaborn to show % instead of raw counts
    palette=["tab:red", "tab:green"],
)

# optionally format the y-axis ticks as percents
for ax in fg.axes.flat:
    ax.yaxis.set_major_formatter(mpl.ticker.PercentFormatter())
    
plt.show()


### Return sig neuron idx

In [None]:
# sig_units = df_stats.loc[
#     (df_stats["predictor"] == "C(first_cat_simple):C(first_num_simple)") & (df_stats["is_significant"]),
#     "unit_id",
# ].unique().tolist()

sig_units = df_stats.loc[
    (df_stats["predictor"] == "C(first_cat_simple)") & (df_stats["is_significant"]),
    "unit_id"
].unique().tolist()

# sig_interaction_units = df_stats.loc[
#     (df_stats["predictor"] == "C(first_cat_simple):C(first_num_simple)") & (df_stats["is_significant"]),
#     "unit_id"
# ].unique().tolist()

# sig_cat_units = df_stats.loc[
#     (df_stats["predictor"] == "C(first_cat_simple)") & (df_stats["is_significant"]),
#     "unit_id"
# ].unique().tolist()

# sig_units = [u for u in sig_num_units if u in sig_cat_units if u not in sig_interaction_units]

print("Significant units:", sig_units)
# print("Significant units for first_num:", sig_num_units)


## plot

In [87]:
# pip install git+https://github.com/ioqfwfq/rlab_neural_analysis.git@jz
from neural_analysis.visualize import plot_spikes_with_PSTH
from neural_analysis.spikes import get_spikes

In [None]:
# Define the unit IDs of interest
unit_ids = sig_units # or any other list of unit IDs you want to plot

for unit_id in unit_ids:
  # Select the data for this specific unit.
  df_unit = data_filtered[data_filtered["unit_id"] == unit_id].reset_index(drop=True)
  
  # Get the real brain area from the unit's data.
  area = df_unit["brainAreaOfCell"].iloc[0]
  
  cond = "second_cat_simple"    # or whichever condition you plan to use in your plotting
  cmap = "Set1"
  
  # Filter data_filtered rows for the given unit
  df_unit = data_filtered[data_filtered["unit_id"] == unit_id].reset_index(drop=True)
  # Now extract the corresponding labels and stats from this subset
  group_labels = df_unit[cond].apply(lambda x: np.squeeze(x).item() if isinstance(x, (list, np.ndarray)) else x)
  # print(group_labels)
  stats = df_unit["first_cat_simple"]

  # Also get the alignments for this unit
  alignments = np.asarray(epoch_ts[unit_id][:, 0], dtype=np.float64) / 1e6  # if using nanosecond times

  # Get full spike train for the unit
  spikes = np.asarray(df_unit["timestamps"].iloc[0]).flatten().astype(np.float64) / 1e6
  spikes = np.sort(spikes)

  # Plot
  axes = plot_spikes_with_PSTH(
      spikes,
      alignments,
      window = (-1, 8),
      group_labels=group_labels,
      stats=stats,
      plot_stats=False,
      sig_test=True,
      cmap=cmap,
  )

  # adjust plot visuals
  # [ax.axvline(np.mean(trial_info["rt"]), color="red", ls="--") for ax in axes]
  axes[1].set_xlabel("Time from stim onset [s]")
  axes[0].set_title(f"{area} {unit_id} [{cond} tuned]")

  plt.show()
  # plt.savefig(f"{area}_{unit_id}_2.png", dpi=300, bbox_inches="tight")
  # plt.close()

# Pupulation

In [None]:
from sklearn.datasets import load_iris
from sklearn.svm import LinearSVC
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.model_selection import cross_val_score
from sklearn.decomposition import PCA
from sklearn.manifold import MDS

In [None]:
# select data
# sub_df = data_filtered.query("unit_id in @sig_units").reset_index(drop=True)
sub_df = data_filtered


In [None]:
# randomly select 20 trials of each condition from each unit
# note the the resultant df is sorted
sub_df = sub_df.groupby(["unit_id", "first_num_simple", "first_cat_simple"]).sample(10)

# collect into design matrix + labels
X = np.column_stack(sub_df.groupby("unit_id")["fr_epoch"].agg(list))
y = sub_df["first_num_simple"].iloc[:len(X)].to_numpy(str)

# fit SVM with cross-validation
# Don't expect good performance since this is single-session
pipe = Pipeline([("scaler", StandardScaler()), ("clf", LinearSVC())])
scores = cross_val_score(pipe, X, y, cv=5)
print(f"CV Accuracy: {np.mean(scores):.2f} ± {np.std(scores):.2f}")

# # collect into design matrix + labels
# X = np.column_stack(sub_df.groupby("unit_id")["fr_epoch"].agg(list))
# y = sub_df["first_num_simple"].iloc[:len(X)].to_numpy(str)

# # fit SVM with cross-validation
# # Don't expect good performance since this is single-session
# pipe = Pipeline([("scaler", StandardScaler()), ("clf", LinearSVC())])
# scores = cross_val_score(pipe, X, y, cv=5)
# print(f"CV Accuracy: {np.mean(scores):.2f} ± {np.std(scores):.2f}")

In [None]:
from sklearn.preprocessing import LabelEncoder

# Encode string labels into numeric
y_encoded = LabelEncoder().fit_transform(y)


# Dimensionality reduction
pca = Pipeline([("scaler", StandardScaler()), ("pca", PCA(n_components=3))])
mds = Pipeline([("scaler", StandardScaler()), ("mds", MDS(n_components=3))])
X_pca = pca.fit_transform(X)
X_mds = mds.fit_transform(X)

# Visualize low-D representation
# Note: matplotlib's 3D plotting is basic and "fake."
#       Use other packages recommended above as needed.
plt.figure(figsize=(16, 8))
ax1 = plt.subplot(1, 2, 1)
ax2 = plt.subplot(1, 2, 2, projection="3d")
ax1.scatter(X_pca[:, 1], X_pca[:, 2], c=y_encoded)
scatter = ax2.scatter(X_mds[:, 0], X_mds[:, 1], X_mds[:, 2], c=y_encoded)
legend1 = ax2.legend(*scatter.legend_elements()) # infer legend from scatter
ax2.add_artist(legend1)
ax1.set_title("PCA")
ax2.set_title("MDS")
plt.show()