In [1]:
import sys
path1 = '/Users/steph/berkelab/DA_maze/DA/Modules/'
path2 = '/Users/steph/berkelab/DA_maze/Behavior/Modules/'
sys.path += [path1,path2]

import numpy as np
import pandas as pd
import math
from matplotlib import pyplot as plt
from turnaround_functions import *
from multi_rat_da import *
from celluloid import Camera
from pdf2image import convert_from_path
from mpl_toolkits.axes_grid1 import make_axes_locatable
%matplotlib qt

phot_path = "/Volumes/Tim/Photometry/"
ephys_path = "/Volumes/Tim/Ephys/"
min_dead_end_length = 3

### Data Preprocessing
- For each session, load dataframe + transition matrix
- Add a bunch of columns to the dataframe to make it easy to slice and iterate based on different things
- Print some stats about this session

In [4]:
def load_dataframe_and_tmat(base_path, block_num):
    """ Load transition matrix and dataframe for this session """
    
    tmat_suffix = "" if block_num == 0 else "_block_{}.0".format(block_num)
    transition_mat = np.load(phot_path+base_path+"/tmat"+tmat_suffix+".npy")
    df = reduce_mem_usage(pd.read_csv(ephys_path+base_path+"/phot_decode_df_withHexStates.csv"))
    if block_num !=0:
        df = df[df.block==block_num]
    return df, transition_mat


def do_preprocessing(df, transition_mat):
    """ Add a bunch of columns to the dataframe for dead end analysis """

    # Use the transition matrix to get paths to dead end hexes
    dead_end_paths = get_all_dead_end_paths(transition_mat, min_dead_end_length)

    # Get subset of dataframe where rat the is in dead ends
    dead_end_hexes = [hex for path in dead_end_paths for hex in path]
    indices = [i for i in df.index if df['hexlabels'][i] in dead_end_hexes]
    dead_end_df = df.loc[df.index.isin(indices)].copy()

    # Add column to the dataframe labeling which dead end path we are in
    print("Paths to dead end hexes:", dead_end_paths)
    path = find_dead_end_path_for_hex(dead_end_df.hexlabels, dead_end_paths)
    dead_end_df['dead_end_path'] = path

    # Add column to the dataframe counting each entry into a dead end path (one count for all dead end paths)
    dead_end_entry = divide_into_sections(dead_end_df.index)
    dead_end_df['dead_end_entry'] = dead_end_entry
    print("The rat entered a dead end path", max(dead_end_entry), "times in this session.")

    # Add column to the dataframe counting the entry into each dead end path (different counts for each dead end path)
    individual_dead_end_entry, dead_end_counts = count_entries_for_each_dead_end(path, dead_end_entry)
    dead_end_df['individual_dead_end_entry'] = individual_dead_end_entry
    print("Entry counts for each dead end path:", dead_end_counts)

    # Add columns to the dataframe indicating how far down a dead end path we got and how long we spent in that path
    hexes_traveled_list, time_spent_list = get_stats(dead_end_df)
    dead_end_entry_from_zero = [x-1 for x in dead_end_entry] # adjust to start at 0 so we can use as an index
    dead_end_df['max_hexes_traveled'] = expand_second_list(dead_end_entry_from_zero, hexes_traveled_list)
    dead_end_df['time_in_dead_end'] = expand_second_list(dead_end_entry_from_zero, time_spent_list)

    # Add column to the dataframe of indices centered around the turnaround point
    dead_end_df['centered_indices'] = get_centered_indices(dead_end_df)
    
    return dead_end_df

In [6]:
# Quick and dirty way to load all of these for now!! 
base_paths = ["IM-1478/07172022","IM-1478/07192022","IM-1478/07192022","IM-1478/07202022","IM-1478/07252022",
              "IM-1478/07252022","IM-1478/07272022","IM-1478/07272022","IM-1478/07272022","IM-1594/07252023"]
blocks = [0,1,2,0,2,3,1,2,3,0]

# Load all of the sessions and concatenate them into a mega dataframe
df_list = []
for base_path, block_num in zip(base_paths, blocks):
    df, transition_mat = load_dataframe_and_tmat(base_path, block_num)
    dead_end_df = do_preprocessing(df, transition_mat)
    df_list.append(dead_end_df)

big_df = pd.concat(df_list, ignore_index=True)

Memory usage of dataframe is 489.81 MB
Memory usage after optimization is: 114.80 MB
Decreased by 76.6%
Paths to dead end hexes: [[9, 12, 16, 20]]
The rat entered a dead end path 21 times in this session.
Entry counts for each dead end path: {9: 21}
Memory usage of dataframe is 534.37 MB
Memory usage after optimization is: 125.24 MB
Decreased by 76.6%
Paths to dead end hexes: [[12, 15, 19]]
The rat entered a dead end path 1 times in this session.
Entry counts for each dead end path: {12: 1}
Memory usage of dataframe is 534.37 MB
Memory usage after optimization is: 125.24 MB
Decreased by 76.6%
Paths to dead end hexes: [[11, 14, 17]]
The rat entered a dead end path 8 times in this session.
Entry counts for each dead end path: {11: 8}
Memory usage of dataframe is 522.47 MB
Memory usage after optimization is: 122.46 MB
Decreased by 76.6%
Paths to dead end hexes: [[11, 8, 6], [18, 22, 27], [21, 17, 13]]
The rat entered a dead end path 37 times in this session.
Entry counts for each dead end

#### Notes
"IM-1478/07192022", block 3 should be fine but is giving us UnboundLocalError: cannot access local variable 'exited_index' where it is not associated with a value
"IM-1594/07262023" is saying it has an HUGE number of dead end entries (1000s??) on all blocks but errors out due to single values... maybe the rat just barely enters the dead end path? Mess with turnaround code so it doesn't ruin everything when we have small values
- also, this dead end has a double dead end... not sure how to account for that

In [16]:
# Print columns so we remember what data we're working with
print(big_df.columns)

Index(['Unnamed: 0', 'index', 'green_z_scored', 'port', 'rwd', 'x', 'y',
       'nom_rwd_a', 'nom_rwd_b', 'beamA', 'beamB', 'beamC', 'tri', 'block',
       'nom_rwd_c', 'hexlabels', 'lenAC', 'lenBC', 'lenAB', 'dtop', 'fiberloc',
       'session_type', 'session', 'rat', 'date', 'lastleave', 'nom_rwd_chosen',
       'vel', 'pairedHexStates', 'acc', 'decodeDistToRat', 'decodeHPD',
       'x_pred', 'y_pred', 'mua', 'theta_env', 'theta_phase',
       'theta_phase_bin', 'pred_hexlabels', 'pred_hexState', 'dead_end_path',
       'dead_end_entry', 'individual_dead_end_entry', 'max_hexes_traveled',
       'time_in_dead_end', 'centered_indices', 'Unnamed: 0.1', 'tot_tri',
       'decodedVal', 'localVal'],
      dtype='object')


[<matplotlib.lines.Line2D at 0x16a07ead0>]

### Look at data across all entries into dead end paths
Plot mean (with shaded error) of the rat's velocity, dopamine, and distance from decode as he turns around in a dead end

In [9]:
data_to_plot = ['vel','green_z_scored','decodeDistToRat']

# Desired start and end indices (250 indices = 1 second)
start_index = -1000
end_index = 1000

# Calculate and plot mean and standard deviation
for data_name in data_to_plot:
    mean_arr, std_arr = get_mean_and_std(big_df, data_name, start_index, end_index)
    plot_mean_with_error(mean_arr, std_arr, start_index, end_index, data_name)

# Optional - Plot all of the lines just to see if we have anything weird going on (e.g. one session skewing the mean)
plot_all_lines_check = False
if plot_all_lines_check:
    fig, ax = plt.subplots()
    filtered_df = big_df[(big_df['centered_indices'] >= start_index) & (big_df['centered_indices'] <= end_index)]
    num_dead_end_entries = max(big_df.dead_end_entry)+1
    for entry in list(range(1, num_dead_end_entries)):
        data = filtered_df[data_name][filtered_df.dead_end_entry==entry]
        idx = filtered_df.centered_indices[filtered_df.dead_end_entry==entry]
        ax.plot(convert_indices_to_time(idx, -1), data)
    ax.set_xlabel("Time (s)")
    ax.set_ylabel(data_name)
    plt.show()

Still working on this bit!

In [6]:
def plot_histograms(hexes_traveled_list, time_spent_list):
    # Histogram of time spent in dead end
    fig, axs = plt.subplots(2, 2)
    axs[0,0].hist(time_spent_list, bins=20, edgecolor='black', alpha=0.7)
    axs[0,0].set_xlabel('Time (seconds)')
    axs[0,0].set_ylabel('Count')
    axs[0,0].set_title('Distribution of time spent in a dead end')

    # Histogram of hexes traveled down dead end
    axs[1,0].hist(hexes_traveled_list, bins=3, edgecolor='black', alpha=0.7)
    axs[1,0].set_xlabel('Hexes traveled')
    axs[1,0].set_ylabel('Count')
    axs[1,0].set_title('Distribution of hexes traveled down dead end')

    # now show these over time
    axs[0,1].plot(time_spent_list, marker='o', linestyle='-')
    axs[0,1].set_xlabel('Dead end entry')
    axs[0,1].set_ylabel('Time (s)')
    axs[0,1].set_title('Time spent in dead end')

    axs[1,1].plot(hexes_traveled_list, marker='o', linestyle='-')
    axs[1,1].set_xlabel('Dead end entry')
    axs[1,1].set_ylabel('Hexes')
    axs[1,1].set_title('Hexes traveled down dead end')

    plt.tight_layout()
    plt.show()

In [7]:
hexes_traveled_list, time_spent_list = get_stats(dead_end_df)
plot_histograms(hexes_traveled_list, time_spent_list)

# ok, do the same thing for an individual dead end!
# TODO: Fix this part!!!
for path in dead_end_paths:
    slice_df = dead_end_df[dead_end_df['dead_end_path'].apply(lambda x: np.array_equal(x, path))]
    hexes_traveled_list, time_spent_list = get_stats(slice_df)
    plot_histograms(hexes_traveled_list, time_spent_list)

NameError: name 'dead_end_paths' is not defined

In [35]:
plt.figure()
plt.plot(dead_end_df.dead_end_path)

ValueError: setting an array element with a sequence.

In [37]:
p = "IM-1594/0722023"
b = 3
df, transition_mat = load_dataframe_and_tmat(p, b)
dead_end_df = do_preprocessing(df, transition_mat)

FileNotFoundError: [Errno 2] No such file or directory: '/Volumes/Tim/Ephys/IM-1594/07312023/phot_decode_df_withHexStates.csv'