# Detect place cells

Load the required packages

In [None]:
import numpy as np
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import os
import sys
import re
import math
import holoviews as hv
import panel as pn
import bisect
hv.extension('bokeh', 'matplotlib')
from IPython.display import display
from ipyfilechooser import FileChooser
import warnings
from scipy.stats import zscore
import matplotlib.cm as cm
import matplotlib.colors as mcolors
from scipy.signal import resample
from scipy.stats import sem
import json
import ipywidgets as widgets
%matplotlib widget
warnings.filterwarnings("ignore")
#%reset
from scipy.interpolate import interp1d
from collections import defaultdict


def remove_outliers_avg_filter(data):
    data = np.array(data, dtype=float)  # Ensure NumPy array with float type
    filtered_data = np.copy(data)  # Copy to avoid modifying original data
    for i in range(len(data)):
        if not np.isnan(data[i]):  # Skip valid values
            continue
        # Find the closest previous non-NaN value
        prev_idx = i - 1
        while prev_idx >= 0 and np.isnan(data[prev_idx]):
            prev_idx -= 1        
        # Find the closest next non-NaN value
        next_idx = i + 1
        while next_idx < len(data) and np.isnan(data[next_idx]):
            next_idx += 1
        # Compute average if both values exist
        if prev_idx >= 0 and next_idx < len(data):
            filtered_data[i] = (data[prev_idx] + data[next_idx]) / 2
        # If neither exists, NaN remains
    return filtered_data

def find_closest_index_sorted(arr, target):
    idx = bisect.bisect_left(arr, target)  # Find the insertion point
    if idx == 0:
        return 0
    if idx == len(arr):
        return len(arr) - 1
    before = idx - 1
    after = idx
    return before if abs(arr[before] - target) <= abs(arr[after] - target) else after

# Sample callback function
def update_my_folder(chooser):
    global dpath
    dpath = chooser.selected
    %store dpath
    return 

def detect_longest_lowest_sequence(arr, margin=0):
    min_val = np.nanmin(arr)  # Find minimum value
    threshold = min_val + margin  # Define threshold based on margin    
    # Get indices where values are within the threshold
    min_indices = np.where(arr <= threshold)[0]
    # Identify consecutive sequences
    longest_sequence = None
    if len(min_indices) > 0:
        start = min_indices[0]
        max_duration = 0  # Track longest duration        
        for i in range(1, len(min_indices)):
            if min_indices[i] != min_indices[i - 1] + 1:  # Not consecutive
                duration = min_indices[i - 1] - start + 1
                if duration > max_duration:
                    max_duration = duration
                    longest_sequence = (start, min_indices[i - 1], duration)
                start = min_indices[i]  # Reset start index        
        # Check last detected sequence
        duration = min_indices[-1] - start + 1
        if duration > max_duration:
            longest_sequence = (start, min_indices[-1], duration)
    return min_val, threshold, longest_sequence


In [None]:
cd "C:/Users/Manip2/SCRIPTS/minian/"

In [None]:
minian_path = os.path.join(os.path.abspath('..'),'minian')
print("The folder used for minian procedures is : {}".format(minian_path))

In [None]:
sys.path.append(minian_path)
from minian.utilities import (
    TaskAnnotation,
    get_optimal_chk,
    load_videos,
    open_minian,
    save_minian,
)

Select the minian folder

In [None]:
try: # tries to retrieve dpath either from a previous run or from a previous notebook
    %store -r dpath
except:
    print("the path was not defined in store")
    #dpath = "/Users/mb/Documents/Syntuitio/AudreyHay/PlanB/ExampleRedLines/2022_08_06/13_30_01/My_V4_Miniscope/"
    dpath = "//10.69.168.1/crnldata/waking/audrey_hay/L1imaging/AnalysedMarch2023/Gaelle/Baseline_recording"

fc1 = FileChooser(dpath,select_default=True, show_only_dirs = True, title = "<b>Folder with videos</b>", layout=widgets.Layout(width='100%'))
display(fc1)
# Register callback function
fc1.register_callback(update_my_folder)

Import spatial map, Ca2+ traces

In [None]:
mice=Path(dpath).parent.parent.parent.parent.parent.name
date=Path(dpath).parent.parent.parent.name
sessiontype=Path(dpath).parent.parent.name
hour=Path(dpath).parent.name
print(mice, '-',date, '-', sessiontype ,'-', hour)


minianversion = 'minian'
try: # tries to retrieve minianversion either from a previous run or from a previous notebook
    %store -r minianversion
except:
    print("the minian folder to use was not defined in store")
    minianversion = 'minian' #'minianAB' # or 'minian_intermediate'
    %store minianversion

folderMouse = Path(os.path.join(dpath,minianversion))
print(folderMouse)
minian_ds = open_minian(folderMouse)

try: 
    StampsMiniscopeFile = Path(os.path.join(dpath, f'timeStamps.csv'))
    tsmini=pd.read_csv(StampsMiniscopeFile)['Time Stamp (ms)']
    V4subfolder=False
except:
    StampsMiniscopeFile = Path(os.path.join(Path(dpath).parent, f'timeStamps.csv'))
    tsmini=pd.read_csv(StampsMiniscopeFile)['Time Stamp (ms)']
    V4subfolder=True

minian_freq=round(1/np.mean(np.diff(np.array(tsmini)/1000)))
print('Miniscope sample rate =', minian_freq, 'Hz')

Ao = minian_ds['A']
Co = minian_ds['C']

try: 
    TodropFile = folderMouse / f'TodropFileAB.json'
    with open(TodropFile, 'r') as f:
        unit_to_drop = json.load(f)
except:
    TodropFile = folderMouse.parent / f'TodropFileAB.json'
    with open(TodropFile, 'r') as f:
        unit_to_drop = json.load(f)
    
C=Co.drop_sel(unit_id=unit_to_drop)
A=Ao.drop_sel(unit_id=unit_to_drop)

idloc = A.idxmax("unit_id")
Hmax = A.idxmax("height")
Hmax2 = Hmax.max("width")

Wmax = A.idxmax("width")
Wmax2 = Wmax.max("height")
coord1 = Wmax2.to_series()
coord2 = Hmax2.to_series()

a = pd.concat([coord1,coord2], axis=1)
unit = len(a)
print("{} units have been found".format(unit))

Import DeepLabCut data

In [None]:
# Define parameters
pixel_to_cm = 2.25  
table_center_x, table_center_y = 313, 283  # Center of the cheeseboard table on the video
table_center_x, table_center_y = 300, 270  # Center of the cheeseboard table on the video
table_center_x, table_center_y = 315, 275  # Center of the cheeseboard table on the video
table_radius = 290 / 2

# Define functions
def calculate_relative_distance(x1, y1, x2, y2):
    return math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)

def calculate_distance_run(x_coords, y_coords):
    distances = np.sqrt(np.diff(x_coords) ** 2 + np.diff(y_coords) ** 2)
    for i in range(1, len(distances) - 1):
        if np.isnan(distances[i]):
            neighbors = [distances[i-1], distances[i+1]]
            distances[i] = np.mean([x for x in neighbors if not np.isnan(x)])
    total_distance_cm = np.nansum(distances) / pixel_to_cm  # Convert to cm
    return total_distance_cm, distances

def find_long_non_nan_sequences(arr, min_length=100):
    mask = ~np.isnan(arr)  # True for non-NaN values
    diff = np.diff(np.concatenate(([0], mask.astype(int), [0])))  # Add padding to detect edges
    starts = np.where(diff == 1)[0]  # Where a sequence starts
    ends = np.where(diff == -1)[0]   # Where a sequence ends
    sequences = [arr[start:end] for start, end in zip(starts, ends) if (end - start) > min_length]
    return sequences

def remove_outliers_median_filter(data, window=1):
    data = np.array(data, dtype=float)  # Ensure NumPy array with float type
    filtered_data = np.copy(data)  # Copy to avoid modifying original data
    half_window = window // 2
    for i in range(len(data)):
        # Define window range, ensuring it doesn't exceed bounds
        start = max(0, i - half_window)
        end = min(len(data), i + half_window + 1)
        # Extract local values in window
        local_values = data[start:end]
        # Check if the window contains at least one non-NaN value
        if np.all(np.isnan(local_values)):
            median_value = np.nan  # Keep NaN if no valid numbers
        else:
            median_value = np.nanmedian(local_values)  # Compute median ignoring NaNs
        # Replace only if the current value is not NaN
        if not np.isnan(data[i]):
            filtered_data[i] = median_value
    return filtered_data

def replace_high_speed_points_with_nan(x, y, speed_threshold):
    x = np.array(x, dtype='float')
    y = np.array(y, dtype='float')
    # Compute speed between consecutive points
    dx = np.diff(x)
    dy = np.diff(y)
    speeds = np.sqrt(dx**2 + dy**2)
    # Create mask for speed exceeding threshold
    high_speed_mask = speeds > speed_threshold
    # We mark i+1 as NaN if speed between them is too high
    x_out = x.copy()
    y_out = y.copy()
    for i in range(len(high_speed_mask)):
        if high_speed_mask[i]:
            # Only mark the faster of the two points
            if i > 0 and i < len(x) - 1:
                if speeds[i] > speeds[i - 1]:
                    x_out[i + 1] = np.nan
                    y_out[i + 1] = np.nan
                else:
                    x_out[i] = np.nan
                    y_out[i] = np.nan
    return x_out, y_out

def interpolate_2d_path(x, y, kind='linear', fill='extrapolate'):
    x = np.array(x, dtype='float')
    y = np.array(y, dtype='float')
    indices = np.arange(len(x))
    valid_mask = ~np.isnan(x) & ~np.isnan(y)
    if np.sum(valid_mask) < 2:
        raise ValueError("Not enough valid points to interpolate/extrapolate.")
    interp_x = interp1d(indices[valid_mask], x[valid_mask], kind=kind, fill_value=fill, bounds_error=False)
    interp_y = interp1d(indices[valid_mask], y[valid_mask], kind=kind, fill_value=fill, bounds_error=False)
    x_filled = x.copy()
    y_filled = y.copy()
    nan_mask = np.isnan(x) | np.isnan(y)
    x_filled[nan_mask] = interp_x(indices[nan_mask])
    y_filled[nan_mask] = interp_y(indices[nan_mask])
    return x_filled, y_filled

def limit_speed(x, y, max_speed):
    dx = np.diff(x.copy())
    dy = np.diff(y.copy())
    speeds = np.sqrt(dx**2 + dy**2)
    for i,t in enumerate(speeds):
        if t > max_speed:        
            x[i+1] = x[i] 
            y[i+1] = y[i] 
            x[i+2] = x[i] 
            y[i+2] = y[i] 
    return x, y

def remove_short_sequences(arr, max_len=10):
    arr = np.array(arr, dtype='float')
    result = arr.copy()
    is_value = ~np.isnan(arr)
    i = 0
    while i < len(arr):
        if is_value[i]:
            start = i
            while i < len(arr) and is_value[i]:
                i += 1
            end = i
            seq_len = end - start
            # Check if surrounded by NaNs and short enough
            if seq_len <= max_len:
                left_nan = (start == 0) or np.isnan(arr[start - 1])
                right_nan = (end == len(arr)) or np.isnan(arr[end])  # safe for edge
                if left_nan and right_nan:
                    result[start:end] = np.nan
        else:
            i += 1
    return result

In [None]:
if V4subfolder: 
    dlcpath=Path(f'{Path(dpath).parent.parent}/My_First_WebCam/')
    for file in os.listdir(dlcpath):
        if file.endswith(('.h5')):
            dlcfile=file
            break  
else:      
    dlcpath=Path(f'{Path(dpath).parent}/My_First_WebCam/')
    for file in os.listdir(dlcpath):
        if file.endswith(('.h5')):
            dlcfile=file
            break
dlc_path = os.path.join(dlcpath, dlcfile)
print(dlcfile)

# Load HDF5 file
df = pd.read_hdf(dlc_path)
directory = os.path.dirname(dlc_path)
timestamps_path = Path(directory,'timeStamps.csv')
if timestamps_path.exists():
    timestamps = pd.read_csv(timestamps_path)
    tswebcam = timestamps['Time Stamp (ms)']
    frame_rate = round(1/(np.mean(np.diff(timestamps.iloc[:,1]))/1000))  # fps
    print(f'Acquisition with DAQ, frame rate = {frame_rate} fps')
else:
    frame_rate = 16  # fps /!\ CHANGE ACCORDING TO YOUR DATA
    print(f'Acquisition with Webcam, frame rate = {frame_rate} fps')

X0 = df.iloc[:, 0]
Y0 = df.iloc[:, 1]

# Remove uncertain location predictions (likelihood < 0.9)
df.iloc[:, 0] = df.apply(lambda row: row.iloc[0] if row.iloc[2] > 0.5 else np.nan, axis=1)
df.iloc[:, 1] = df.apply(lambda row: row.iloc[1] if row.iloc[2] > 0.5 else np.nan, axis=1)

X = df.iloc[:, 0]
Y = df.iloc[:, 1]

# Separate the individual's positions into x and y coordinates
individual_xO= np.array(X.values)
individual_yO = np.array(Y.values)

# Define when the mouse is on the cheeseboard (start)
for i, x in enumerate(individual_xO):
    y = individual_yO[i]
    if calculate_relative_distance(x, y, table_center_x, table_center_y) >= table_radius:
        individual_xO[i] = np.nan
        individual_yO[i] = np.nan

individual_xOO = remove_short_sequences(individual_xO, max_len=3)
individual_yOO = remove_short_sequences(individual_yO, max_len=3)

x_start = find_long_non_nan_sequences(individual_xOO)[0][0] # first value of the first long non nan sequence
y_start = find_long_non_nan_sequences(individual_yOO)[0][0] # first value of the first long non nan sequence

start_frame = np.where(individual_xOO == x_start)[0][0].item()

individual_xOO[:start_frame]=np.nan # remove any path before the real start
individual_yOO[:start_frame]=np.nan # remove any path before the real start

individual_x1, individual_y1 = replace_high_speed_points_with_nan(individual_xOO, individual_yOO, speed_threshold=10)

#for i in range(len(individual_x1)-1, 0, -1): # Find the last non-NaN value which is not isolated
#    if not np.isnan(individual_x1[i]) and not np.isnan(individual_x1[i-1]):
#        last_frame = i
#        break

last_frame = len(individual_x1)

individual_x2, individual_y2 = interpolate_2d_path(individual_x1[start_frame:last_frame], individual_y1[start_frame:last_frame], kind='nearest')
individual_x3, individual_y3 = limit_speed(individual_x2, individual_y2, max_speed=20)

individual_x = np.concatenate((individual_x1[:start_frame], individual_x3))
individual_y = np.concatenate((individual_y1[:start_frame], individual_y3))

if len(individual_x) == len(tswebcam):
    if timestamps_path.exists():
        start_time = timestamps.iloc[start_frame,1].item() / 1000
        end_time = timestamps.iloc[-1,1].item() / 1000
        duration_trial = end_time - start_time
    else:
        duration_trial = (last_frame - start_frame) / frame_rate
    print(f'Total trial duration: {round(duration_trial)} sec')

    total_distance, speed = calculate_distance_run(individual_x[start_frame:last_frame], individual_y[start_frame:last_frame])
    print(f"Total distance run: {round(total_distance)} cm")
    print(f"Average speed: {round(np.nanmean(speed)/pixel_to_cm*frame_rate,2)} cm/s")

    # Create the plot
    fig, ax = plt.subplots(figsize=(3, 3)) 

    # Plot individual positions over time
    cmap = plt.get_cmap('gnuplot2')
    norm = plt.Normalize(vmin=0, vmax=len(individual_x))

    for i in range(1, len(individual_x)):
        ax.plot(individual_x[i-1:i+1], individual_y[i-1:i+1], color=cmap(norm(i)), linewidth=1)

    #plt.plot(individual_x, individual_y, label="Individual's Path", color='b')

    plt.scatter(x_start, y_start, color='black', s=100, label='Start')
    # Draw cheeseboard circle
    table_circle = plt.Circle((table_center_x, table_center_y), table_radius, color='k', fill=False)
    plt.gca().add_patch(table_circle) 

    # Add labels and title
    ax.set_aspect('equal')
    ax.invert_yaxis()
    plt.title(f'Mouse Path On Cheeseboard Maze')
    plt.xlabel('X Position')
    plt.ylabel('Y Position')
    plt.legend(loc='upper left')
    plt.show()
else: 
    print(f'Error: Length of DLC data ({len(individual_x)}) does not match length of timestamps ({len(tswebcam)})')

In [None]:
square_size = pixel_to_cm * 6

# Filter out NaNs
valid_mask = ~np.isnan(individual_x) & ~np.isnan(individual_y)
path_x = individual_x[valid_mask]
path_y = individual_y[valid_mask]

# Generate symmetric grid of square centers
n = int(np.floor(2 * table_radius / square_size))
offsets = (np.arange(n) - (n - 1) / 2.0) * square_size
centers_x = table_center_x + offsets
centers_y = table_center_y + offsets

# Count visits per square
counts = defaultdict(int)
for px, py in zip(path_x, path_y):
    if np.sqrt((px - table_center_x)**2 + (py - table_center_y)**2) > table_radius:
        continue  # skip points outside circle
    ix = int(np.floor((px - (table_center_x - n/2 * square_size)) / square_size))
    iy = int(np.floor((py - (table_center_y - n/2 * square_size)) / square_size))
    counts[(ix, iy)] += 1

max_count = max(counts.values())/3 if counts else 1

# Plot
fig, ax = plt.subplots(figsize=(3,3))

# Draw circle outline
theta = np.linspace(0, 2*np.pi, 500)
ax.plot(table_center_x + table_radius*np.cos(theta),
        table_center_y + table_radius*np.sin(theta),
        'k', lw=1)

# Draw squares
for i, cx in enumerate(centers_x):
    for j, cy in enumerate(centers_y):
        if np.sqrt((cx - table_center_x)**2 + (cy - table_center_y)**2) + square_size/np.sqrt(2) <= table_radius:
            count = counts.get((i, j), 0)
            if count > 0:                
                intensity = count / max_count # Map count to viridis colormap
                color = plt.cm.viridis(intensity)
            else:
                color = 'lightgrey'  # no visits
            rect = plt.Rectangle((cx - square_size/2, cy - square_size/2), 
                                 square_size, square_size, 
                                 facecolor=color, edgecolor=None, lw=0.5)
            ax.add_patch(rect)

ax.invert_yaxis()
ax.set_aspect('equal')
plt.show()

Plot the spatial map for all cells + interactive Ca2+ trace

In [None]:
# Set up selector object
discrete_slider = pn.widgets.DiscreteSlider(
    name= f"Unit n°", 
    options=[i for i in a.index],
    value=a.index[0]
)

next_unit_button = pn.widgets.Button(name='Next unit > ', button_type='primary')
previous_unit_button = pn.widgets.Button(name='< Previous unit')

# Define a callback function for the button
def nextunit_callback(event):
    position = np.where(a.index == discrete_slider.value)[0]
    position = position[0]
    nextunitvalue=a.index[position + 1] if position+2<=len(a) else a.index[0]
    discrete_slider.value = nextunitvalue
    
# Define a callback function for the button
def previousunit_callback(event):
    position = np.where(a.index == discrete_slider.value)[0]
    position = position[0]
    previousunitvalue=a.index [position - 1]
    discrete_slider.value = previousunitvalue

next_unit_button.on_click(nextunit_callback)
previous_unit_button.on_click(previousunit_callback)

# Define interactivity
@pn.depends(indexes=discrete_slider)
def calciumtrace(indexes):
    index = indexes
    position = np.where(a.index == index)[0]
    position = position[0]
    return hv.Curve((tsmini/1000, C[position, :]), label=f'Unit n°{index} \nNr #{position}').opts(ylim=(0, 10), xlim=(0, tsmini.values[-1]/1000),frame_height=200, color='red')

@pn.depends(indexes=discrete_slider)
def unitshadow(indexes):
    index = indexes 
    data=A.sel(unit_id=index)
    x = np.linspace(0, 600, 600)
    y = np.linspace(0, 600, 600)
    masked_data = np.where(data < 0.01, np.nan, data) 
    return hv.Image((x, y, masked_data)).opts(cmap='hot', clim=(0, 1))

@pn.depends(indexes=discrete_slider)
def circlepath(indexes):
    index = indexes
    radius = 15
    num_points=100
    theta = np.linspace(0, 2*np.pi, num_points)
    position = np.where(a.index == index)[0]
    position = position[0]
    return hv.Path((a.iloc[position, 0] + radius * np.cos(theta), a.iloc[position, 1] + radius * np.sin(theta)), group='keep').opts(ylim=(0, 600), xlim=(0, 600), line_color='red', line_width=3) #

In [None]:
output_size = 120
hv.output(size=int(output_size))

image = hv.Image(
    A.max("unit_id").compute().astype(np.float32).rename("A"),
    kdims=["width", "height"],
).opts(colorbar=False, invert_yaxis=False,cmap="Viridis")

alltraces=hv.NdOverlay({idx: hv.Curve((tsmini/1000, C[idx,:])).opts(frame_height=200, show_legend=False, color='black', alpha=0.2, xlabel='time (s)')
                       for idx in np.arange(len(C))})

start =  hv.VLine(start_time).opts(color='blue', line_width=2)
starttxt = hv.Text(start_time + 5, 9.5, 'Start').opts(text_color='blue')

layout = pn.Column(pn.Row(image * hv.DynamicMap(unitshadow), 
            pn.Column(starttxt * start * hv.DynamicMap(calciumtrace), discrete_slider, pn.Row(previous_unit_button, next_unit_button),       
                    ),
                    ))   

display(layout)

Choose a neuron to plot

In [None]:
if V4subfolder:
    V4subfolder_id = int(folderMouse.parent.name[-1]) - 1 
    ts_start = V4subfolder_id*15*1000 # cause 15 videos per subfolders and 1000 frames per videos
    ts_stop = np.shape(C)[1] + ts_start
    tsmini_sub=tsmini[ts_start:ts_stop]
    tsmini_sub=tsmini_sub.reset_index(drop=True)
else: 
    tsmini_sub=tsmini

In [None]:
nr=4

Plot neuron's activity during the trial

In [None]:
plt.close('all')
plt.figure(figsize=[10,2])
total_distance, speed = calculate_distance_run(individual_x[start_frame:last_frame], individual_y[start_frame:last_frame])
plt.plot(tswebcam[start_frame+1:last_frame]/1000, (speed)/max(speed), 'c', label='Mouse speed')

Cnr=C[nr,:]

normalized_C = (Cnr - np.min(Cnr)) / (np.max(Cnr.values) - np.min(Cnr.values)) if np.sum(Cnr)!=0 else Cnr
plt.plot(tsmini_sub[:]/1000, normalized_C, 'k', label= f'Neuron #{nr}')

Call=C[:,:]
normalized_Cmean = (np.mean(Call, axis=0) - np.min(np.mean(Call, axis=0))) / (np.max(np.mean(Call, axis=0)) - np.min(np.mean(Call, axis=0))) 
plt.plot(tsmini_sub[:]/1000, normalized_Cmean , 'k', alpha=0.2, label= f'Mean neuron activity')

plt.axvline(x=tswebcam[start_frame+1]/1000, color='b', linewidth=2)
plt.text(tswebcam[start_frame+1]/1000, plt.gca().get_ylim()[1], 'Start', fontsize=8, color='blue')
plt.xlabel('time (s)')
plt.legend(frameon=False, bbox_to_anchor=(1, 1))
plt.tight_layout() 
plt.show()

Plot neuron on the cheeseboard

In [None]:
# If the miniscope recording started after the webcam recording, cut the webcam data
if tsmini_sub.iloc[0] > tswebcam[start_frame]:
    Newstart_frame = np.where(tswebcam >= tsmini_sub.iloc[0].item())[0][1].item()
    print(f'... webcam data cut to match miniscope length, new start at frame {Newstart_frame} (instead of {start_frame})')
    start_frame = Newstart_frame 
# If the miniscope recording is shorter than the webcam recording, cut the webcam data
if tsmini_sub.iloc[-1] < tswebcam[last_frame-1]:
    Newlast_frame = np.where(tswebcam <= tsmini_sub.iloc[-1].item())[0][-1].item()
    print(f'... webcam data cut to match miniscope length, new end at frame {Newlast_frame} (instead of {last_frame})')
    last_frame = Newlast_frame


# Align data
x = individual_x[start_frame:last_frame]
y = individual_y[start_frame:last_frame]
        
Cnr_ = Cnr[:].to_numpy()

# Plot circular environment
fig, ax = plt.subplots(figsize=(4, 4))
table_circle = plt.Circle((table_center_x, table_center_y), table_radius, color='k', fill=False)
plt.gca().add_patch(table_circle) 

# Normalize Cnr_ for colormap
norm = mcolors.Normalize(vmin=Cnr_.min(), vmax=Cnr_.max())
cmap = cm.jet

# Plot path with colormap
for i in range(len(x) - 1):
    closest_point= find_closest_index_sorted(tsmini_sub, tswebcam[start_frame+i])
    color = cmap(norm(Cnr_[closest_point]))
    ax.plot([x[i], x[i+1]], [y[i], y[i+1]], color=color, linewidth=3)

# Add colorbar
sm = cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax)
#plt.scatter(x_start, y_start, color='black', s=100, label='Start')
ax.invert_yaxis()
ax.set_aspect('equal')
plt.title(f'Neuron #{nr}')
plt.tight_layout()
plt.show()

In [None]:
# If the miniscope recording started after the webcam recording, cut the webcam data
if tsmini_sub.iloc[0] > tswebcam[start_frame]:
    Newstart_frame = np.where(tswebcam >= tsmini_sub.iloc[0].item())[0][1].item()
    print(f'... webcam data cut to match miniscope length, new start at frame {Newstart_frame} (instead of {start_frame})')
    start_frame = Newstart_frame 
# If the miniscope recording is shorter than the webcam recording, cut the webcam data
if tsmini_sub.iloc[-1] < tswebcam[last_frame-1]:
    Newlast_frame = np.where(tswebcam <= tsmini_sub.iloc[-1].item())[0][-1].item()
    print(f'... webcam data cut to match miniscope length, new end at frame {Newlast_frame} (instead of {last_frame})')
    last_frame = Newlast_frame


# Align data
x = individual_x[start_frame:last_frame]
y = individual_y[start_frame:last_frame]
        
Cnr_ = Cnr[:].to_numpy()

# Generate symmetric grid of square centers
n = int(np.floor(2 * table_radius / square_size))
offsets = (np.arange(n) - (n - 1) / 2.0) * square_size
centers_x = table_center_x + offsets
centers_y = table_center_y + offsets

nr_tot_act_biaised = defaultdict(int) # Neuron activity per square
counts = defaultdict(int) # Count visits per square
for idx, (px, py) in enumerate(zip(x, y)):
    if np.sqrt((px - table_center_x)**2 + (py - table_center_y)**2) > table_radius:
        continue  # skip points outside circle
    ix = int(np.floor((px - (table_center_x - n/2 * square_size)) / square_size))
    iy = int(np.floor((py - (table_center_y - n/2 * square_size)) / square_size))
    closest_point= find_closest_index_sorted(tsmini_sub, tswebcam[start_frame+idx])
    nr_tot_act_biaised[(ix, iy)] += Cnr_[closest_point]
    counts[(ix, iy)] += 1

nr_tot_act = defaultdict(float)
for key in counts.keys():
    if counts[key] != 0:   # avoid division by zero
        nr_tot_act[key] = nr_tot_act_biaised[key] / counts[key]
    else:
        nr_tot_act[key] = np.nan  
max_nr_tot_act = max(nr_tot_act.values())

# Plot
fig, ax = plt.subplots(figsize=(3,3))
# Draw circle outline
theta = np.linspace(0, 2*np.pi, 500)
ax.plot(table_center_x + table_radius*np.cos(theta),
        table_center_y + table_radius*np.sin(theta),
        'k', lw=1)
# Draw squares
for i, cx in enumerate(centers_x):
    for j, cy in enumerate(centers_y):
        if np.sqrt((cx - table_center_x)**2 + (cy - table_center_y)**2) + square_size/np.sqrt(2) <= table_radius:
            nr_act = nr_tot_act.get((i, j), np.nan)
            if ~ np.isnan(nr_act):                
                intensity = nr_act / max_nr_tot_act # Map count to viridis colormap
                color = plt.cm.viridis(intensity)
            else:
                color = 'lightgrey'  # no visits
            rect = plt.Rectangle((cx - square_size/2, cy - square_size/2), 
                                 square_size, square_size, 
                                 facecolor=color, edgecolor='None', lw=0.5)
            ax.add_patch(rect)
ax.invert_yaxis()
ax.set_aspect('equal')
plt.show()