In [1]:
import sys
path1 = '/Users/steph/berkelab/DA_maze/DA/Modules/'
path2 = '/Users/steph/berkelab/DA_maze/Behavior/Modules/'
sys.path += [path1,path2]

import numpy as np
import pandas as pd
import math
from matplotlib import pyplot as plt
from turnaround_functions import *
from multi_rat_da import *
#from tmat_ops import *
#from hexLevelAnalyses import get_sigRats_fromMeanList
#from photometryQuantifications import *
#from scipy.stats import wilcoxonz
from celluloid import Camera
from pdf2image import convert_from_path
from mpl_toolkits.axes_grid1 import make_axes_locatable
%matplotlib qt

### Data Preprocessing
- Load dataframe, transition matrix, and hex background
- Add a bunch of columns to the dataframe to make it easy to slice and iterate based on different things
- Print some stats about this session

In [2]:
specific_path = "IM-1478/07202022"
#specific_path = "IM-1478/07172022"
phot_path_base = "/Volumes/Tim/Photometry/"
ephys_path_base = "/Volumes/Tim/Ephys/"
min_dead_end_length = 3

# Load transition matrix and dataframe for this session
transition_mat = np.load(phot_path_base+specific_path+"/tmat.npy")
df = reduce_mem_usage(pd.read_csv(ephys_path_base+specific_path+"/phot_decode_df_withHexStates.csv"))
minx, miny, maxx, maxy = get_bounds(df)

# Load pdf image of hexes from this session to use as the plot background
image = convert_from_path(phot_path_base+specific_path+"/hex_layout.pdf")
image[0].save('hex_background.png', 'png')

# Use the transition matrix to get paths to dead end hexes
dead_end_paths = get_all_dead_end_paths(transition_mat, min_dead_end_length)

# Get subset of dataframe where rat the is in dead ends
dead_end_hexes = [hex for path in dead_end_paths for hex in path]
indices = [i for i in df.index if df['hexlabels'][i] in dead_end_hexes]
dead_end_df = df.loc[df.index.isin(indices)].copy()

# Add column to the dataframe labeling which dead end path we are in
print("Paths to dead end hexes:", dead_end_paths)
path = find_dead_end_path_for_hex(dead_end_df.hexlabels, dead_end_paths)
dead_end_df['dead_end_path'] = path

# Add column to the dataframe counting each entry into a dead end path (one count for all dead end paths)
dead_end_entry = divide_into_sections(dead_end_df.index)
dead_end_df['dead_end_entry'] = dead_end_entry
print("The rat entered a dead end path", max(dead_end_entry), "times in this session.")

# Add column to the dataframe counting the entry into each dead end path (different counts for each dead end path)
individual_dead_end_entry, dead_end_counts = count_entries_for_each_dead_end(path, dead_end_entry)
dead_end_df['individual_dead_end_entry'] = individual_dead_end_entry
print("Entry counts for each dead end path:", dead_end_counts)

# Add columns to the dataframe indicating how far down a dead end path we got and how long we spent in that path
hexes_traveled_list, time_spent_list = get_stats(dead_end_df)
dead_end_entry_from_zero = [x-1 for x in dead_end_entry] # adjust to start at 0 so we can use as an index
dead_end_df['max_hexes_traveled'] = expand_second_list(dead_end_entry_from_zero, hexes_traveled_list)
dead_end_df['time_in_dead_end'] = expand_second_list(dead_end_entry_from_zero, time_spent_list)

# Add column to the dataframe of indices centered around the turnaround point
dead_end_df['centered_indices'] = get_centered_indices(dead_end_df)

Memory usage of dataframe is 522.47 MB
Memory usage after optimization is: 122.46 MB
Decreased by 76.6%
Paths to dead end hexes: [[11, 8, 6], [18, 22, 27], [21, 17, 13]]
The rat entered a dead end path 37 times in this session.
Entry counts for each dead end path: {21: 6, 11: 14, 18: 17}


In [3]:
# Print columns so we remember what data we're working with
print(dead_end_df.columns)

Index(['Unnamed: 0', 'index', 'green_z_scored', 'port', 'rwd', 'x', 'y',
       'nom_rwd_a', 'nom_rwd_b', 'beamA', 'beamB', 'beamC', 'tri', 'block',
       'nom_rwd_c', 'hexlabels', 'lenAC', 'lenBC', 'lenAB', 'dtop', 'fiberloc',
       'session_type', 'session', 'rat', 'date', 'lastleave', 'nom_rwd_chosen',
       'vel', 'pairedHexStates', 'acc', 'decodeDistToRat', 'decodeHPD',
       'x_pred', 'y_pred', 'mua', 'theta_env', 'theta_phase',
       'theta_phase_bin', 'pred_hexlabels', 'pred_hexState', 'dead_end_path',
       'dead_end_entry', 'individual_dead_end_entry', 'max_hexes_traveled',
       'time_in_dead_end', 'centered_indices'],
      dtype='object')


### Look at data across all entries into dead end paths
Plot mean (with shaded error) of the rat's velocity, dopamine, and distance from decode as he turns around in a dead end

In [4]:
data_to_plot = ['vel','green_z_scored','decodeDistToRat']

# Desired start and end indices (250 indices = 1 second)
start_index = -500
end_index = 500

# Calculate and plot mean and standard deviation
for data_name in data_to_plot:
    mean_arr, std_arr = get_mean_and_std(dead_end_df, data_name, start_index, end_index)
    plot_mean_with_error(mean_arr, std_arr, start_index, end_index, data_name)

# Optional - Plot all of the lines just to see if we have anything weird going on (e.g. one session skewing the mean)
plot_all_lines_check = False
if plot_all_lines_check:
    fig, ax = plt.subplots()
    filtered_df = dead_end_df[(dead_end_df['centered_indices'] >= start_index) & (dead_end_df['centered_indices'] <= end_index)]
    num_dead_end_entries = max(dead_end_df.dead_end_entry)+1
    for entry in list(range(1, num_dead_end_entries)):
        data = filtered_df[data_name][filtered_df.dead_end_entry==entry]
        idx = filtered_df.centered_indices[filtered_df.dead_end_entry==entry]
        ax.plot(convert_indices_to_time(idx, -1), data)
    ax.set_xlabel("Time (s)")
    ax.set_ylabel(data_name)
    plt.show()

### Take a closer look at a single entry into a dead end path

In [5]:
#entries = list(range(1, max(dead_end_df.dead_end_entry))) # if we want to look at all entries individually
entries = [4] # or just a single entry
for entry in entries:

    # Plot instance(s) of when a rat is in a dead end
    scale = 60
    xshift = 10
    yshift = -30

    plot_direction_changes = True # if we want to highlight direction changes on the plot

    # get subsets of data (it's faster than constantly accessing the dataframe (?) )
    x = dead_end_df.x[dead_end_df.dead_end_entry==entry]
    y = dead_end_df.y[dead_end_df.dead_end_entry==entry]
    x_pred = dead_end_df.x_pred[dead_end_df.dead_end_entry==entry]
    y_pred = dead_end_df.y_pred[dead_end_df.dead_end_entry==entry]
    dopamine = dead_end_df.green_z_scored[dead_end_df.dead_end_entry==entry]
    decode_dist = dead_end_df.decodeDistToRat[dead_end_df.dead_end_entry==entry]
    theta_phase = dead_end_df.theta_phase[dead_end_df.dead_end_entry==entry]
    idx = dead_end_df.index[dead_end_df.dead_end_entry==entry]
    #vel = moving_average(dead_end_df.vel[dead_end_df.dead_end_entry==entry], 63)
    time = convert_indices_to_time(idx)
    hexes = dead_end_df.hexlabels[dead_end_df.dead_end_entry==entry]
    
    # get stats for this entry
    hexes_traveled, time_spent = stats_for_dead_end_entry(hexes)
    print("Hexes traveled:", hexes_traveled)
    print("Time spent in this dead end:", time_spent, "seconds")

    # plot a bunch of stuff for a dead end entry!
    img = plt.imread("hex_background.png")
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2,2)

    ax1.set_title("Rat's actual position")
    ax1.imshow(img, extent=[minx+xshift-scale, maxx+xshift+scale, maxy+yshift+scale, miny+yshift-scale])
    im = ax1.scatter(x, y, c=time, alpha=0.1, cmap='autumn')
    divider = make_axes_locatable(ax1)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    cbar = fig.colorbar(im, cax=cax, orientation='vertical')
    cbar.solids.set(alpha=1)

    ax3.set_title("Decoded position")
    ax3.imshow(img, extent=[minx+xshift-scale, maxx+xshift+scale, maxy+yshift+scale, miny+yshift-scale])
    ax3.scatter(x_pred, y_pred, c=time, alpha=0.5, cmap='autumn')

    ax2.set_title("Dopamine")
    t1 = ax2.scatter(time, scale_to_bounds(theta_phase,dopamine), c='grey', alpha=0.1)
    t1.set_label('Theta phase')
    ax2.scatter(time, dopamine)
    #ax2.set_xlabel("Time (s)")
    ax2.legend()

    ax4.set_title("Decode distance")
    t2 = ax4.scatter(time, scale_to_bounds(theta_phase,decode_dist), c='grey', alpha=0.1)
    t2.set_label('Theta phase')
    ax4.plot(time, decode_dist)   
    ax4.set_xlabel("Time (s)")
    ax4.legend()
    
    #ax4.set_title("Velocity")
    #ax4.set_xlabel("Time (s)")
    #ax4.scatter(time, vel)     

    if plot_direction_changes:
        # find turnaround point
        dead_end_path = dead_end_df.dead_end_path[dead_end_df.dead_end_entry==entry].iloc[0]
        turnaround_index = find_dead_end_turnaround(x, y, hexes, dead_end_path)
    
        # adjust index for plotting because time and smoothed velocity start from 0
        adjusted_index = turnaround_index - min(x.index) 
    
        # highlight this point on the plots in blue
        ax1.scatter(x[turnaround_index], y[turnaround_index], c='aqua')
        ax3.scatter(x_pred[turnaround_index], y_pred[turnaround_index], c='aqua')
        ax2.scatter(time[adjusted_index], dopamine[turnaround_index], c='aqua')
        ax4.scatter(time[adjusted_index], decode_dist[turnaround_index], c='aqua')
        #ax4.scatter(time[adjusted_index], vel[adjusted_index], c='aqua')

Hexes traveled: 3
Time spent in this dead end: 5.86 seconds


Still working on this bit!

In [6]:
def plot_histograms(hexes_traveled_list, time_spent_list):
    # Histogram of time spent in dead end
    fig, axs = plt.subplots(2, 2)
    axs[0,0].hist(time_spent_list, bins=20, edgecolor='black', alpha=0.7)
    axs[0,0].set_xlabel('Time (seconds)')
    axs[0,0].set_ylabel('Count')
    axs[0,0].set_title('Distribution of time spent in a dead end')

    # Histogram of hexes traveled down dead end
    axs[1,0].hist(hexes_traveled_list, bins=3, edgecolor='black', alpha=0.7)
    axs[1,0].set_xlabel('Hexes traveled')
    axs[1,0].set_ylabel('Count')
    axs[1,0].set_title('Distribution of hexes traveled down dead end')

    # now show these over time
    axs[0,1].plot(time_spent_list, marker='o', linestyle='-')
    axs[0,1].set_xlabel('Dead end entry')
    axs[0,1].set_ylabel('Time (s)')
    axs[0,1].set_title('Time spent in dead end')

    axs[1,1].plot(hexes_traveled_list, marker='o', linestyle='-')
    axs[1,1].set_xlabel('Dead end entry')
    axs[1,1].set_ylabel('Hexes')
    axs[1,1].set_title('Hexes traveled down dead end')

    plt.tight_layout()
    plt.show()

In [7]:
hexes_traveled_list, time_spent_list = get_stats(dead_end_df)
plot_histograms(hexes_traveled_list, time_spent_list)

# ok, do the same thing for an individual dead end!
# TODO: Fix this part!!!
for path in dead_end_paths:
    slice_df = dead_end_df[dead_end_df['dead_end_path'].apply(lambda x: np.array_equal(x, path))]
    hexes_traveled_list, time_spent_list = get_stats(slice_df)
    plot_histograms(hexes_traveled_list, time_spent_list)

### (De)function definitions 
(aka functions I wrote once but no longer use but don't wanna delete yet so they are defunct haha get it)

In [8]:
def calculate_direction_vector(x1, y1, x2, y2):
    """ Calculate the direction vector from (x1, y1) to (x2, y2) """
    
    delta_x = x2 - x1
    delta_y = y2 - y1
    magnitude = math.sqrt(delta_x**2 + delta_y**2)
    if magnitude == 0:
        return (0, 0)  # Avoid division by zero
    return (delta_x / magnitude, delta_y / magnitude)


def find_direction_change_indices(x_coords, y_coords, angle_threshold=1.57): # 1.5708 = 90 degrees in radians
    """ Given a list of x and y coordinates and an angle_threshold, find the indices where 
    the direction between sucessive coordinates changed more than angle_threshold radians 
    """
    
    if len(x_coords) < 3 or len(y_coords) < 3 or len(x_coords) != len(y_coords):
        return []  # Invalid data
    
    direction_change_indices = []
    
    for i in range(1, len(x_coords) - 1):
        prev_direction = calculate_direction_vector(x_coords[i-1], y_coords[i-1], x_coords[i], y_coords[i])
        next_direction = calculate_direction_vector(x_coords[i], y_coords[i], x_coords[i+1], y_coords[i+1])
        
        # Check if either of the directions is zero (no movement)
        if prev_direction == (0, 0) or next_direction == (0, 0):
            continue
        
        # Calculate the angle between direction vectors using the dot product
        dot_product = prev_direction[0] * next_direction[0] + prev_direction[1] * next_direction[1]
        angle = math.acos(max(-1, min(1, dot_product)))
        
        # Check if the angle is greater than the threshold
        if angle > angle_threshold:
            # uncomment the print statement below for useful debugging 
            #print("angle: ",angle," xy-1: ",x_coords[i-1],",",y_coords[i-1]," xy: ",x_coords[i],",", y_coords[i], " xy+1: ",x_coords[i+1],",",y_coords[i+1],"prev dir",prev_direction,"next dir",next_direction)
            direction_change_indices.append(i)
    
    return direction_change_indices

def find_turnaround_index_xy(x, y):
    """ Find the index of the furthest coordinate from the first coordinate in a list of x, y coordinates """
    
    max_distance = 0
    turnaround_index = []

    for i in range(len(x)):
        distance = math.sqrt((x[i] - x[0]) ** 2 + (y[i] - y[0]) ** 2)
        if distance > max_distance:
            max_distance = distance
            turnaround_index = [i]

    return turnaround_index