In [1]:
import sys
path1 = '/Users/steph/berkelab/DA_maze/DA/Modules/'
path2 = '/Users/steph/berkelab/DA_maze/Behavior/Modules/'
sys.path += [path1,path2]

In [2]:
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
import seaborn as sns
sns.set()
plt.style.use('default')
from multi_rat_da import *
from tmat_ops import *
from hexLevelAnalyses import get_sigRats_fromMeanList
from photometryQuantifications import *
from scipy.stats import wilcoxon
from celluloid import Camera
from pdf2image import convert_from_path
%matplotlib qt

In [3]:
loadpath = "/Users/steph/berkelab/Data/" #location of dataset
#df = pd.read_csv(loadpath+"phot_decode_df_withHexStates.csv")
#df = pd.read_csv("/Volumes/Tim/Photometry/IM-1478/07162022/IM-1478_07162022_h_sampleframe.csv")

def normalize(data):
    return (data - np.min(data)) / (np.max(data) - np.min(data))

def reduce_mem_usage(df):
    start_mem = df.memory_usage().sum() / 1024**2
    print('Memory usage of dataframe is {:.2f} MB'.format(start_mem))
    for col in df.columns:
        col_type = df[col].dtype
        if col_type != object:
            c_min = df[col].min()
            c_max = df[col].max()
            if str(col_type)[:3] == 'int':
                if c_min > np.iinfo(np.int8).min and c_max < np.iinfo(np.int8).max:
                    df[col] = df[col].astype(np.int8)
                elif c_min > np.iinfo(np.int16).min and c_max < np.iinfo(np.int16).max:
                    df[col] = df[col].astype(np.int16)
                elif c_min > np.iinfo(np.int32).min and c_max < np.iinfo(np.int32).max:
                    df[col] = df[col].astype(np.int32)
                elif c_min > np.iinfo(np.int64).min and c_max < np.iinfo(np.int64).max:
                    df[col] = df[col].astype(np.int64)
            else:
                if c_min > np.finfo(np.float16).min and c_max < np.finfo(np.float16).max:
                    df[col] = df[col].astype(np.float16)
                elif c_min > np.finfo(np.float32).min and c_max < np.finfo(np.float32).max:
                    df[col] = df[col].astype(np.float32)
                else:
                    df[col] = df[col].astype(np.float64)
        else:
            df[col] = df[col].astype('category')
    end_mem = df.memory_usage().sum() / 1024**2
    print('Memory usage after optimization is: {:.2f} MB'.format(end_mem))
    print('Decreased by {:.1f}%'.format(100 * (start_mem - end_mem) / start_mem))
    return df

#df = reduce_mem_usage(df)

In [61]:
# pls show me the names of the dataframe columns bc I forget them thanks
df.columns

Index(['Unnamed: 0', 'x', 'y', 'green', 'port', 'rwd', 'block', 'pA', 'pB',
       'pC', 'frame', 'ref', 'vel', 'acc', 'tri', 'fromP', '470', 'beamA',
       'beamB', 'beamC', 'tot_tri', 'green_z_scored', 'fiberloc',
       'session_type', 'rat', 'date', 'nextprob', 'nextp', 'lenAC', 'lenBC',
       'lenAB', 'simple_rr', 'pchosen', 'dtop', 'nom_rwd_a', 'nom_rwd_b',
       'nom_rwd_c', 'hexlabels'],
      dtype='object')

In [124]:
# plot all x,y points (in a given block) to get a picture of the maze
bnum = 1
pad = 20
x = df.x[df.block==bnum]
y = df.y[df.block==bnum]
minx = np.nanmin(x)
miny = np.nanmin(y)
maxx = np.nanmax(x)
maxy = np.nanmax(y)
fig = plt.figure()
plt.plot(x, y, '.')
plt.xlim(minx-pad, maxx+pad)
plt.ylim(miny-pad, maxy+pad)
plt.axis('off')
plt.savefig('maze_background.png', bbox_inches='tight', transparent=True, pad_inches=0)

# same thing, but use a different color for each hex (rainbow order, yay)
fig = plt.figure()
hexcolor = iter(cm.rainbow(np.linspace(0, 1, 50)))    
for hex in range (1,50):
    c = next(hexcolor)
    plt.plot(x[df.hexlabels==hex], y[df.hexlabels==hex], '.', c=c, alpha=0.1)
plt.axis('off')
plt.savefig('maze_background_colored.png', bbox_inches='tight', transparent=True, pad_inches=0)

In [6]:
# test load image for plot background
img = plt.imread("maze_background.png")
fig, ax = plt.subplots()
ax.imshow(img, extent=[minx-pad, maxx+pad, miny-pad, maxy+pad])

<matplotlib.image.AxesImage at 0x12bf3e710>

In [7]:
# create plot of rat's xy (colored by hex) overlayed on the hex background

# these are empirical fudge factors so the hex background and the rat's xy position (kinda sorta) line up
# ideally we would just scale automatically but the rat's min/max xy aren't super reliable
xshift = -7
yshift = -35
imscale = 0

img = plt.imread("hex_background.png")
fig, ax = plt.subplots()
ax.imshow(img, extent=[minx+xshift-imscale, maxx+xshift+imscale, maxy+yshift+imscale, miny+yshift-imscale])

hexcolor = iter(cm.rainbow(np.linspace(0, 1, 50)))    
x = df.x[df.block==1]
y = df.y[df.block==1]
for hex in range (1,50):
    c = next(hexcolor)
    plt.plot(x[df.hexlabels==hex], y[df.hexlabels==hex], '.', c=c, alpha=0.1)

In [8]:
# Animation of rat xy and predicted xy with background generated from all rat xy
img = plt.imread("maze_background_colored.png")
fig, ax = plt.subplots()

camera = Camera(fig)
for i in range(1000,2000):
    ax.imshow(img, extent=[minx-pad, maxx+pad, miny-pad, maxy+pad], alpha=0.3)
    plt.plot(df.x[i], df.y[i], color='grey', marker='o')
    plt.plot(df.x_pred[i], df.y_pred[i], 'yo')
    camera.snap()
animation = camera.animate(interval=10, blit=True)

In [5]:
# Animation of rat xy and predicted xy with hex background
img = plt.imread("hex_background.png")
fig, ax = plt.subplots()

camera = Camera(fig)
for i in range(1000,2000):
    ax.invert_yaxis()
    ax.imshow(img, extent=[minx+xshift-imscale, maxx+xshift+imscale, maxy+yshift+imscale, miny+yshift-imscale], alpha=0.3)
    plt.plot(df.x[i], df.y[i], color='grey', marker='o')
    plt.plot(df.x_pred[i], df.y_pred[i], 'yo')
    ax.text(minx+xshift, miny+yshift, 'Rat in {}, pred in {}'.format(int(df.hexlabels[i]), df.pred_hexlabels[i]), verticalalignment='top', horizontalalignment='left', fontsize=15)
    camera.snap()
animation = camera.animate(interval=10, blit=True)

NameError: name 'xshift' is not defined

In [123]:
# Plot of rat xy colored by dopamine

# these are empirical fudge factors so the hex background and the rat's xy position (kinda sorta) line up
# ideally we would just scale automatically but the rat's min/max xy aren't super reliable
xshift = -7
yshift = -35
imscale = 0
trialnum = 6
blocknum = 2

xpred = df.x_pred[(df.block==blocknum)&(df.tri==trialnum)]
ypred = df.y_pred[(df.block==blocknum)&(df.tri==trialnum)]
x = df.x[(df.block==blocknum)&(df.tri==trialnum)]
y = df.y[(df.block==blocknum)&(df.tri==trialnum)]
dopamine = normalize(df.green_z_scored[(df.block==blocknum)&(df.tri==trialnum)])

img = plt.imread("hex_background.png")
fig, (ax1, ax2) = plt.subplots(1,2)
ax1.imshow(img, extent=[minx+xshift-imscale, maxx+xshift+imscale, maxy+yshift+imscale, miny+yshift-imscale])
ax1.scatter(x, y, c=dopamine, cmap='plasma', alpha=0.3)
ax2.imshow(img, extent=[minx+xshift-imscale, maxx+xshift+imscale, maxy+yshift+imscale, miny+yshift-imscale])
ax2.scatter(xpred, ypred, c=dopamine, cmap='plasma', alpha=0.3)


NameError: name 'minx' is not defined

In [7]:
# add animation of this ^ trial
fig, ax = plt.subplots()
img = plt.imread("hex_background.png")

camera = Camera(fig)
for i in range(min(x.index), max(x.index)):
    #ax.invert_yaxis()
    ax.imshow(img, extent=[minx+xshift-imscale, maxx+xshift+imscale, maxy+yshift+imscale, miny+yshift-imscale], alpha=0.3)
    plt.plot(x[i], y[i], color='grey', marker='o')
    plt.plot(xpred[i],ypred[i], 'yo')
    ax.text(minx+xshift, miny+yshift, 'Rat in {}, pred in {}'.format(int(df.hexlabels[i]), df.pred_hexlabels[i]), verticalalignment='top', horizontalalignment='left', fontsize=15)
    camera.snap()
animation = camera.animate(interval=10, blit=True)

Let's mess with transition matrices!

In [11]:
tmat = np.load('/Users/steph/berkelab/DA_maze/Data/Transition_Matrices/IM-1478/07252022/tmat_block_1.0.npy')

In [12]:
fig, ax = plt.subplots()
ax.imshow(tmat)


<matplotlib.image.AxesImage at 0x135abfd10>

In [13]:
# when using Photrat to load things, make sure Tim's directory is mounted
remy = Photrat('IM-1478', '07252022')
remy.load_df()


Memory usage of dataframe is 387.17 MB
Memory usage after optimization is: 112.92 MB
Decreased by 70.8%


In [14]:
#tmat = remy.load_tmat()
tmo = TmatOperations(remy)
tmo.get_availstates()

TypeError: Photrat.__init__() missing 1 required positional argument: 'date'

In [None]:
remy.create_hexdf(98.0, 1)

In [16]:
hexData = PhotRats(None)
# the directory containing your dataframe
loadpath = "/Volumes/Tim/Photometry/10MfRatDataSet/data4sharing/"
datName = "hexLevelDf"
hexData.df = reduce_mem_usage(pd.read_csv(loadpath+datName+".csv",index_col=[0]))
hexData.directory_prefix = loadpath

Memory usage of dataframe is 52.45 MB
Memory usage after optimization is: 12.54 MB
Decreased by 76.1%


In [36]:
seshdict = "/Volumes/Tim/Photometry/10MfRatDataSet/sessionTable.csv"

photrats = PhotRats(seshdict)
photrats.directory_prefix = "/Volumes/Tim/Photometry/"
#photrats.load_tmats()

hexData = PhotRats(None)
# the directory containing your dataframe
loadpath = "/Volumes/Tim/Photometry/10MfRatDataSet/data4sharing/"
datName = "hexLevelDf"
hexData.df = reduce_mem_usage(pd.read_csv(loadpath+datName+".csv",index_col=[0]))
hexData.directory_prefix = loadpath

In [46]:
seshdict = pd.read_csv("/Volumes/Tim/Photometry/10MfRatDataSet/sessionTable.csv")
photrats = PhotRats(seshdict)
photrats.directory_prefix = "/Volumes/Tim/Photometry/"
photrats.load_dfs()

could not load df for Unnamed: 0; 0


TypeError: 'NoneType' object does not support item assignment

All of my functions for things! 

In [6]:
def distance_squared(x1, y1, x2, y2):
    """ Calculates the squared Euclidean distance between two points """
    return (x2 - x1)**2 + (y2 - y1)**2

def find_turnaround(x_coords, y_coords):
    """ Find the index at which the rat on a dead end path turns around

    Parameters: x_coords (list); y_coords (list): 2 lists of x and y coordinates representing the path
    Returns: the index (int) at which the rat turns around, or None if no turnaround is detected
    """
    
    if len(x_coords) < 2 or len(y_coords) < 2 or len(x_coords) != len(y_coords):
        return None  # Invalid data
    
    max_distance_sq = 0
    turnaround_index = None
    
    for i in range(1, len(x_coords)):
        dist_sq = distance_squared(x_coords[0], y_coords[0], x_coords[i], y_coords[i])
        if dist_sq > max_distance_sq:
            max_distance_sq = dist_sq
            turnaround_index = i
    
    return turnaround_index

In [4]:
def find_dead_ends(tmat):
    """ Given a transition matrix, returns an array of hexes that are dead ends """
    
    dead_ends = np.unique(np.where(tmat==1)[0])
    dead_ends = np.delete(dead_ends,np.isin(dead_ends,[1,2,3]))
    return dead_ends

def path_to_dead_end(tmat, next_hex, path=[]):
    """ Given a transition matrix and a dead end hex, returns the path of hexes to that dead end 
    
    Note: Third argument [] or name of empty list is required to ensure you get a new list 
    instead of modifying the same list from the last time this function was run
    """

    path.append(next_hex)
    next_hexes = np.where(tmat[:, next_hex]==0.5)[0]
    for hex in next_hexes:
        if hex not in path and hex not in [1,2,3]:
            path_to_dead_end(tmat, hex, path)
    return list(path)

def get_all_dead_end_hexes(tmat):
    """ Given a transition matrix, get all hexes that are part of dead end paths """

    all_dead_end_hexes = []
    dead_ends = find_dead_ends(tmat)
    for hex in dead_ends:
        all_dead_end_hexes += path_to_dead_end(tmat, hex, [])[:]
    return list(all_dead_end_hexes)

def animate_rat_and_pred(df, background_img):
    """ Animation of rat xy and predicted xy with given background image """
    
    fig, ax = plt.subplots()
    camera = Camera(fig)
    for i in df.index:
        #ax.invert_yaxis()
        #ax.imshow(img, extent=[minx+xshift-imscale, maxx+xshift+imscale, maxy+yshift+imscale, miny+yshift-imscale], alpha=0.3)
        ax.imshow(background_img, alpha=0.3)
        plt.plot(df.x[i], df.y[i], color='grey', marker='o')
        plt.plot(df.x_pred[i], df.y_pred[i], 'yo')
        #ax.text(minx+xshift, miny+yshift, 'Rat in {}, pred in {}'.format(int(df.hexlabels[i]), df.pred_hexlabels[i]), verticalalignment='top', horizontalalignment='left', fontsize=15)
        #camera.snap()
    #animation = camera.animate(interval=10, blit=True)

def get_bounds(df):
    return np.nanmin(df.x), np.nanmin(df.y), np.nanmax(df.x), np.nanmax(df.y)

def divide_into_sections(arr):
    """ Divides a sorted list of numbers into sections of consecutive numbers increasing by 1.
    
    Parameters: arr (list): A sorted list of numbers with potential breaks. 
    Returns: result (list): A list where each section of increasing numbers is represented by consecutive integers.

    Example:
    >>> input_array = [7, 8, 9, 10, 11, 55, 56, 57, 58, 59, 60, 61, 990, 991, 992, 993]
    >>> result_array = divide_into_sections(input_array)
    >>> print(result_array)
    [1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3]
    """

    result = [0]*len(arr) # preallocate result for speeeeed
    current_section = 1

    for i in range(1, len(arr)):
        if arr[i] != arr[i-1] + 1:
            current_section += 1
        result[i] = current_section

    return result

In [5]:
# 1. load dataframe and corresponding transition matrix
# 2. find dead end hexes
# 3. find paths to those dead ends
# 4. find all times a rat is inside a dead end path

specific_path = "IM-1478/07202022"
phot_path_base = "/Volumes/Tim/Photometry/"
ephys_path_base = "/Volumes/Tim/Ephys/"

# Load transition matrix and dataframe for this session
transition_mat = np.load(phot_path_base+specific_path+"/tmat.npy")
df = reduce_mem_usage(pd.read_csv(ephys_path_base+specific_path+"/phot_decode_df_withHexStates.csv"))
minx, miny, maxx, maxy = get_bounds(df)

# Load pdf image of hexes from this session to use as the plot background
image = convert_from_path(phot_path_base+specific_path+"/hex_layout.pdf")
image[0].save('hex_background.png', 'png')

# Get subset of dataframe where rat the is in dead ends
dead_end_hexes = get_all_dead_end_hexes(transition_mat)
indices = [i for i in df.index if df['hexlabels'][i] in dead_end_hexes]
dead_end_df = df[df.index.isin(indices)]

# Add column to the dataframe labeling each entry into a dead end path
dead_end_entry = divide_into_sections(dead_end_df.index)
dead_end_df['dead_end_entry'] = dead_end_entry

Memory usage of dataframe is 522.47 MB
Memory usage after optimization is: 122.46 MB
Decreased by 76.6%


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dead_end_df['dead_end_entry'] = dead_end_entry


In [15]:
# Plot instance(s) of when a rat is in a dead end
scale = 60
xshift = 10
yshift = -30

entry = 6 # which entry into a dead end we care about
plot_direction_changes = True # if we want to highlight direction changes on the plot

# get subsets of data (it's faster than constantly accessing the dataframe)
x = dead_end_df.x[dead_end_df.dead_end_entry==entry]
y = dead_end_df.y[dead_end_df.dead_end_entry==entry]
x_pred = dead_end_df.x_pred[dead_end_df.dead_end_entry==entry]
y_pred = dead_end_df.y_pred[dead_end_df.dead_end_entry==entry]
dopamine = dead_end_df.green_z_scored[dead_end_df.dead_end_entry==entry]
idx = dead_end_df.index[dead_end_df.dead_end_entry==entry]
vel = dead_end_df.vel[dead_end_df.dead_end_entry==entry]

# plot !
img = plt.imread("hex_background.png")
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2,2)
ax1.set_title("Rat's actual position")
ax1.imshow(img, extent=[minx+xshift-scale, maxx+xshift+scale, maxy+yshift+scale, miny+yshift-scale])
ax1.scatter(x, y, c=idx, alpha=0.1, cmap='autumn')
ax3.set_title("Decoded position")
ax3.imshow(img, extent=[minx+xshift-scale, maxx+xshift+scale, maxy+yshift+scale, miny+yshift-scale])
ax3.scatter(x_pred, y_pred, c=idx, alpha=0.5, cmap='autumn')
ax2.set_title("Dopamine")
ax2.scatter(idx, dead_end_df.theta_phase[dead_end_df.dead_end_entry==entry], c='grey', alpha=0.1)
ax2.scatter(idx, dopamine)
ax4.set_title("Velocity")
ax4.scatter(idx, vel)

if plot_direction_changes:
    # find points where the rat's direction changes
    direction_changes = find_direction_change_indices(list(x), list(y), 1.58)
    direction_changes = [d + min(x.index) for d in direction_changes]
    # highlight these points on the plot in green
    for i in direction_changes:
        ax1.scatter(x[i], y[i], c='lawngreen')
        ax3.scatter(x_pred[i], y_pred[i], c='lawngreen')
        ax2.scatter(i, dopamine[i], c='lawngreen')
        ax4.scatter(i, vel[i], c='lawngreen')

angle:  2.0344439357957027  xy-1:  291.75 , 235.0  xy:  291.75 , 235.125  xy+1:  292.0 , 235.0 prev dir (0.0, 1.0) next dir (0.8944271909999159, -0.4472135954999579)
angle:  3.141592653589793  xy-1:  295.0 , 225.75  xy:  294.75 , 225.75  xy+1:  295.0 , 225.75 prev dir (-1.0, 0.0) next dir (1.0, 0.0)
angle:  2.0344439357957027  xy-1:  307.5 , 209.75  xy:  307.5 , 209.625  xy+1:  307.25 , 209.75 prev dir (0.0, -1.0) next dir (-0.8944271909999159, 0.4472135954999579)
angle:  2.677945044588987  xy-1:  305.0 , 211.25  xy:  304.75 , 211.375  xy+1:  305.0 , 211.375 prev dir (-0.8944271909999159, 0.4472135954999579) next dir (1.0, 0.0)


In [14]:
import math 

def calculate_direction_vector(x1, y1, x2, y2):
    """
    Calculate the direction vector from (x1, y1) to (x2, y2).
    """
    delta_x = x2 - x1
    delta_y = y2 - y1
    magnitude = math.sqrt(delta_x**2 + delta_y**2)
    
    if magnitude == 0:
        return (0, 0)  # Avoid division by zero
    
    return (delta_x / magnitude, delta_y / magnitude)

def find_direction_change_indices(x_coords, y_coords, angle_threshold=1.570):  # 2.094 120 degrees in radians
    if len(x_coords) < 3 or len(y_coords) < 3 or len(x_coords) != len(y_coords):
        return []  # Invalid data
    
    direction_change_indices = []
    
    for i in range(1, len(x_coords) - 1):
        prev_direction = calculate_direction_vector(x_coords[i-1], y_coords[i-1], x_coords[i], y_coords[i])
        next_direction = calculate_direction_vector(x_coords[i], y_coords[i], x_coords[i+1], y_coords[i+1])
        
        # Check if either of the directions is zero (no movement)
        if prev_direction == (0, 0) or next_direction == (0, 0):
            continue
            
        dot_product = prev_direction[0] * next_direction[0] + prev_direction[1] * next_direction[1]
        
        # Calculate the angle between direction vectors using the dot product
        angle = math.acos(max(-1, min(1, dot_product)))
        
        # Check if the angle is greater than the threshold (120 degrees)
        if angle > angle_threshold:
            print("angle: ",angle," xy-1: ",x_coords[i-1],",",y_coords[i-1]," xy: ",x_coords[i],",", y_coords[i], " xy+1: ",x_coords[i+1],",",y_coords[i+1],"prev dir",prev_direction,"next dir",next_direction)
            direction_change_indices.append(i)
    
    return direction_change_indices

direction_changes = find_direction_change_indices(list(x), list(y))

if direction_changes:
    print("Direction changes greater detected at the following points:")
    for point in direction_changes:
        print(point)
else:
    print("No direction changes detected.")

angle:  1.5707963267948966  xy-1:  325.75 , 205.5  xy:  325.75 , 205.625  xy+1:  326.0 , 205.625 prev dir (0.0, 1.0) next dir (1.0, 0.0)
angle:  1.5707963267948966  xy-1:  325.75 , 205.625  xy:  326.0 , 205.625  xy+1:  326.0 , 205.75 prev dir (1.0, 0.0) next dir (0.0, 1.0)
angle:  1.5707963267948966  xy-1:  326.25 , 206.5  xy:  326.25 , 206.625  xy+1:  326.5 , 206.625 prev dir (0.0, 1.0) next dir (1.0, 0.0)
angle:  1.5707963267948966  xy-1:  326.25 , 206.625  xy:  326.5 , 206.625  xy+1:  326.5 , 206.75 prev dir (1.0, 0.0) next dir (0.0, 1.0)
angle:  1.5707963267948966  xy-1:  329.25 , 209.125  xy:  329.25 , 209.25  xy+1:  329.5 , 209.25 prev dir (0.0, 1.0) next dir (1.0, 0.0)
angle:  1.5707963267948966  xy-1:  329.25 , 209.25  xy:  329.5 , 209.25  xy+1:  329.5 , 209.375 prev dir (1.0, 0.0) next dir (0.0, 1.0)
angle:  1.5707963267948966  xy-1:  329.5 , 209.375  xy:  329.5 , 209.5  xy+1:  329.75 , 209.5 prev dir (0.0, 1.0) next dir (1.0, 0.0)
angle:  1.5707963267948966  xy-1:  329.5 , 20

In [43]:
#plt.plot(dead_end_df.index)
plt.plot(dead_end_df.tri)

[<matplotlib.lines.Line2D at 0x310da4e90>]

369628    185.375
369629    185.500
369630    185.625
369631    185.625
369632    185.750
           ...   
374192    179.125
374193    179.000
374194    178.750
374195    178.500
374196    178.250
Name: x, Length: 4569, dtype: float16

In [46]:
fig, ax = plt.subplots()
#ax.imshow(img, alpha=0.3)
ax.invert_yaxis()
ax.scatter(dead_end_df.x, dead_end_df.y, c='grey')
ax.scatter(dead_end_df.x_pred, dead_end_df.y_pred, c='yellow')
    #ax.text(minx+xshift, miny+yshift, 'Rat in {}, pred in {}'.format(int(df.hexlabels[i]), df.pred_hexlabels[i]), verticalalignment='top', horizontalalignment='left', fontsize=15)
    #camera.snap()
    #animation = camera.animate(interval=10, blit=True)

<matplotlib.collections.PathCollection at 0x310e25a90>

In [None]:
# extract the stuff we care about once so plotting is way faster
x = df.x
y = df.y
minx = np.nanmin(x)
miny = np.nanmin(y)
maxx = np.nanmax(x)
maxy = np.nanmax(y)

xpred = df.x_pred[(df.block==blocknum)&(df.tri==trialnum)]
ypred = df.y_pred[(df.block==blocknum)&(df.tri==trialnum)]
x = df.x[(df.block==blocknum)&(df.tri==trialnum)]
y = df.y[(df.block==blocknum)&(df.tri==trialnum)]
dopamine = normalize(df.green_z_scored[(df.block==blocknum)&(df.tri==trialnum)])

img = plt.imread("hex_background.png")
fig, (ax1, ax2) = plt.subplots(1,2)
ax1.imshow(img, extent=[minx+xshift-imscale, maxx+xshift+imscale, maxy+yshift+imscale, miny+yshift-imscale])
ax1.scatter(x, y, c=dopamine, cmap='plasma', alpha=0.3)
ax2.imshow(img, extent=[minx+xshift-imscale, maxx+xshift+imscale, maxy+yshift+imscale, miny+yshift-imscale])
ax2.scatter(xpred, ypred, c=dopamine, cmap='plasma', alpha=0.3)