In [None]:
#import packages and functions

%load_ext autoreload
%autoreload 2

import pursuit_functions as pursuit
    
import pandas as pd
import numpy as np
from itertools import product
import matplotlib.pyplot as plt
import seaborn as sns


from sklearn.decomposition import PCA
from xgboost import XGBClassifier
from sklearn.neighbors import KNeighborsClassifier

from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import StandardScaler
from sklearn.utils import resample
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay

In [None]:
#load data set

all_pursuit_tasks = pd.read_parquet("ca1_ca3_rsc_pursuit_data.parquet", engine="pyarrow")

In [None]:
#drop NA values for RSC, CA1, and CA3 sessions
RSC_sessions = all_pursuit_tasks[all_pursuit_tasks["region"] == "RSC"]
CA1_sessions = all_pursuit_tasks[all_pursuit_tasks["region"] == "CA1"]
CA3_sessions = all_pursuit_tasks[all_pursuit_tasks["region"] == "CA3"]

RSC_cleaned = pursuit.tuning.drop_NA_vals(RSC_sessions)
CA1_cleaned = pursuit.tuning.drop_NA_vals(CA1_sessions)
CA3_cleaned = pursuit.tuning.drop_NA_vals(CA3_sessions)

In [None]:
#get all coordinate values below 99th percentile and normalize points for all regions 

normalized_sessions = pursuit.tuning.normalize_points(all_pursuit_tasks)

In [None]:
normalized_sessions.head()

In [None]:
#find the mean center and overall radius of the arena for all normalized data points
#you can specify the percentile value to be considered for the overall radius; default is 95th percentile
#calculates the individual center point for each session

circle_boundaries, radius = pursuit.tuning.fit_circle_bounds(normalized_sessions)
print(radius)

In [None]:
circle_boundaries.head()

In [None]:
#find circumference points for plotting using the center coordinates and overall radius
all_circ_points = pursuit.tuning.circumference(circle_boundaries)

In [None]:
#plot normalized concatenated laser and rat paths with center point and boundary
#the function takes the normalized_sessions, circle_boundaries, and all_circ_points dataframes

pursuit.tuning.plot_arena_bounds(normalized_sessions, circle_boundaries, all_circ_points)

In [None]:
#normalize only laser points and make a dataframe containing spike data using the normalized data mask
#function takes the cleaned df
RSC_laser_spks = pursuit.tuning.norm_laser_get_spks(RSC_cleaned)
CA1_laser_spks = pursuit.tuning.norm_laser_get_spks(CA1_cleaned)
CA3_laser_spks = pursuit.tuning.norm_laser_get_spks(CA3_cleaned)

In [None]:
RSC_laser_spks.head()

In [None]:
#find distance of normalized laser points to circle boundary by each session
#function takes the normalized laser/spikes and circle boundaries dataframes

RSC_laser_spks_bounds = pursuit.tuning.dist_to_bounds(RSC_laser_spks, circle_boundaries)
CA1_laser_spks_bounds = pursuit.tuning.dist_to_bounds(CA1_laser_spks, circle_boundaries)
CA3_laser_spks_bounds = pursuit.tuning.dist_to_bounds(CA3_laser_spks, circle_boundaries)

In [None]:
RSC_laser_spks_bounds.head()

In [None]:
#put raw spike counts into bins calculated from the overall min and max bound_dist values
def bin_spike_data(dataframe, spk_prefix="spkTable", num_bins=20, bin_edges=None):

    rows = []

    if bin_edges is None:
        bin_edges = pursuit.tuning.find_bin_edges(dataframe, "bound_dist", num_bins)

    for sessFile in dataframe["sessFile"].unique():

        session = dataframe[dataframe["sessFile"] == sessFile].copy()

        session["bound_bin"] = pd.cut(session["bound_dist"], bins=bin_edges, include_lowest=True)

        spk_cols = [col for col in session.columns if spk_prefix in col and not session[col].isna().all()]

        for spk in spk_cols:
            spk_by_bin = session.groupby("bound_bin")[spk].sum()

            bin_midpoints = pd.IntervalIndex.from_breaks(bin_edges).to_series().apply(
                lambda interval: round((interval.left + interval.right) / 2, 2)
                )

            for bin_mid, spk_count in zip(bin_midpoints, spk_by_bin):
                rows.append({
                    "sessFile": sessFile,
                    "neuron": spk,
                    "spike_count": spk_count,
                    "bin_midpoint": bin_mid        
                })

    binned_spks_df = pd.DataFrame(rows)
    return binned_spks_df




In [None]:
#put laser coordinates into bins calculated from the overall min and max bound_dist values
def bin_laser_data(dataframe, num_bins=20, bin_edges=None):

    rows = []

    if bin_edges is None:
        bin_edges = find_overall_bin_edges(dataframe, "bound_dist", num_bins)

    for sessFile in dataframe["sessFile"].unique():

        session = dataframe[dataframe["sessFile"] == sessFile].copy()

        session["bound_bin"] = pd.cut(session["bound_dist"], bins=bin_edges, include_lowest=True)

        coords_by_bin = session.groupby(["bound_bin"], observed=False).size()

        for bin_interval, laser_count in coords_by_bin.items():
            bin_mid = round((bin_interval.left + bin_interval.right) / 2, 2)
            
            rows.append({
                    "sessFile": sessFile,
                    "laser_occupancy": laser_count,
                    "bin_midpoint": bin_mid
                })

    binned_laser_df = pd.DataFrame(rows)
    return binned_laser_df


In [None]:
#bin the data!

RSC_laser_binned = bin_laser_data(RSC_sessions_laser_bounds_spks)
CA1_laser_binned = bin_laser_data(CA1_sessions_laser_bounds_spks)
CA3_laser_binned = bin_laser_data(CA3_sessions_laser_bounds_spks)

RSC_spikes_binned = bin_spike_data(RSC_sessions_laser_bounds_spks)
CA1_spikes_binned = bin_spike_data(CA1_sessions_laser_bounds_spks)
CA3_spikes_binned = bin_spike_data(CA3_sessions_laser_bounds_spks)

In [None]:
#normalize spike counts by laser occupancy using bins calculated from the overall min and max bound_dist values 

def make_tuning_curve(spike_df, laser_df):

    merged_df = pd.merge(spike_df, laser_df, on=["sessFile", "bin_midpoint"], how="left")

    merged_df["tuning"] = merged_df["spike_count"] / merged_df["laser_occupancy"]

    return merged_df

In [None]:
RSC_binned_tuning = make_tuning_curve(RSC_spikes_binned, RSC_laser_binned)
CA1_binned_tuning = make_tuning_curve(CA1_spikes_binned, CA1_laser_binned)
CA3_binned_tuning = make_tuning_curve(CA3_spikes_binned, CA3_laser_binned)

In [None]:
def plot_tuning_curves(dataframe):
    
    plt.figure(figsize=(12,8))

    for sessFile in dataframe["sessFile"].unique():

        session = dataframe[dataframe["sessFile"] == sessFile]
        
        pivoted = (session.pivot(index="neuron", columns="bin_midpoint", values="tuning").fillna(0))

        for neuron in pivoted.index:
            plt.plot(pivoted.columns, pivoted.loc[neuron], marker='o', linestyle='-', label=f"{neuron}")

    plt.xlabel("Boundary Distance (bin midpoint)")
    plt.ylabel("Tuning (spike count / laser occupancy)")
    plt.title(f"All Neuron Tuning Curves")
    plt.grid(True)
    plt.tight_layout()
    plt.show()



In [None]:
plot_tuning_curves(RSC_binned_tuning)

In [None]:
RSC_binned_tuning.head()

In [None]:
#z-score binned normalized data

def z_score_norm(dataframe):
    
    sessions = dataframe.copy()
        
    def norm_z(x):
        std = x.std()
        mean = x.mean()
        z = (x - mean) / std if std > 0 else x*0

        z_min, z_max = z.min(), z.max()
        if z_max > z_min:
            return (z - z_min) / (z_max - z_min)
        else:
            return z * 0

        
    sessions["z_score"] = (
        sessions.groupby(["sessFile", "neuron"])["tuning"].transform(norm_z)
    )

    return sessions
        

        

In [None]:
RSC_z_scored = z_score_norm(RSC_binned_tuning)
CA1_z_scored = z_score_norm(CA1_binned_tuning)
CA3_z_scored = z_score_norm(CA3_binned_tuning)

In [None]:
#plot heatmap for z-scored data: RSC sessions

pivoted_binned_zscore_neurons = RSC_z_scored.pivot_table(
    index=["sessFile", "neuron"], 
    columns="bin_midpoint", 
    values="z_score"
    ).fillna(0)

plt.figure(figsize=(12,8))

heatmap = sns.heatmap(pivoted_binned_zscore_neurons, cmap="viridis", annot=False, fmt=".2f", yticklabels=False)

x_labels = heatmap.get_xticklabels()

rounded_labels = [f"{float(label.get_text()):.2f}" if label.get_text() != "" else "" for label in x_labels]

heatmap.set_xticklabels(rounded_labels, rotation=45, ha="right")


plt.title(f"Z-scored Spike Activity by Boundary Distance: RSC Sessions")
plt.xlabel("Boundary Distance (bin midpoint)")
plt.ylabel("Neurons")
plt.show()


In [None]:
#pivot table and apply gaussian smoothing for plotting

def pivot_smooth(dataframe, window_size=3, window_type='gaussian', std=1):

    pivoted_df = dataframe.pivot_table(
        index=["sessFile", "neuron"], 
        columns="bin_midpoint", 
        values="z_score"
        ).fillna(0)

    smoothed_df = pivoted_df.apply(
        lambda row: row.rolling(
            window=window_size, 
            win_type=window_type, 
            center=True).mean(std=std), 
            axis=1
        ).fillna(0)
    
    return smoothed_df

In [None]:
#plot heatmap for z-scored data

def heatmap(dataframe):

    plt.figure(figsize=(12,8))

    heatmap = sns.heatmap(dataframe, cmap="viridis", annot=False, fmt=".2f", yticklabels=False)

    x_labels = heatmap.get_xticklabels()

    rounded_labels = [f"{float(label.get_text()):.2f}" if label.get_text() != "" else "" for label in x_labels]

    heatmap.set_xticklabels(rounded_labels, rotation=45, ha="right")


    plt.title(f"Z-scored Spike Activity by Boundary Distance")
    plt.xlabel("Boundary Distance (bin midpoint)")
    plt.ylabel("Neurons")
    plt.show()


In [None]:
#plot heatmaps for z-scored data: CA3 Sessions

pivoted_binned_zscore_neurons = CA3_z_scored.pivot_table(
    index=["sessFile", "neuron"], 
    columns="bin_midpoint", 
    values="z_score"
    ).fillna(0)

plt.figure(figsize=(12,8))

heatmap = sns.heatmap(pivoted_binned_zscore_neurons, cmap="viridis", annot=False, fmt=".2f", yticklabels=False)

x_labels = heatmap.get_xticklabels()

rounded_labels = [f"{float(label.get_text()):.2f}" if label.get_text() != "" else "" for label in x_labels]

heatmap.set_xticklabels(rounded_labels, rotation=45, ha="right")


plt.title(f"Z-scored Spike Activity by Boundary Distance: CA3 Sessions")
plt.xlabel("Boundary Distance (bin midpoint)")
plt.ylabel("Neurons")
plt.show()

In [None]:
#count number of neurons by sessFile and region dataframe

def count_neurons(dataframe, spk_prefix="spkTable"):

    # Create an empty dictionary to hold the counts for each session.
    session_neuron_counts = {}

    # Iterate over unique sessions.
    for sess in dataframe["sessFile"].unique():
        # Select only rows for that session.
        session = dataframe[dataframe["sessFile"] == sess]
        
        # Identify spkTable columns that are not entirely NaN for this session.
        spk_cols = [col for col in session.columns 
                    if col.startswith(spk_prefix) and not sess_df[col].isna().all()
        ]
        
        # Store the count for this session.
        session_neuron_counts[sess] = len(spk_cols)

    # Sum the counts over all sessions
    total_neurons = sum(session_neuron_counts.values())

    return session_neuron_counts, total_neurons