In [None]:
#import packages and functions

%load_ext autoreload
%autoreload 2

import pursuit_functions as pursuit
    
import pandas as pd
import numpy as np
from itertools import product
import matplotlib.pyplot as plt
import seaborn as sns


from sklearn.decomposition import PCA
from xgboost import XGBClassifier
from sklearn.neighbors import KNeighborsClassifier

from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import StandardScaler
from sklearn.utils import resample
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay

In [None]:
#load data set

all_pursuit_tasks = pd.read_parquet("ca1_ca3_rsc_pursuit_data.parquet", engine="pyarrow")

# Normalize points and find circle boundaries.

In [None]:
#get all coordinate values below 99th percentile and normalize points for all regions 

normalized_sessions = pursuit.tuning.normalize_points(all_pursuit_tasks)

In [None]:
normalized_sessions.head()

In [None]:
#find the mean center and overall radius of the arena for all normalized data points
#you can specify the percentile value to be considered for the overall radius; default is 95th percentile
#calculates the individual center point for each session

circle_boundaries, radius = pursuit.tuning.fit_circle_bounds(normalized_sessions)
print(radius)

In [None]:
circle_boundaries.head()

In [None]:
#find circumference points for plotting using the center coordinates and overall radius
all_circ_points = pursuit.tuning.circumference(circle_boundaries)

# Plot the laser coordinates and boundaries.

In [None]:
#plot normalized concatenated laser and rat paths with center point and boundary
#the function takes the normalized_sessions, circle_boundaries, and all_circ_points dataframes

pursuit.tuning.plot_arena_bounds(normalized_sessions, circle_boundaries, all_circ_points)

# Clean data and pull spike data.

In [None]:
# obtain region-specific sessions
RSC_sessions = all_pursuit_tasks[all_pursuit_tasks["region"] == "RSC"]
CA1_sessions = all_pursuit_tasks[all_pursuit_tasks["region"] == "CA1"]
CA3_sessions = all_pursuit_tasks[all_pursuit_tasks["region"] == "CA3"]



In [None]:
#obtain trial block-specific sessions

RSC_pursuit = RSC_sessions[RSC_sessions["trial_block"] == "pursuit"]
CA1_pursuit = CA1_sessions[CA1_sessions["trial_block"] == "pursuit"]
CA3_pursuit = CA3_sessions[CA3_sessions["trial_block"] == "pursuit"]

In [None]:
#drop NA values for RSC, CA1, and CA3 sessions

RSC_cleaned = pursuit.tuning.drop_NA_vals(RSC_pursuit)
CA1_cleaned = pursuit.tuning.drop_NA_vals(CA1_pursuit)
CA3_cleaned = pursuit.tuning.drop_NA_vals(CA3_pursuit)

In [None]:
# there's some data compression that collapses the 120hz recording time points as the recordings get longer. 
# first we collapse the time points to whole seconds
# we need to normalize the time so that the first observation starts from 0 seconds
# we then calculate the normalized minute each observation belongs to

def normalize_time(dataframe):
    df = dataframe.copy()
    df["time"] = df["time"].astype(float)
    df["relative_time"] = df.groupby("sessFile")["time"].transform(lambda x: x - x.min())
    df["norm_sec"] = df["relative_time"].astype(int)
    df["norm_min"] = df["norm_sec"] // 60

    return df


In [None]:
RSC_clean_time = normalize_time(RSC_cleaned)

In [None]:
RSC_clean_time

In [None]:
# to split each session into epochs, we separate them by first/second half of the recording or odd/even minutes
# for epoch_half, True:1, False:2 will result in the first half labeled as 1 (<= cutoff) and the second half labeled as 2
# for epoch_odd_even, we take the minutes divided by 2 and find the remainder. If the remainder is 1 (odd), it will be labeled as 1 and if the remainder is 0 (even), it will be labeled as 2

def assign_epochs(dataframe):
    df = dataframe.copy()

    # separate epochs by half

    def label_half(group):
        mins = group["norm_min"].unique()
        mins.sort()
        cutoff = mins[len(mins) // 2]
        return group["norm_min"] <= cutoff
    
    df["epoch_half"] = (
        df.groupby("sessFile", group_keys=False)
        .apply(label_half, include_groups=False)
        .map({True: 1, False: 2})
    )

    # separate epochs by odd/even minutes

    df["epoch_odd_even"] = df["norm_min"] % 2
    df["epoch_odd_even"] = df["epoch_odd_even"].map({1: 1, 0: 2})

    return df


In [None]:
# for epoch_half, True:1, False:2 will result in the first half labeled as 1 (<= cutoff) and the second half labeled as 2
# for epoch_odd_even, we take the minutes divided by 2 and find the remainder. If the remainder is 1 (odd), it will be labeled as 1 and if the remainder is 0 (even), it will be labeled as 2

RSC_clean_time_epochs = assign_epochs(RSC_clean_time)

RSC_clean_time_epochs

In [None]:
RSC_clean_time_epochs[RSC_clean_time_epochs["sessFile"] == "KB10_02_pursuitRoot.mat"]

In [None]:
#make a new df with only sessFile, laser, epoch, and spike data
def epoch_laser_spks(dataframe, laser_x="laserPos_1", laser_y="laserPos_2"):

    epoch_laser_spks_data = []

    spk_columns = [col for col in dataframe.columns if "spkTable" in col]
    
    for sessFile in dataframe["sessFile"].unique():

            session = dataframe[dataframe["sessFile"] == sessFile].copy()

            laser_x_vals = session[laser_x].astype("float64")
            laser_y_vals = session[laser_y].astype("float64")
    
            #identify 99th percentile x, y boundaries
            x_low, x_high = np.percentile(laser_x_vals, [0, 99])
            y_low, y_high = np.percentile(laser_y_vals, [0, 99])

            #filter the data so we only get the data under the 99th percentile
            filter = (
                (laser_x_vals >= x_low) & (laser_x_vals <= x_high) & 
                (laser_y_vals >= y_low) & (laser_y_vals <= y_high)
            )

            filtered_session = session[filter].copy()

            #normalize the points to the origin
            x_normalized = filtered_session[laser_x].astype("float64") - float(x_low)
            y_normalized = filtered_session[laser_y].astype("float64") - float(y_low)

            #grab epoch data 
            epoch_half = filtered_session["epoch_half"].values
            epoch_odd_even = filtered_session["epoch_odd_even"].values

            #make a dataframe containing normalized data
            normalized_df = pd.DataFrame({
                "sessFile": sessFile,
                "laser_x_normalized": x_normalized.values,
                "laser_y_normalized": y_normalized.values,
                "epoch_half": epoch_half,
                "epoch_odd_even": epoch_odd_even
            })

            #grab spike data using the normalized data mask
            spk_df = filtered_session[spk_columns].reset_index(drop=True)

            #make a combined dataframe
            combined_df = pd.concat([normalized_df.reset_index(drop=True), spk_df], axis=1)

            #append dataframe to the list
            epoch_laser_spks_data.append(combined_df)

        #make a giant dataframe by concatenating all the dataframes in the list        
    epoch_laser_spks_df = pd.concat(epoch_laser_spks_data, ignore_index=True)

    return epoch_laser_spks_df

In [None]:
RSC_epoch_laser_spks = epoch_laser_spks(RSC_clean_time_epochs)

In [None]:
RSC_epoch_laser_spks[RSC_epoch_laser_spks["sessFile"] == "KB10_02_pursuitRoot.mat"]

In [None]:
#function for pulling epochs from each session along with associated sessFile, laser, and spike data

def pull_epochs(dataframe, 
                spk_prefix="spkTable"):

    epoch_first_half = []
    epoch_second_half = []
    epoch_odd_min = []
    epoch_even_min = []

    for sessFile in dataframe["sessFile"].unique():
        
        session = dataframe[dataframe["sessFile"] == sessFile].copy()

        # mapping first half and second epochs for each session 
        first_half = session[session["epoch_half"] == 1]
        second_half = session[session["epoch_half"] ==2]
        # mapping odd and even epochs for each session
        odd_min = session[session["epoch_odd_even"] == 1]
        even_min = session[session["epoch_odd_even"] == 2]

        spk_cols = [col for col in session.columns if spk_prefix in col and not session[col].isna().all()]

        #function for grabbing sessFile, laser x, laser y, and spk columns for each epoch
        def build_epoch_df(epoch):
            return pd.concat([
                epoch[["sessFile", "laser_x_normalized", "laser_y_normalized"]].reset_index(drop=True),
                epoch[spk_cols].reset_index(drop=True)
            ], axis=1)

        epoch_first_half.append(build_epoch_df(first_half))
        epoch_second_half.append(build_epoch_df(second_half))
        epoch_odd_min.append(build_epoch_df(odd_min))
        epoch_even_min.append(build_epoch_df(even_min))

    return (
        pd.concat(epoch_first_half, ignore_index=True),
        pd.concat(epoch_second_half, ignore_index=True),
        pd.concat(epoch_odd_min, ignore_index=True),
        pd.concat(epoch_even_min, ignore_index=True),
    )
    


In [None]:
RSC_epoch_first_half, RSC_epoch_second_half, RSC_epoch_odd_min, RSC_epoch_even_min = pull_epochs(RSC_epoch_laser_spks)

In [None]:
RSC_epoch_first_half

In [None]:
tuning curves for first half and second half epochs

In [None]:
tuning curves for odd and even min epochs

In [None]:
shift spikes circularly for bootstrapping x1000

In [None]:
def unique_time_entries(dataframe):
    all_sessions = []

    for sessFile in dataframe["sessFile"].unique():
        session = dataframe[dataframe["sessFile"] == sessFile]
        session_times = session["time"]

        times_unique = session_times.astype("float64").nunique()
        times_count = session_times.astype("float64").value_counts().to_dict()

        all_sessions.append({
            "sessFile": sessFile,
            "times_unique": times_unique,
            "unique_times_count": times_count
        })

    return pd.DataFrame(all_sessions)

In [None]:
RSC_sessions_times = unique_time_entries(RSC_sessions)
RSC_sessions_times

In [None]:
def inspect_time_counts(session_row):
    times= session_row["unique_times_count"]
    sorted_times = sorted(times.items())
    return pd.DataFrame(sorted_times, columns=["time", "count"])

In [None]:
lp03_25 = RSC_sessions_times[RSC_sessions_times["sessFile"] == "LP03_25_pursuitRoot.mat"]

lp03_25_row = lp03_25.iloc[0]

lp03_25_inspection = inspect_time_counts(lp03_25_row)


lp03_25_inspection

In [None]:
RSC_cleaned.head(10)

In [None]:
#normalize only laser points and make a dataframe containing spike data using the normalized data mask
#function takes the cleaned df
RSC_laser_spks = pursuit.tuning.norm_laser_get_spks(RSC_cleaned)
CA1_laser_spks = pursuit.tuning.norm_laser_get_spks(CA1_cleaned)
CA3_laser_spks = pursuit.tuning.norm_laser_get_spks(CA3_cleaned)

In [None]:
CA1_laser_spks.head(50)

# Find distance of laser points to boundary and bin data by distance.

In [None]:
#find distance of normalized laser points to circle boundary by each session
#function takes the normalized laser/spikes and circle boundaries dataframes

RSC_laser_spks_bounds = pursuit.tuning.dist_to_bounds(RSC_laser_spks, circle_boundaries)
CA1_laser_spks_bounds = pursuit.tuning.dist_to_bounds(CA1_laser_spks, circle_boundaries)
CA3_laser_spks_bounds = pursuit.tuning.dist_to_bounds(CA3_laser_spks, circle_boundaries)

In [None]:
CA3_laser_spks_bounds.head(50)

In [None]:
#bin the data!
RSC_laser_spikes_binned = pursuit.tuning.bin_spikes_laser(RSC_laser_spks_bounds)
CA1_laser_spikes_binned = pursuit.tuning.bin_spikes_laser(CA1_laser_spks_bounds)
CA3_laser_spikes_binned = pursuit.tuning.bin_spikes_laser(CA3_laser_spks_bounds)

# Normalize spike counts, smooth data, peak sort neurons, and plot tuning curves.

In [None]:
#calculate raw tuning curves
RSC_tuning = pursuit.tuning.calculate_tuning(RSC_laser_spikes_binned)
CA1_tuning = pursuit.tuning.calculate_tuning(CA1_laser_spikes_binned)
CA3_tuning = pursuit.tuning.calculate_tuning(CA3_laser_spikes_binned)

In [None]:
#plot all neuron tuning curves for a sanity check

pursuit.tuning.plot_tuning_curves(CA3_tuning)

In [None]:
# z-score and normalize data
RSC_z_scored = pursuit.tuning.z_score_norm(RSC_tuning)
CA1_z_scored = pursuit.tuning.z_score_norm(CA1_tuning)
CA3_z_scored = pursuit.tuning.z_score_norm(CA3_tuning)

In [None]:
# apply smoothing and pivot data
RSC_smoothed = pursuit.tuning.pivot_smooth(RSC_z_scored)
CA1_smoothed = pursuit.tuning.pivot_smooth(CA1_z_scored)
CA3_smoothed = pursuit.tuning.pivot_smooth(CA3_z_scored)

In [None]:
# peak sort the data
RSC_smoothed_sorted = pursuit.tuning.peak_sort(RSC_smoothed)
CA1_smoothed_sorted = pursuit.tuning.peak_sort(CA1_smoothed)
CA3_smoothed_sorted = pursuit.tuning.peak_sort(CA3_smoothed)

In [None]:
#plot heatmaps
pursuit.tuning.heatmap(RSC_smoothed_sorted)
pursuit.tuning.heatmap(CA1_smoothed_sorted)
pursuit.tuning.heatmap(CA3_smoothed_sorted)

In [None]:
RSC_count = pursuit.tuning.count_neurons(RSC_cleaned)
CA1_count = pursuit.tuning.count_neurons(CA1_cleaned)
CA3_count = pursuit.tuning.count_neurons(CA3_cleaned)

In [None]:
RSC_count


In [None]:
CA1_count

In [None]:
CA3_count