In [None]:
#import packages and functions

%load_ext autoreload
%autoreload 2

import pursuit_functions as pursuit
    
import pandas as pd
import numpy as np
from itertools import product
import matplotlib.pyplot as plt
import seaborn as sns


from sklearn.decomposition import PCA
from xgboost import XGBClassifier
from sklearn.neighbors import KNeighborsClassifier

from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import StandardScaler
from sklearn.utils import resample
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay

In [None]:
#load data set

all_pursuit_tasks = pd.read_parquet("ca1_ca3_rsc_pursuit_data.parquet", engine="pyarrow")

In [None]:
#drop NA values for RSC, CA1, and CA3 sessions
RSC_sessions = all_pursuit_tasks[all_pursuit_tasks["region"] == "RSC"]
CA1_sessions = all_pursuit_tasks[all_pursuit_tasks["region"] == "CA1"]
CA3_sessions = all_pursuit_tasks[all_pursuit_tasks["region"] == "CA3"]

RSC_cleaned = pursuit.tuning.drop_NA_vals(RSC_sessions)
CA1_cleaned = pursuit.tuning.drop_NA_vals(CA1_sessions)
CA3_cleaned = pursuit.tuning.drop_NA_vals(CA3_sessions)

In [None]:
RSC_cleaned.head()

In [None]:
#get all coordinate values below 99th percentile and normalize points for all regions 

normalized_sessions = pursuit.tuning.normalize_points(all_pursuit_tasks)

In [None]:
normalized_sessions.head()

In [None]:
#find the mean center and overall radius of the arena for all normalized data points
#you can specify the percentile value to be considered for the overall radius; default is 95th percentile
#calculates the individual center point for each session

circle_boundaries, radius = pursuit.tuning.fit_circle_bounds(normalized_sessions)
print(radius)

In [None]:
circle_boundaries.head()

In [None]:
#find circumference points for plotting using the center coordinates and overall radius
all_circ_points = pursuit.tuning.circumference(circle_boundaries)

In [None]:
#plot normalized concatenated laser and rat paths with center point and boundary
#the function takes the normalized_sessions, circle_boundaries, and all_circ_points dataframes

pursuit.tuning.plot_arena_bounds(normalized_sessions, circle_boundaries, all_circ_points)

In [None]:
#normalize only laser points and make a dataframe containing spike data using the normalized data mask
#function takes the cleaned df
RSC_laser_spks = pursuit.tuning.norm_laser_get_spks(RSC_cleaned)
CA1_laser_spks = pursuit.tuning.norm_laser_get_spks(CA1_cleaned)
CA3_laser_spks = pursuit.tuning.norm_laser_get_spks(CA3_cleaned)

In [None]:
CA1_laser_spks.head(50)

In [None]:
#find distance of normalized laser points to circle boundary by each session
#function takes the normalized laser/spikes and circle boundaries dataframes

RSC_laser_spks_bounds = pursuit.tuning.dist_to_bounds(RSC_laser_spks, circle_boundaries)
CA1_laser_spks_bounds = pursuit.tuning.dist_to_bounds(CA1_laser_spks, circle_boundaries)
CA3_laser_spks_bounds = pursuit.tuning.dist_to_bounds(CA3_laser_spks, circle_boundaries)

In [None]:
CA3_laser_spks_bounds.head(50)

### Continue from here!

In [None]:
#put raw spike counts and laser coords into bins calculated from the overall min and max bound_dist values
#function takes laser_spks_bounds dataframes 

def bin_spikes_laser(dataframe, 
                     spk_prefix="spkTable", 
                     dist_col="bound_dist",
                     num_bins=20, 
                     bin_edges=None):
    
    
    if bin_edges is None:
        overall_min = dataframe[dist_col].min()
        overall_max = dataframe[dist_col].max()
        bin_edges = np.linspace(overall_min, overall_max, num_bins+1)

    intervals = pd.IntervalIndex.from_breaks(bin_edges)
    bin_midpoints = (intervals.left + intervals.right) / 2

    bin_mid_lookup = dict(zip(intervals, bin_midpoints))
    
    rows = []

    for sessFile in dataframe["sessFile"].unique():
        
        session = dataframe[dataframe["sessFile"] == sessFile].copy()

        session["bound_bin"] = pd.cut(session[dist_col], bins=intervals, include_lowest=True)

        laser_occupancy = session["bound_bin"].value_counts().reindex(intervals, fill_value=0, observed=False)

        spk_cols = [col for col in session.columns if spk_prefix in col and not session[col].isna().all()]

        for neuron in spk_cols:
            spks_by_bin = session.groupby("bound_bin")[neuron].sum().reindex(intervals, fill_value=0)

            for i in intervals:
                rows.append({
                    "sessFile": sessFile,
                    "neuron": neuron,
                    "bin_midpoint": round(bin_mid_lookup[i], 2),
                    "spike_count": int(spks_by_bin[i]),
                    "laser_occupancy": int(laser_occupancy[i])
                })

    return pd.DataFrame(rows)



In [None]:
#put raw spike counts into bins calculated from the overall min and max bound_dist values 

def bin_spike_data(dataframe, spk_prefix="spkTable", num_bins=20, bin_edges=None):

    rows = []

    if bin_edges is None:
        bin_edges = pursuit.tuning.find_bin_edges(dataframe, "bound_dist", num_bins)

    for sessFile in dataframe["sessFile"].unique():

        session = dataframe[dataframe["sessFile"] == sessFile].copy()

        session["bound_bin"] = pd.cut(session["bound_dist"], bins=bin_edges, include_lowest=True)

        spk_cols = [col for col in session.columns if spk_prefix in col and not session[col].isna().all()]

        for spk in spk_cols:
            spk_by_bin = session.groupby("bound_bin")[spk].sum()

            bin_midpoints = pd.IntervalIndex.from_breaks(bin_edges).to_series().apply(
                lambda interval: round((interval.left + interval.right) / 2, 2)
                )

            for bin_mid, spk_count in zip(bin_midpoints, spk_by_bin):
                rows.append({
                    "sessFile": sessFile,
                    "neuron": spk,
                    "spike_count": spk_count,
                    "bin_midpoint": bin_mid        
                })

    binned_spks_df = pd.DataFrame(rows)
    return binned_spks_df




In [None]:
#put laser coordinates into bins calculated from the overall min and max bound_dist values
def bin_laser_data(dataframe, num_bins=20, bin_edges=None):

    rows = []

    if bin_edges is None:
        bin_edges = pursuit.tuning.find_bin_edges(dataframe, "bound_dist", num_bins)

    for sessFile in dataframe["sessFile"].unique():

        session = dataframe[dataframe["sessFile"] == sessFile].copy()

        session["bound_bin"] = pd.cut(session["bound_dist"], bins=bin_edges, include_lowest=True)

        coords_by_bin = session.groupby(["bound_bin"], observed=False).size()

        for bin_interval, laser_count in coords_by_bin.items():
            bin_mid = round((bin_interval.left + bin_interval.right) / 2, 2)
            
            rows.append({
                    "sessFile": sessFile,
                    "laser_occupancy": laser_count,
                    "bin_midpoint": bin_mid
                })

    binned_laser_df = pd.DataFrame(rows)
    return binned_laser_df


In [None]:
#bin the data!
RSC_laser_spikes_binned = pursuit.tuning.bin_spikes_laser(RSC_laser_spks_bounds)
CA1_laser_spikes_binned = pursuit.tuning.bin_spikes_laser(CA1_laser_spks_bounds)
CA3_laser_spikes_binned = pursuit.tuning.bin_spikes_laser(CA3_laser_spks_bounds)

In [None]:
#normalize spike counts by laser occupancy using bins calculated from the overall min and max bound_dist values 

def calculate_tuning(laser_spikes_binned_df):

    laser_spikes_binned_df["tuning"] = laser_spikes_binned_df["spike_count"] / laser_spikes_binned_df["laser_occupancy"]

    return laser_spikes_binned_df

In [None]:
RSC_tuning = pursuit.tuning.calculate_tuning(RSC_laser_spikes_binned)
CA1_tuning = pursuit.tuning.calculate_tuning(CA1_laser_spikes_binned)
CA3_tuning = pursuit.tuning.calculate_tuning(CA3_laser_spikes_binned)

In [None]:
#plot all neuron tuning curves

pursuit.tuning.plot_tuning_curves(CA3_tuning)

In [None]:
RSC_z_scored = pursuit.tuning.z_score_norm(RSC_tuning)
CA1_z_scored = pursuit.tuning.z_score_norm(CA1_tuning)
CA3_z_scored = pursuit.tuning.z_score_norm(CA3_tuning)

In [None]:
RSC_smoothed = pursuit.tuning.pivot_smooth(RSC_z_scored)
CA1_smoothed = pursuit.tuning.pivot_smooth(CA1_z_scored)
CA3_smoothed = pursuit.tuning.pivot_smooth(CA3_z_scored)

In [None]:
RSC_smoothed_sorted = pursuit.tuning.peak_sort(RSC_smoothed)
CA1_smoothed_sorted = pursuit.tuning.peak_sort(CA1_smoothed)
CA3_smoothed_sorted = pursuit.tuning.peak_sort(CA3_smoothed)

In [None]:
pursuit.tuning.heatmap(RSC_smoothed_sorted)
pursuit.tuning.heatmap(CA1_smoothed_sorted)
pursuit.tuning.heatmap(CA3_smoothed_sorted)