In [None]:
from pathlib import Path
import scipy.io as sio
import pandas as pd
import numpy as np
import matplotlib as plt
import pursuit_functions
import pyarrow as pa

pd.set_option('display.max_columns', 100)  # Show more columns (default is 20)

%load_ext autoreload
%autoreload 2

# TODO: 
# * normalize column names (i.e. sessFile -> dataDir)
# * look into LP03_03_pursuitRoot.mat corruption

In [None]:


def optimize_pyarrow_dtypes(df):
    """
    Optimizes PyArrow-backed dtypes in a DataFrame by selecting the smallest appropriate dtype
    for each column based on value ranges.
    
    Parameters:
        df (pd.DataFrame): A DataFrame with PyArrow-backed dtypes.

    Returns:
        pd.DataFrame: A DataFrame with optimized PyArrow-backed dtypes.
    """
    optimized_dtypes = {}

    for col in df.columns:
        if pd.api.types.is_integer_dtype(df[col]):
            min_val, max_val = df[col].min(), df[col].max()

            # Select the most efficient integer dtype
            if min_val >= 0 and max_val <= 255:
                optimized_dtypes[col] = pd.ArrowDtype(pa.uint8())
            elif min_val >= 0 and max_val <= 65535:
                optimized_dtypes[col] = pd.ArrowDtype(pa.uint16())
            elif min_val >= -128 and max_val <= 127:
                optimized_dtypes[col] = pd.ArrowDtype(pa.int8())
            elif min_val >= -32768 and max_val <= 32767:
                optimized_dtypes[col] = pd.ArrowDtype(pa.int16())
            elif min_val >= -(2**31) and max_val <= (2**31 - 1):
                optimized_dtypes[col] = pd.ArrowDtype(pa.int32())
            else:
                optimized_dtypes[col] = pd.ArrowDtype(pa.int64())

        elif pd.api.types.is_float_dtype(df[col]):
            min_val, max_val = df[col].min(), df[col].max()

            # Select the most efficient float dtype
            if min_val >= -65504 and max_val <= 65504:
                optimized_dtypes[col] = pd.ArrowDtype(pa.float16())
            elif min_val >= -(2**31) and max_val <= (2**31 - 1):
                optimized_dtypes[col] = pd.ArrowDtype(pa.float32())
            else:
                optimized_dtypes[col] = pd.ArrowDtype(pa.float64())

    # Convert DataFrame dtypes
    return df.astype(optimized_dtypes)



In [None]:
#load region files
data_dir = Path("/Volumes/ASA_Lab/Data/Andy/nitzPurusitData")
#"/Users/may/pursuitSessionFiles")
#data_dir = Path("/Volumes/ASA_Lab/Data/Xiaoxiao/ppcRscEVCPoster/pursuitSessionFiles")
region_directories = pursuit_functions.file_reader.load_region_files(data_dir, 'Rsc.mat')
print("Extracted structures:", region_directories.keys())

In [None]:
np.unique(region_directories['slRsc']['sessFile'])

In [None]:
for key in region_directories['slRsc'].keys():
    print(f"{key}: {type(region_directories['slRsc'][key])}, shape: {region_directories['slRsc'][key].shape}")

In [None]:
# 1 here is the trial block type
# there is only one valid block type per outer row 
region_directories['slRsc']['blocks'][0][1] 

In [None]:
#for x in region_directories['slRsc']['spkTimes'][0]:
#    print(f"{x:.4}")
# when spikes were detected in seconds

In [None]:
#load pursuit files
data_dir = Path("/Volumes/ASA_Lab/Data/Andy/nitzPurusitData/Sessions")
#data_dir = Path("/Volumes/ASA_Lab/Data/Xiaoxiao/ppcRscEVCPoster/pursuitSessionFiles")

include_files = np.unique(region_directories['slRsc']['sessFile'])
pursuit_session_files = pursuit_functions.file_reader.load_session_files(data_dir, include_files=include_files)


In [None]:
#set a variable to a region_directories file
ca1_directory = pd.DataFrame(region_directories['ca1SL'])
ca3_directory = pd.DataFrame(region_directories['ca3SL'])
rsc_directory = pd.DataFrame(region_directories['rscSL'])

#display first few rows of the dataframe
ca1_directory.head()
ca3_directory.head()
rsc_directory.head()

In [None]:
#convert extracted pursuit session data into dataframes
pursuit_df = {
    filename: pd.DataFrame(file_data).convert_dtypes(dtype_backend="pyarrow")
    for filename, file_data in pursuit_session_files.items()
}

#display first few rows of a dataframe for a specific pursuit file
pursuit_df['KB20_09_pursuitRoot.mat'].head()

In [None]:
#extract trial block indices for all 3 region directory dataframes

def extract_trial_blocks(region_directory):
    """Extracts start and end indices for trial blocks and stores them as separate columns."""

    #ensure column names are clean
    region_directory.columns = region_directory.columns.str.strip()

    #determine the correct block column structure
    has_blocks = "blocks" in region_directory.columns
    has_separate_blocks = all(col in region_directory.columns for col in ["feBlock", "pursuitBlock", "feBlock2"])

    if not has_blocks and not has_separate_blocks:
        print(f"Warning: No recognized block column found in {region_directory}. Skipping extraction.")
        return region_directory

    #convert NumPy arrays to lists only if using separate columns (CA3)
    if has_separate_blocks:
        for col in ["feBlock", "pursuitBlock", "feBlock2"]:
            region_directory[col] = region_directory[col].apply(lambda x: x.tolist() if isinstance(x, np.ndarray) else x)

    # extract trial block start/end indices
    if has_blocks:  # CA1 & RSC (Single "blocks" column)
        region_directory["FE1_start"] = region_directory["blocks"].apply(lambda x: x[0][0])
        region_directory["FE1_end"] = region_directory["blocks"].apply(lambda x: x[0][1])
        region_directory["pursuit_start"] = region_directory["blocks"].apply(lambda x: x[1][0])
        region_directory["pursuit_end"] = region_directory["blocks"].apply(lambda x: x[1][1])
        region_directory["FE2_start"] = region_directory["blocks"].apply(lambda x: x[2][0])
        region_directory["FE2_end"] = region_directory["blocks"].apply(lambda x: x[2][1])
    elif has_separate_blocks:  # CA3 (Separate columns)
        region_directory["FE1_start"] = region_directory["feBlock"].apply(lambda x: x[0] if isinstance(x, list) and len(x) > 0 else None)
        region_directory["FE1_end"] = region_directory["feBlock"].apply(lambda x: x[1] if isinstance(x, list) and len(x) > 0 else None)
        region_directory["pursuit_start"] = region_directory["pursuitBlock"].apply(lambda x: x[0] if isinstance(x, list) and len(x) > 0 else None)
        region_directory["pursuit_end"] = region_directory["pursuitBlock"].apply(lambda x: x[1] if isinstance(x, list) and len(x) > 0 else None)
        region_directory["FE2_start"] = region_directory["feBlock2"].apply(lambda x: x[0] if isinstance(x, list) and len(x) > 0 else None)
        region_directory["FE2_end"] = region_directory["feBlock2"].apply(lambda x: x[1] if isinstance(x, list) and len(x) > 0 else None)

    return region_directory

# apply the function to all region dataframes
ca1_directory = extract_trial_blocks(ca1_directory)
ca3_directory = extract_trial_blocks(ca3_directory)
rsc_directory = extract_trial_blocks(rsc_directory)


In [None]:
ca3_directory.head()


In [None]:
#make the mega dataset

import os
import pandas as pd

def clean_session_filename(file_path):
    """Extract and format the correct session file name from a directory path."""
    cleaned_name = os.path.basename(file_path).replace("\\", "_")

    return (
        cleaned_name if cleaned_name.endswith("_pursuitRoot.mat") 
        else cleaned_name + "pursuitRoot.mat"
    )

def process_row(row, session_col, pursuit_df, region_name):
    """Helper function to process each row and extract trial blocks."""
    sessFile = row[session_col]
    if session_col == "dataDir":
        sessFile = clean_session_filename(sessFile)

    session_data = pursuit_df.get(sessFile)
    if session_data is None:
        print(f"Warning: Session file {sessFile} not found in pursuit_df. Skipping.")
        return None

    blocks = {
        "FE1": (row["FE1_start"], row["FE1_end"]),
        "pursuit": (row["pursuit_start"], row["pursuit_end"]),
        "FE2": (row["FE2_start"], row["FE2_end"])
    }

    
    extracted_blocks = [
        session_data.iloc[int(start): int(end) + 1].assign(
            region=region_name, trial_block=trial_block, sessFile=sessFile
        )
        for trial_block, (start, end) in blocks.items()
        if pd.notna(start) and pd.notna(end)
    ]
    return extracted_blocks

def extract_and_slice_region_data(region_name, region_directory, pursuit_df):
    """Extract trial block indices and slice corresponding session files."""
    session_col = "sessFile" if "sessFile" in region_directory.columns else "dataDir"

    extracted_data = [
        block
        for _, row in region_directory.iterrows()
        for block in (process_row(row, session_col, pursuit_df, region_name) or [])
    ]

    if not extracted_data:
        return pd.DataFrame()

    new_cols = ["region", "trial_block", "sessFile"]
    out = pd.concat(extracted_data, ignore_index=True).dropna(axis=1, how="all").astype({'region': "category", 'trial_block': "category", 'sessFile': "category"})
    existing_columns = [col for col in out.columns if col not in new_cols]
    out = out[existing_columns + new_cols]
    return optimize_pyarrow_dtypes(out)

ca1_data = extract_and_slice_region_data("CA1", ca1_directory, pursuit_df)
ca3_data = extract_and_slice_region_data("CA3", ca3_directory, pursuit_df)
rsc_data = extract_and_slice_region_data("RSC", rsc_directory, pursuit_df)
# TODO: save to parquet file, maybe

# Combine all regions into a single dataset
all_regions_data = pd.concat([ca1_data, ca3_data, rsc_data], ignore_index=True)

# Sort data
all_regions_data = all_regions_data.sort_values(by=["sessFile", "region", "trial_block"]).reset_index(drop=True)

# Display the final dataset
all_regions_data.head()

#save dataframe to a parquet file
all_regions_data.to_parquet("/Users/may/pursuit/ca1_ca3_rsc_pursuit_data.parquet", engine="pyarrow", index=True)

In [None]:
all_regions_data[(all_regions_data["region"] == "CA3") & (all_regions_data["trial_block"] == "pursuit")]



In [None]:
#identifying pursuit, shortcut, and characteristic trials

df = all_regions_data[(all_regions_data["region"] == "CA3") & (all_regions_data["trial_block"] == "pursuit")].copy()
df["ratMoveDir"] = df["ratMoveDir"].astype("float64")
df["laserMoveDir"] = df["laserMoveDir"].astype("float64")
df["laserDist"] = df["laserDist"].astype("float64")
df["ratVel"] = df["ratVel"].astype("float64")
df["laserBearingMD"] = df["laserBearingMD"].astype("float64")

#identifying start of runs

df["movement_alignment"] = np.abs(df["ratMoveDir"] - df["laserMoveDir"])

coherent_movement = (df["movement_alignment"] < np.deg2rad(30)) & (df["ratVel"] > 2) & (df["laserVel"] >2)

approaching_laser = df["laserDist"].diff() < 0

df["start_of_run"] = coherent_movement & approaching_laser

#identifying end of runs
rat_reached_laser = df["laserDist"] < 3 #i can change this threshold

tracking_lost = df["laserDist"].isna() | df["laserDist"].diff().abs() > 50

rat_stops = df["ratVel"].diff() < -2 #looking for velocity drops

df["end_of_run"] = rat_reached_laser | tracking_lost | rat_stops

#identifying trial types
#compute median and variance of egocentric measures for each trial
trial_stats = df.groupby("sessFile")[["laserBearingMD", "laserDist"]].agg(["median", "std"])
trial_stats.columns = ['_'.join(col).strip() for col in trial_stats.columns]
trial_stats = trial_stats.reset_index()
df = df.merge(trial_stats, on="sessFile", how="left")

#defining thresholds for trajectory types
rt_threshold = trial_stats["laserBearingMD_std"].quantile(0.30)
ct_threshold = trial_stats["laserBearingMD_std"].quantile(0.20)
sc_threshold = trial_stats["laserDist_median"].quantile(0.15)

df["trial_type"] = "CT"
df.loc[df["laserBearingMD_std"] > rt_threshold, "trial_type"] = "RT"
df.loc[df["laserBearingMD_std"] < ct_threshold, "trial_type"] = "CT"
df.loc[df["laserDist_median"] < sc_threshold, "trial_type"] = "SC"

print(trial_stats.describe())


In [None]:
#create a histogram for laserBearingMD_std

import matplotlib.pyplot as plt
import seaborn as sns

# Create a histogram for laserBearingMD_std
plt.figure(figsize=(8, 6))
sns.histplot(trial_stats["laserBearingMD_std"], bins=30, kde=True, color="blue", alpha=0.6)

# Overlay RT and CT thresholds as vertical lines
plt.axvline(rt_threshold, color="red", linestyle="dashed", linewidth=2, label=f"RT Threshold ({rt_threshold:.2f})")
plt.axvline(ct_threshold, color="green", linestyle="dashed", linewidth=2, label=f"CT Threshold ({ct_threshold:.2f})")

plt.xlabel("laserBearingMD_std")
plt.ylabel("Frequency")
plt.title("Distribution of laserBearingMD_std with RT/CT Thresholds")
plt.legend()
plt.show()


In [None]:
# create a histogram for laserDist_median
plt.figure(figsize=(8, 6))
sns.histplot(trial_stats["laserDist_median"], bins=30, kde=True, color="purple", alpha=0.6)

# Overlay Shortcut (SC) threshold
plt.axvline(sc_threshold, color="orange", linestyle="dashed", linewidth=2, label=f"SC Threshold ({sc_threshold:.2f})")

plt.xlabel("laserDist_median")
plt.ylabel("Frequency")
plt.title("Distribution of laserDist_median with SC Threshold")
plt.legend()
plt.show()


In [None]:
#plot pseudorandom, characteristic, and shortcut trials

import matplotlib.pyplot as plt

# Get all session IDs where region == "CA3" and trial_block == "pursuit"
session_ids = df[(df["region"] == "CA3") & (df["trial_block"] == "pursuit")]["sessFile"].unique()
trial_types = ["RT", "CT", "SC"]  # Pseudorandom, Characteristic, Shortcut

for session_id in session_ids:  # Loop through all sessions
    plt.figure(figsize=(15, 5))  # Create a wide figure for side-by-side plots

    for i, trial_type in enumerate(trial_types):  # Loop through RT, CT, SC
        df_session = df[(df["sessFile"] == session_id) & (df["trial_type"] == trial_type)].copy()

        # Skip if there are no trials of this type
        if df_session.empty:
            print(f"No {trial_type} trials found for session {session_id}. Skipping {trial_type} plot.")
            continue

        # Convert halffloat columns to float64
        df_session["ratPos_1"] = df_session["ratPos_1"].astype("float64")
        df_session["ratPos_2"] = df_session["ratPos_2"].astype("float64")

        # Create subplot for this trial type
        ax = plt.subplot(1, 3, i + 1)  # 1 row, 3 columns, current index

        # Plot the laser's trajectory
        ax.plot(df_session["laserPos_1"], df_session["laserPos_2"], color="blue", linewidth=0.7, alpha=0.5, label="Laser Trajectory", zorder=1)

        # Plot the rat's trajectory
        ax.plot(df_session["ratPos_1"], df_session["ratPos_2"], color="black", linewidth=0.7, label="Rat Trajectory", zorder=2)

        # Mark start and end points
        ax.scatter(df_session["ratPos_1"].iloc[0], df_session["ratPos_2"].iloc[0], color="green", label="Start", s=100, zorder=3)
        ax.scatter(df_session["ratPos_1"].iloc[-1], df_session["ratPos_2"].iloc[-1], color="red", label="End", s=100, zorder=3)

        ax.set_xlabel("X Position")
        ax.set_ylabel("Y Position")
        ax.set_title(f"{trial_type} - Session {session_id}")
        ax.legend()

    plt.tight_layout()  # Adjust spacing between subplots
    plt.show()


In [None]:
# compute latency between rat position and target position 

import numpy as np
import pandas as pd
from scipy.stats import spearmanr

# Convert all halffloat columns to float64
float_columns = ["ratPos_1", "ratPos_2", "laserPos_1", "laserPos_2"]
df_session[float_columns] = df_session[float_columns].astype("float64")

def compute_latency(df_session, max_shift=60, step_size=2):
    rat_pos = np.vstack((df_session["ratPos_1"], df_session["ratPos_2"])).T
    target_pos = np.vstack((df_session["laserPos_1"], df_session["laserPos_2"])).T

    max_corr = -1
    optimal_shift = 0
    shifts = np.arange(-max_shift, max_shift + 1, step_size)

    print("\nShift (ms) | Spearman Correlation")
    print("-" * 35)

    for shift in shifts:
        if shift > 0:
            shifted_rat = rat_pos[:-shift]
            shifted_target = target_pos[shift:]
        elif shift < 0:
            shifted_rat = rat_pos[-shift:]
            shifted_target = target_pos[:shift]
        else:
            shifted_rat = rat_pos
            shifted_target = target_pos
    
        # Mask to remove NaNs before correlation
        mask = ~np.isnan(shifted_rat).any(axis=1) & ~np.isnan(shifted_target).any(axis=1)
        shifted_rat, shifted_target = shifted_rat[mask], shifted_target[mask]

        # If arrays are empty after removing NaNs, continue to the next shift
        if len(shifted_rat) == 0 or len(shifted_target) == 0:
            print(f"Skipping shift {shift * (1000 / 30):>9.0f} ms due to empty arrays.")
            continue


        # Compute Spearman correlation
        corr, _ = spearmanr(shifted_rat.ravel(), shifted_target.ravel())

        print(f"{shift * (1000 / 30):>9.0f} ms | {corr:.4f}")

        if corr > max_corr:
            max_corr = corr
            optimal_shift = shift

    return optimal_shift * (1000 / 30)
 # Convert from samples to milliseconds

#TODO: make this dynamic for all sessions
# Example usage for a session
session_id = df["sessFile"].unique()[1]  # Select a session
df_session = df[(df["trial_type"] == "RT")].copy()  # Only RT trials

latency = compute_latency(df_session)
print(f"Optimal Latency: {latency:.2f} ms")


In [None]:
df_session.head()


In [None]:
#plot heatmaps of neurons by session
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import zscore

### 1️⃣ LOAD & FILTER YOUR DATA ###
# 🛠 EDIT: Ensure your dataset has the required columns
df = all_regions_data[(all_regions_data["region"] == "CA3") & (all_regions_data["trial_block"] == "pursuit")].copy()

# 🛠 EDIT: Ensure necessary columns exist
if "time" not in df.columns or "trial_block" not in df.columns or "region" not in df.columns:
    raise ValueError("Dataset must have 'time', 'trial_block', and 'region' columns!")

# 🛠 EDIT: Detect spike table columns automatically
spike_columns = [col for col in df.columns if "spkTable" in col]

### 2️⃣ NORMALIZE TIME PER SESSION ###
# Ensure sessFile is treated as a category to avoid issues
df["sessFile"] = df["sessFile"].astype("category")

# Compute start time per session and normalize time
df["time"] = df["time"].astype("float64")  # Convert time to float64
df["start_time"] = df.groupby("sessFile")["time"].transform("min")
df["normalized_time"] = df["time"] - df["start_time"]


# 🛠 EDIT: If time is in milliseconds, divide by 1000 before normalization
# df["normalized_time"] = (df["time"] / 1000) - df.groupby("sessFile")["time"].transform("min")

### 3️⃣ CREATE TIME BINS (16.67s) ###
bin_width = 16.67  # Time bin size in seconds
df["time_bin"] = (df["normalized_time"] // bin_width).astype(int)
df["time_seconds"] = df["time_bin"] * (16.67 / 60)

### 4️⃣ AGGREGATE SPIKE COUNTS PER TIME BIN ###
# Convert spike data to integers, replacing NaNs with 0
df[spike_columns] = df[spike_columns].fillna(0).astype(int)

# Group by "sessFile" and "time_bin" to avoid mixing sessions
df_binned = df.groupby(["sessFile", "time_bin"])[spike_columns].sum()
df_binned = df_binned.reset_index()  # Reset index to keep 'time_bin'
df_binned["time_seconds"] = df_binned["time_bin"] * (16.67 / 60)  # Convert to seconds
df_binned = df_binned.set_index(["sessFile", "time_seconds"])  # Re-index using seconds

### 5️⃣ APPLY Z-SCORE NORMALIZATION PER NEURON, PER SESSION ###
# Ensure z-score works even if variance is 0
df_zscored = df_binned.groupby("sessFile").apply(lambda x: x.apply(lambda col: zscore(col, nan_policy="omit") if col.std() > 0 else col, axis=0))

# Verify results
print(df_zscored.head())

### 6️⃣ PLOT HEATMAP FOR EACH SESSION ###
for session in df["sessFile"].unique():
    session_data = df_zscored.loc[session]  # Get session-specific data

    # Remove 'time_bin' if it exists
    if "time_bin" in session_data.columns:
        session_data = session_data.drop(columns=["time_bin"])

    plt.figure(figsize=(12, 6))
    sns.heatmap(session_data.T, cmap="coolwarm", center=0, cbar=True, xticklabels=10)
    plt.xlabel("Time (s)")
    plt.ylabel("Neurons")
    plt.title(f"Z-Scored Neural Spike Activity - Session {session}")
    plt.show()



### 7️⃣ OPTIONAL: RASTER PLOT (SPIKES OVER TIME) ###
plt.figure(figsize=(12, 6))

for neuron in spike_columns:
    neuron_data = df[df[neuron] > 0]  # Get rows where spikes occurred
    plt.scatter(neuron_data["time_seconds"], np.full(len(neuron_data), neuron), s=2, color="black")

plt.xlabel("Time (s)")  # Updated label
plt.ylabel("Neurons")
plt.title("Spike Raster Plot")
plt.show()



In [None]:
#define the start and end indices
pursuit_start = (indices["pursuit_start"].values[0])
pursuit_end = (indices["pursuit_end"].values[0])
FE1_start = (indices["FE1_start"].values[0])
FE1_end = (indices["FE1_end"].values[0])
FE2_start = (indices["FE2_start"].values[0])
FE2_end = (indices["FE2_end"].values[0])

#get dataframe length
max_rows = len(pursuit_df['KB20_09_pursuitRoot.mat'])

#call get_block_rows() before slicing
FE1_start, FE1_end = pursuit_functions.index_utils.get_block_rows(FE1_start, FE1_end, max_rows)
pursuit_start, pursuit_end = pursuit_functions.index_utils.get_block_rows(pursuit_start, pursuit_end, max_rows)
FE2_start, FE2_end = pursuit_functions.index_utils.get_block_rows(FE2_start, FE2_end, max_rows)

#slice the data rows into trial blocks (FE= free explore)
FE1_block = pursuit_df['KB20_09_pursuitRoot.mat'].iloc[FE1_start:FE1_end +1]if FE1_start is not None else None
pursuit_block = pursuit_df['KB20_09_pursuitRoot.mat'].iloc[pursuit_start:pursuit_end +1] if pursuit_start is not None else None
FE2_block = pursuit_df['KB20_09_pursuitRoot.mat'].iloc[FE2_start:FE2_end +1] if FE2_start is not None else None

print("\ncheck out the blocks")
print(pursuit_block)
print(FE1_block)
print(FE2_block)

In [None]:
import matplotlib.pyplot as plt 

x = pursuit_block["ratPos_1"]
y = pursuit_block["ratPos_2"]

x_center = (x.max() + x.min()) / 2
y_center = (y.max() + y.min()) / 2

x_range = (x.max() - x.min()) / 2
y_range = (y.max() - y.min()) / 2
limit = max(x_range, y_range)  # Use max range to keep aspect equal

fig, ax = plt.subplots(figsize = (6,6))
ax.scatter(x, y, c='blue', label="Points", s=0.05)

ax.set_xlim([x_center - limit, x_center + limit])
ax.set_ylim([y_center -limit, y_center +limit])

ax.set_aspect('equal', adjustable='datalim')

ax.axhline(y_center, color="gray", linewidth=1)
ax.axvline(x_center, color="gray", linewidth=1)
ax.grid(True, linestyle="--", alpha=0.6)
ax.set_xlabel("X Coordinate")
ax.set_ylabel("Y Coordinate")



plt.show()