In [3]:
import sys
path1 = '/Users/steph/berkelab/DA_maze/DA/Modules/'
path2 = '/Users/steph/berkelab/DA_maze/Behavior/Modules/'
sys.path += [path1,path2]

import numpy as np
import pandas as pd
import math
from matplotlib import pyplot as plt
from multi_rat_da import *
from tmat_ops import *
from hexLevelAnalyses import get_sigRats_fromMeanList
from photometryQuantifications import *
#from scipy.stats import wilcoxonz
from celluloid import Camera
from pdf2image import convert_from_path
from mpl_toolkits.axes_grid1 import make_axes_locatable
%matplotlib qt

# Function definitions

In [50]:
def normalize(data):
    """ Given an array of numerical data, returns that data scaled between 0-1 """
    
    return (data - np.min(data)) / (np.max(data) - np.min(data))

def scale_to_bounds(data, bounds):
    """ Scales the values of the 'data' list to match the min and max of the 'bounds' list

    Parameters:
    data (list): The data list to be scaled.
    bounds (list): The target bounds to which the data should be scaled.

    Returns:
    scaled_data (list): The scaled data list within the specified bounds.
    """

    return (data-np.min(data)) / (np.max(data)-np.min(data)) * (np.max(bounds)-np.min(bounds)) + np.min(bounds)


def get_bounds(df):
    """ Given a dataframe with x and y fields, returns min(x), min(y), max(x), max(y) """
    
    return np.nanmin(df.x), np.nanmin(df.y), np.nanmax(df.x), np.nanmax(df.y)


def expand_second_list(first_list, second_list):
    """ Expand the second list to match the size of the first list.

    Args:
    first_list (list): The list with counts or indices.
    second_list (list): The list to be expanded.

    Returns:
    list: The expanded second list.

    Example:
    >>> first_list = [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3]
    >>> second_list = [4, 5, 6, 7]
    >>> expanded_second_list = expand_second_list(first_list, second_list)
    >>> print(expanded_second_list)
    [4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7]
    """
    
    expanded_second_list = [second_list[val] for val in first_list]
    return expanded_second_list


def moving_average(data, window_size=125):
    """ Smooth a jumpy signal using a simple moving average filter.
    
    Parameters:
    data (list): List of values to be smoothed.
    window_size (int): Size of the moving average window. Default of 125 = 0.5s (at 250 Hz sample frequency)
    """
    
    if window_size <= 0 or window_size >= len(data):
        raise ValueError("Invalid window size")
        
    smoothed_data = []
    for i in range(len(data)):
        start = max(0, i - window_size + 1)
        end = i + 1
        window = data[start:end]
        smoothed_value = sum(window) / len(window)
        smoothed_data.append(smoothed_value)

    return smoothed_data


def convert_indices_to_time(indices, start_point=0, sample_frequency=250):
    """ Given an array of indices and a sampling frequency (Hz), return corresponding time array (s)
    
    The returned time array will start from 0 seconds by default (better for plotting)
    start_point = -1 means the start time of the returned array will not be adjusted
    """
    
    seconds_per_sample = 1/sample_frequency
    if start_point == -1:
        return indices*seconds_per_sample
    else: # if not explicity set to -1, I choose to assume we want to start the time from 0 
        return (indices-np.nanmin(indices))*seconds_per_sample
    

def find_dead_ends(tmat):
    """ Given a transition matrix, returns a list of hexes that are dead ends """
    
    dead_ends = np.unique(np.where(tmat==1)[0])
    dead_ends = np.delete(dead_ends,np.isin(dead_ends,[1,2,3]))
    
    return dead_ends


def path_to_dead_end(tmat, next_hex, path=[]):
    """ Given a transition matrix and a dead end hex, recursively finds the path of hexes to that dead end 
    
    Note: Third argument [] or name of empty list is required to ensure you get a new list 
    instead of modifying the same list from the last time this function was run
    """

    path.append(next_hex)
    next_hexes = np.where(tmat[:, next_hex]==0.5)[0]
    
    for hex in next_hexes:
        if hex not in path and hex not in [1,2,3]:
            path_to_dead_end(tmat, hex, path)
            
    return list(path)


def get_all_dead_end_paths(tmat, min_length=1):
    """ Given a transition matrix, get all paths to dead end hexes.
    
    Parameters: 
    tmat: the transition matrix
    min_length (int): the minimum length of dead end path to include
    
    Returns: a list of lists, where each list starts with a dead end hex and 
    includes the path (in order) of hexes to that dead end.
    """

    all_dead_end_paths = []
    dead_ends = find_dead_ends(tmat)
    
    for hex in dead_ends:
        path = path_to_dead_end(tmat, hex, [])
        if len(path) >= min_length:
            all_dead_end_paths.append(path)
        
    return list(all_dead_end_paths)


def find_dead_end_path_for_hex(hexes, all_dead_end_paths):
    """ Returns the dead end path (or list of paths) a hex (or hexes) is in, or [] if not in a dead end.
    
    Parameters: 
    hexes: A list of hexes, or a single hex
    all_dead_end_paths: List of all dead end paths in the maze (in hexes)
    
    Returns: 
    A list of dead end paths the same length as hexes, where each entry is the dead end path that the hex is in
    
    (I use this to add a column of dead end paths to the dataframe so I can group by dead end)
    """
    
    # if we only want to check a single hex, that's fine
    if isinstance(hexes, int):
        return next((path for path in dead_end_paths if hexes in path), [])
    
    # otherwise, loop through all hexes and return a list of dead end paths for each hex
    result = []
    for hex in hexes:
        dead_end_path = next((path for path in dead_end_paths if hex in path), [])
        result.append(dead_end_path)
        
    return result


def divide_into_sections(arr):
    """ Divides a sorted list of numbers into sections of consecutive numbers increasing by 1.
    
    Parameters: 
    arr (list): A sorted list of numbers with potential breaks. 
    
    Returns: 
    result (list): A list where each section of increasing numbers is represented by consecutive integers.
    
    (I use this function to find each distinct time a rat enters a dead end path.)

    Example:
    >>> input_array = [7, 8, 9, 10, 11, 55, 56, 57, 58, 59, 60, 61, 990, 991, 992, 993]
    >>> result_array = divide_into_sections(input_array)
    >>> print(result_array)
    [1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3]
    """

    result = [1]*len(arr)
    current_section = 1
    
    for i in range(1, len(arr)):
        if arr[i] != arr[i-1] + 1:
            current_section += 1
        result[i] = current_section
        
    return result


def count_entries_for_each_dead_end(dead_end_paths, list_of_entries):
    """ Count the number of times the rat entered a specific dead end

    Args:
    dead_end_paths (list): A list representing the areas a person entered
    list_of_entries (list): A list representing distinct instances of area entry

    Returns:
    result: A list of counts, where each count represents how many times the rat entered a specific dead end
    dead_end_counts: A dictionary representing the number of times the rat entered the path to each dead end
    """
    
    result = []
    dead_end_counts = {}
    distinct_entries = set()
    
    # replace each dead end path with just the dead end hex to make things easier
    dead_ends = [sublist[0] for sublist in dead_end_paths]
    
    for de, entry in zip(dead_ends, list_of_entries):    
        # if this is a new entry into a dead end, start or add to the entry count for this dead end
        if entry not in distinct_entries:
            if de not in dead_end_counts: 
                dead_end_counts[de] = 1
            else:
                dead_end_counts[de] += 1
            # add this to our set of distinct entries so we don't double count it
            distinct_entries.add(entry)
            
        result.append(dead_end_counts[de])
        
    return result, dead_end_counts


def find_turnaround_index_hexes(rats_hex_path, dead_end_path):
    """ Find the index where the rat turned around when going down a dead end path, using hexes.
    
    Parameters: 
    rats_hex_path: The path the rat took through the dead end (in hexes)
    dead_end_path: The dead end path that the rat is in (hexes)
    
    Returns: 
    entered_index, exited_index: indices of the last time the rat entered and exited the furthest hex in the path
    """
    
    # find the furthest hex along the dead end path we get to before turning around
    for hex in dead_end_path:
        if hex in rats_hex_path.values:
            furthest_hex = hex
            break 
    
    # we didn't get past the first hex, entered/exited indices are just the first and last indices
    first_hex = rats_hex_path.values[0]
    if first_hex == furthest_hex:
        entered_index = 0
        exited_index = len(rats_hex_path)-1
        return entered_index, exited_index
    
    # otherwise, find the indices of the last time we entered and exited the furthest hex
    turnaround_indices = []
    for i in range(len(rats_hex_path) - 1):
        if rats_hex_path.values[i] != furthest_hex and rats_hex_path.values[i+1] == furthest_hex:
            entered_index = i+1
        elif rats_hex_path.values[i] == furthest_hex and rats_hex_path.values[i+1] != furthest_hex:
            exited_index = i
    
    return entered_index, exited_index


def find_furthest_point_from_endpoints(x_coords, y_coords):
    """ Finds the index of the (x,y) coordinate that is furthest from both the first and last coordinates.
    
    Parameters: 
    x_coords: List of x coordinates
    y_coords: List of y coordinates

    Returns:
    furthest_index (int): The index of the furthest coordinate
    furthest_coordinate (tuple): The (x,y) coordinate that is furthest from both endpoints
    """
    
    start_x, start_y = x_coords.iloc[0], y_coords.iloc[0]
    end_x, end_y = x_coords.iloc[-1], y_coords.iloc[-1]

    max_combined_distance = 0
    furthest_coordinate = []
    furthest_index = []

    for i, (x, y) in enumerate(zip(x_coords, y_coords)):
        distance_to_start = math.sqrt((x - start_x) ** 2 + (y - start_y) ** 2)
        distance_to_end = math.sqrt((x - end_x) ** 2 + (y - end_y) ** 2)

        combined_distance = distance_to_start + distance_to_end
        if combined_distance > max_combined_distance:
            max_combined_distance = combined_distance
            furthest_coordinate = (x, y)
            furthest_index = i

    furthest_index = furthest_index + min(x_coords.index) # adjust for index in the dataframe

    return furthest_index, furthest_coordinate


def find_dead_end_turnaround(x_coords, y_coords, rats_hex_path, dead_end_path):
    """ Finds the index where the rat turned around in a dead end path """
    
    # find where the rat entered and exited the furthest hex in the dead end path
    entered_index, exited_index = find_turnaround_index_hexes(rats_hex_path, dead_end_path)
    
    # find the furthest point in that hex from the entry and exit points
    x = x_coords[entered_index:exited_index]
    y = y_coords[entered_index:exited_index]
    turnaround_index, turnaround_coordinate = find_furthest_point_from_endpoints(x, y)
    
    return turnaround_index


def stats_for_dead_end_entry(rats_hex_path):
    """ Given the rat's path in a dead end (in hexes), return hexes traveled and time spent in this dead end """
    
    hexes_traveled = len(set(rats_hex_path))
    time_spent = len(rats_hex_path)*1/250 # time (s) = number of samples * seconds per sample
    
    return hexes_traveled, time_spent


def get_stats(df):
    """ Given a dataframe, iterates through each dead end entry and calculates the max hexes traveled and
    the time spent for each dead end entry. Returned lists have one entry for each entry into a dead end.
    
    Returns:
    hexes_traveled (list): Number of hexes traveled for each dead end entry
    time_spent (list): The amount of time (in seconds) spent in that dead end entry
    """
    
    num_dead_end_entries = max(df.dead_end_entry)
    hexes_traveled = [0]*(num_dead_end_entries)
    time_spent = [0]*(num_dead_end_entries)
    
    for entry in list(range(1, num_dead_end_entries+1)):
        hexes = df.hexlabels[df.dead_end_entry==entry]
    
        # get stats for this entry
        hexes_traveled[entry-1], time_spent[entry-1] = stats_for_dead_end_entry(hexes)
    
    return hexes_traveled, time_spent

def get_centered_indices(df):
    """ Given a dataframe, loops through each dead end entry and calculates new indices for that entry
    centered around the rat's turnaround point in that dead end. For example, the rat's turnaround point
    would now have index 0, indices leading up to that point are ..., -2, -1, and indices following the 
    turnaround point are 1, 2, ...
    
    Returns:
    centered_indices (list): list of the same length as the dataframe indicating the centered indices for
    each of the dead end entries """
    
    num_dead_end_entries = max(df.dead_end_entry)
    centered_indices = []
    for entry in list(range(1, num_dead_end_entries+1)):
        entry_df = dead_end_df[dead_end_df.dead_end_entry==entry]
        dead_end_path = entry_df.dead_end_path.iloc[0]
        turnaround_index = find_dead_end_turnaround(entry_df.x, entry_df.y, entry_df.hexlabels, dead_end_path)
        indices = (entry_df.index - turnaround_index)
        centered_indices += indices.tolist()
    return centered_indices

def get_mean_and_std(df, data_name, start_index, end_index):
    """ Messy function for now to take a dataframe, desired data (ex 'green_z_scored'), 
    and return the mean and standard deviation over all dead end entries, centered around
    the dead end turnaround point. 
    """
    
    filtered_df = df[(df['centered_indices'] >= start_index) & (df['centered_indices'] <= end_index)]
    num_dead_end_entries = max(df.dead_end_entry)+1
    
    # Create a common index range from start to end
    common_indices = np.arange(start_index, end_index + 1)
    extended_lists = []

    # Extend and interpolate values for each list to match the common index range
    for entry in list(range(1, num_dead_end_entries)):
        data = filtered_df[data_name][filtered_df.dead_end_entry==entry]
        valid_indices = filtered_df.centered_indices[filtered_df.dead_end_entry==entry]
    
        # Create an array of NaN values with the length of the common index range
        extended_values = np.full(end_index - start_index + 1, np.nan)

        # Assign the valid values to the corresponding positions in the extended array
        extended_values[valid_indices - start_index] = data
        extended_lists.append(extended_values)
    
    mean_array = np.nanmean(extended_lists, axis=0)
    std_array = np.nanstd(extended_lists, axis=0)
    
    return mean_array, std_array

def plot_mean_with_error(mean_array, error_array, start_index=-500, end_index=500, ylabel=None):
    """ Given a mean array and error array (SD, SEM, etc), plots the mean with the error as a shaded region
    
    Optionally specify the start index, end index, and y axis label. Indices for the x axis will
    automatically be converted to time assuming a sampling frequency of 250 Hz """
    
    upper_bound = mean_array + error_array
    lower_bound = mean_array - error_array
    common_indices = np.arange(start_index, end_index + 1)
    time = convert_indices_to_time(common_indices, -1)

    plt.figure()
    plt.plot(time, mean_array, linewidth=2)
    plt.fill_between(time, lower_bound, upper_bound, color='blue', alpha=0.2)
    plt.xlabel('Time (s)')
    if ylabel is not None:
        plt.ylabel(ylabel)

Do the dataframe loading and slicing in its own cell because it is slooow

In [55]:
#specific_path = "IM-1478/07202022"
specific_path = "IM-1478/07172022"
phot_path_base = "/Volumes/Tim/Photometry/"
ephys_path_base = "/Volumes/Tim/Ephys/"
min_dead_end_length = 3

# Load transition matrix and dataframe for this session
transition_mat = np.load(phot_path_base+specific_path+"/tmat.npy")
df = reduce_mem_usage(pd.read_csv(ephys_path_base+specific_path+"/phot_decode_df_withHexStates.csv"))
minx, miny, maxx, maxy = get_bounds(df)

# Load pdf image of hexes from this session to use as the plot background
image = convert_from_path(phot_path_base+specific_path+"/hex_layout.pdf")
image[0].save('hex_background.png', 'png')

# Use the transition matrix to get paths to dead end hexes
dead_end_paths = get_all_dead_end_paths(transition_mat, min_dead_end_length)

# Get subset of dataframe where rat the is in dead ends
dead_end_hexes = [hex for path in dead_end_paths for hex in path]
indices = [i for i in df.index if df['hexlabels'][i] in dead_end_hexes]
dead_end_df = df.loc[df.index.isin(indices)].copy()

Memory usage of dataframe is 489.81 MB
Memory usage after optimization is: 114.80 MB
Decreased by 76.6%


### Data Preprocessing
- Add a bunch of columns to the dataframe to make it easy to slice and iterate based on different things
- Print some stats about this session

In [56]:
# Add column to the dataframe labeling which dead end path we are in
print("Paths to dead end hexes:", dead_end_paths)
path = find_dead_end_path_for_hex(dead_end_df.hexlabels, dead_end_paths)
dead_end_df['dead_end_path'] = path

# Add column to the dataframe counting each entry into a dead end path (one count for all dead end paths)
dead_end_entry = divide_into_sections(dead_end_df.index)
dead_end_df['dead_end_entry'] = dead_end_entry
print("The rat entered a dead end path", max(dead_end_entry), "times in this session.")

# Add column to the dataframe counting the entry into each dead end path (different counts for each dead end path)
individual_dead_end_entry, dead_end_counts = count_entries_for_each_dead_end(path, dead_end_entry)
dead_end_df['individual_dead_end_entry'] = individual_dead_end_entry
print("Entry counts for each dead end path:", dead_end_counts)

# Add columns to the dataframe indicating how far down a dead end path we got and how long we spent in that path
hexes_traveled_list, time_spent_list = get_stats(dead_end_df)
dead_end_entry_from_zero = [x-1 for x in dead_end_entry] # adjust to start at 0 so we can use as an index
dead_end_df['max_hexes_traveled'] = expand_second_list(dead_end_entry_from_zero, hexes_traveled_list)
dead_end_df['time_in_dead_end'] = expand_second_list(dead_end_entry_from_zero, time_spent_list)

# Add column to the dataframe of indices centered around the turnaround point
dead_end_df['centered_indices'] = get_centered_indices(dead_end_df)

Paths to dead end hexes: [[9, 12, 16, 20]]
The rat entered a dead end path 21 times in this session.
Entry counts for each dead end path: {9: 21}


In [246]:
print(dead_end_df.columns)

Index(['Unnamed: 0', 'index', 'green_z_scored', 'port', 'rwd', 'x', 'y',
       'nom_rwd_a', 'nom_rwd_b', 'beamA', 'beamB', 'beamC', 'tri', 'block',
       'nom_rwd_c', 'hexlabels', 'lenAC', 'lenBC', 'lenAB', 'dtop', 'fiberloc',
       'session_type', 'session', 'rat', 'date', 'lastleave', 'nom_rwd_chosen',
       'vel', 'pairedHexStates', 'acc', 'decodeDistToRat', 'decodeHPD',
       'x_pred', 'y_pred', 'mua', 'theta_env', 'theta_phase',
       'theta_phase_bin', 'pred_hexlabels', 'pred_hexState', 'dead_end_path',
       'dead_end_entry', 'individual_dead_end_entry', 'max_hexes_traveled',
       'time_in_dead_end', 'centered_indices'],
      dtype='object')


In [59]:
# Desired start and end indices
start_index = -1000
end_index = 1000
#data_name = 'vel'
#data_name = 'green_z_scored'
data_name = 'decodeDistToRat'

# Plot mean and standard deviation of whatever data we're interested in
mean_arr, std_arr = get_mean_and_std(dead_end_df, data_name, start_index, end_index)
plot_mean_with_error(mean_arr, std_arr, start_index, end_index, data_name)

# Plot all of the lines just to see if we have anything weird going on
plot_all_lines_check = False
if plot_all_lines_check:
    fig, ax = plt.subplots()
    filtered_df = dead_end_df[(dead_end_df['centered_indices'] >= start_index) & (dead_end_df['centered_indices'] <= end_index)]
    num_dead_end_entries = max(dead_end_df.dead_end_entry)+1
    for entry in list(range(1, num_dead_end_entries)):
        data = filtered_df[data_name][filtered_df.dead_end_entry==entry]
        idx = filtered_df.centered_indices[filtered_df.dead_end_entry==entry]
        ax.plot(convert_indices_to_time(idx, -1), data)
    ax.set_xlabel("Time (s)")
    ax.set_ylabel(data_name)
    plt.show()

In [62]:
hexes_traveled_list, time_spent_list = get_stats(dead_end_df)
plot_histograms(hexes_traveled_list, time_spent_list)

# ok, do the same thing for an individual dead end!
# TODO: Fix this part!!!
for path in dead_end_paths:
    slice_df = dead_end_df[dead_end_df['dead_end_path'].apply(lambda x: np.array_equal(x, path))]
    hexes_traveled_list, time_spent_list = get_stats(slice_df)
    plot_histograms(hexes_traveled_list, time_spent_list)

In [61]:
def plot_histograms(hexes_traveled_list, time_spent_list):
    # Histogram of time spent in dead end
    fig, axs = plt.subplots(2, 2)
    axs[0,0].hist(time_spent_list, bins=20, edgecolor='black', alpha=0.7)
    axs[0,0].set_xlabel('Time (seconds)')
    axs[0,0].set_ylabel('Count')
    axs[0,0].set_title('Distribution of time spent in a dead end')

    # Histogram of hexes traveled down dead end
    axs[1,0].hist(hexes_traveled_list, bins=3, edgecolor='black', alpha=0.7)
    axs[1,0].set_xlabel('Hexes traveled')
    axs[1,0].set_ylabel('Count')
    axs[1,0].set_title('Distribution of hexes traveled down dead end')

    # now show these over time
    axs[0,1].plot(time_spent_list, marker='o', linestyle='-')
    axs[0,1].set_xlabel('Dead end entry')
    axs[0,1].set_ylabel('Time (s)')
    axs[0,1].set_title('Time spent in dead end')

    axs[1,1].plot(hexes_traveled_list, marker='o', linestyle='-')
    axs[1,1].set_xlabel('Dead end entry')
    axs[1,1].set_ylabel('Hexes')
    axs[1,1].set_title('Hexes traveled down dead end')

    plt.tight_layout()
    plt.show()

### Plot a bunch of things for a single entry into a dead end path

In [273]:
#entries = list(range(1, 38))
entries = [4]
for entry in entries:

    # Plot instance(s) of when a rat is in a dead end
    scale = 60
    xshift = 10
    yshift = -30

    plot_direction_changes = True # if we want to highlight direction changes on the plot

    # get subsets of data (it's faster than constantly accessing the dataframe (?) )
    x = dead_end_df.x[dead_end_df.dead_end_entry==entry]
    y = dead_end_df.y[dead_end_df.dead_end_entry==entry]
    x_pred = dead_end_df.x_pred[dead_end_df.dead_end_entry==entry]
    y_pred = dead_end_df.y_pred[dead_end_df.dead_end_entry==entry]
    dopamine = dead_end_df.green_z_scored[dead_end_df.dead_end_entry==entry]
    decode_dist = dead_end_df.decodeDistToRat[dead_end_df.dead_end_entry==entry]
    theta_phase = dead_end_df.theta_phase[dead_end_df.dead_end_entry==entry]
    idx = dead_end_df.index[dead_end_df.dead_end_entry==entry]
    #vel = moving_average(dead_end_df.vel[dead_end_df.dead_end_entry==entry], 63)
    time = convert_indices_to_time(idx)
    hexes = dead_end_df.hexlabels[dead_end_df.dead_end_entry==entry]
    
    # get stats for this entry
    hexes_traveled, time_spent = stats_for_dead_end_entry(hexes)
    print("Hexes traveled:", hexes_traveled)
    print("Time spent in this dead end:", time_spent, "seconds")
    

    # plot a bunch of stuff for a dead end entry!
    img = plt.imread("hex_background.png")
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2,2)

    ax1.set_title("Rat's actual position")
    ax1.imshow(img, extent=[minx+xshift-scale, maxx+xshift+scale, maxy+yshift+scale, miny+yshift-scale])
    im = ax1.scatter(x, y, c=time, alpha=0.1, cmap='autumn')
    divider = make_axes_locatable(ax1)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    cbar = fig.colorbar(im, cax=cax, orientation='vertical')
    cbar.solids.set(alpha=1)

    ax3.set_title("Decoded position")
    ax3.imshow(img, extent=[minx+xshift-scale, maxx+xshift+scale, maxy+yshift+scale, miny+yshift-scale])
    ax3.scatter(x_pred, y_pred, c=time, alpha=0.5, cmap='autumn')

    ax2.set_title("Dopamine")
    t1 = ax2.scatter(time, scale_to_bounds(theta_phase,dopamine), c='grey', alpha=0.1)
    t1.set_label('Theta phase')
    ax2.scatter(time, dopamine)
    #ax2.set_xlabel("Time (s)")
    ax2.legend()

    ax4.set_title("Decode distance")
    t2 = ax4.scatter(time, scale_to_bounds(theta_phase,decode_dist), c='grey', alpha=0.1)
    t2.set_label('Theta phase')
    ax4.plot(time, decode_dist)   
    ax4.set_xlabel("Time (s)")
    ax4.legend()
    
    #ax4.set_title("Velocity")
    #ax4.set_xlabel("Time (s)")
    #ax4.scatter(time, vel)     

    if plot_direction_changes:
        # find turnaround point
        dead_end_path = dead_end_df.dead_end_path[dead_end_df.dead_end_entry==entry].iloc[0]
        turnaround_index = find_dead_end_turnaround(x, y, hexes, dead_end_path)
    
        # adjust index for plotting because time and smoothed velocity start from 0
        adjusted_index = turnaround_index - min(x.index) 
    
        # highlight this point on the plots in blue
        ax1.scatter(x[turnaround_index], y[turnaround_index], c='aqua')
        ax3.scatter(x_pred[turnaround_index], y_pred[turnaround_index], c='aqua')
        ax2.scatter(time[adjusted_index], dopamine[turnaround_index], c='aqua')
        ax4.scatter(time[adjusted_index], decode_dist[turnaround_index], c='aqua')
        #ax4.scatter(time[adjusted_index], vel[adjusted_index], c='aqua')

Hexes traveled: 3
Time spent in this dead end: 5.86 seconds


# (De)function definitions 
(aka functions I wrote once but no longer use but don't wanna delete yet so they r defunct haha get it)

In [None]:
def calculate_direction_vector(x1, y1, x2, y2):
    """ Calculate the direction vector from (x1, y1) to (x2, y2) """
    
    delta_x = x2 - x1
    delta_y = y2 - y1
    magnitude = math.sqrt(delta_x**2 + delta_y**2)
    if magnitude == 0:
        return (0, 0)  # Avoid division by zero
    return (delta_x / magnitude, delta_y / magnitude)


def find_direction_change_indices(x_coords, y_coords, angle_threshold=1.57): # 1.5708 = 90 degrees in radians
    """ Given a list of x and y coordinates and an angle_threshold, find the indices where 
    the direction between sucessive coordinates changed more than angle_threshold radians 
    """
    
    if len(x_coords) < 3 or len(y_coords) < 3 or len(x_coords) != len(y_coords):
        return []  # Invalid data
    
    direction_change_indices = []
    
    for i in range(1, len(x_coords) - 1):
        prev_direction = calculate_direction_vector(x_coords[i-1], y_coords[i-1], x_coords[i], y_coords[i])
        next_direction = calculate_direction_vector(x_coords[i], y_coords[i], x_coords[i+1], y_coords[i+1])
        
        # Check if either of the directions is zero (no movement)
        if prev_direction == (0, 0) or next_direction == (0, 0):
            continue
        
        # Calculate the angle between direction vectors using the dot product
        dot_product = prev_direction[0] * next_direction[0] + prev_direction[1] * next_direction[1]
        angle = math.acos(max(-1, min(1, dot_product)))
        
        # Check if the angle is greater than the threshold
        if angle > angle_threshold:
            # uncomment the print statement below for useful debugging 
            #print("angle: ",angle," xy-1: ",x_coords[i-1],",",y_coords[i-1]," xy: ",x_coords[i],",", y_coords[i], " xy+1: ",x_coords[i+1],",",y_coords[i+1],"prev dir",prev_direction,"next dir",next_direction)
            direction_change_indices.append(i)
    
    return direction_change_indices

def find_turnaround_index_xy(x, y):
    """ Find the index of the furthest coordinate from the first coordinate in a list of x, y coordinates """
    
    max_distance = 0
    turnaround_index = []

    for i in range(len(x)):
        distance = math.sqrt((x[i] - x[0]) ** 2 + (y[i] - y[0]) ** 2)
        if distance > max_distance:
            max_distance = distance
            turnaround_index = [i]

    return turnaround_index