### For Berke Lab: plot DA/ACh aligned to port entry

Good test of if our time alignment is actually working! (it is!!!!! yay!!!)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pynwb import NWBHDF5IO
from spyglass.common import Nwbfile

# Replace this with path to your nwb!!
nwb_file_name = "IM-1478_20220725.nwb"
nwb_path = Nwbfile().get_abs_path(nwb_file_name)
#nwb_path = "/Users/steph/berkelab/jdb_to_nwb/plotplot/IM-1478_20220725.nwb"
name_of_DA_trace = "z_scored_green_dFF" # "z_scored_green_dFF" if Tim/YS, "zscored_470_405_ratio" if Jose

with NWBHDF5IO(nwb_path, mode="r+") as io:
    nwbfile = io.read()
    session_id = nwbfile.session_id

    # Get trial and reward data
    trials = nwbfile.intervals["trials"]
    poke_in_times = trials["poke_in"].data[:]
    rewards = trials["reward"].data[:]

    # Get DA trace
    DA_object = nwbfile.acquisition[name_of_DA_trace]
    DA_trace = DA_object.data[:]
    timestamps = DA_object.get_timestamps()
    sampling_rate = DA_object.rate  


# Define time window
time_window = 3  # seconds
num_samples = int(2 * time_window * sampling_rate)
time_vector = np.arange(-time_window, time_window, 1/sampling_rate)[:num_samples]

# Split by rewarded vs unrewarded trials
rewarded, un_rewarded = [], []

for poke_in, reward in zip(poke_in_times, rewards):
    
    # Extract DA trace centered on poke_in
    idx = np.where(timestamps == poke_in)[0][0]
    start_idx = max(0, idx - num_samples // 2)
    end_idx = start_idx + num_samples

    if end_idx <= len(DA_trace):  # Ensure within bounds
        trace = DA_trace[start_idx:end_idx]
        (rewarded if reward == 1 else un_rewarded).append(trace)

# Convert to arrays
rewarded, un_rewarded = map(np.array, (rewarded, un_rewarded))

print(f"Got {len(poke_in_times)} pokes from nwbfile!")
print(f"{rewarded.shape[0]} rewarded")
print(f"{un_rewarded.shape[0]} unrewarded")

# Now plot!!
fig, axes = plt.subplots(2, 1, figsize=(8, 10), sharex=False)

# First subplot: Average DA response across all trials
axes[0].plot(time_vector, rewarded.mean(axis=0), label="rewarded", color="red")
axes[0].plot(time_vector, un_rewarded.mean(axis=0), label="unrewarded", color="blue")
axes[0].axvline(0, linestyle="--", color="black", label="Poke In")
axes[0].set_xlabel("Time (s)")
axes[0].set_ylabel("z-scored dF/F")
axes[0].set_title(f"DA aligned to port entry ({session_id})")
axes[0].legend()

# Second subplot: Full session DA trace
axes[1].plot(timestamps, DA_trace)
for poke_in, reward in zip(poke_in_times, rewards):
    color = "red" if reward == 1 else "blue"
    axes[1].axvline(poke_in, linestyle="--", color=color, label="Poke In")
axes[1].set_xlabel("Time (s)")
axes[1].set_ylabel("z-scored dF/F")
axes[1].set_title("Full session DA trace and port entries")
axes[1].set_ylim([-2, 10])

plt.tight_layout()
plt.show()

## Plot the rat's raw position overlayed on the maze

Crates a heatmap of rat position for each block, and a combined figure with the rat's positon for the first vs second half of each block (to hopefully show trajectory refinement as the rat learns the maze)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from hexmaze import plot_hex_maze
from spyglass.common import Nwbfile

# Replace this with path to your nwb!!
nwb_file_name = "BraveLu20240519_.nwb"
nwb_path = Nwbfile().get_abs_path(nwb_file_name)
#nwb_path = "/Users/steph/berkelab/jdb_to_nwb/nwbs/centroids/IM-1478_20220725.nwb"

# Choose the name of the spatial series in the nwb to plot
spatial_series_name = "led_0_series_1" # "cap_back_position"

with NWBHDF5IO(nwb_path, mode="r+") as io:
    nwbfile = io.read()
    session_id = nwbfile.session_id

    # Get hex centroids and convert to a dict so we can plot the hex maze with custom centroids
    behavior_module = nwbfile.processing["behavior"]
    if "hex_centroids" in behavior_module.data_interfaces:
        centroids_df = behavior_module.data_interfaces["hex_centroids"].to_dataframe()
        centroids_dict = centroids_df.set_index('hex')[['x', 'y']].apply(tuple, axis=1).to_dict()
    else:
        # If we have no centroids, still make the plot, just without the maze background
        # (It is way worse this way... but fine I guess)
        centroids_dict = None

    # Get position data for the given spatial series
    position = behavior_module.data_interfaces["position"].spatial_series[spatial_series_name]
    position_df = pd.DataFrame(position.data, columns=["x", "y"]) 
    position_df["timestamp"] = position.timestamps

    # Get block data
    block_data = nwbfile.intervals["block"].to_dataframe()
    n_blocks = len(block_data)

    # Set up n_blocks x 2 plot to plot rat position by (first half, second half) of each block
    fig, axs = plt.subplots(n_blocks, 2, figsize=(8, 4 * n_blocks), sharex=True, sharey=True)

    for row, block in enumerate(block_data.itertuples(index=False)):
        # Get maze configuration and reward probabilities for this block
        maze = block.maze_configuration
        reward_probs = [block.pA, block.pB, block.pC]

        # Filter position data for this block (exluding nans)
        block_times = (position_df["timestamp"] >= block.start_time) & (position_df["timestamp"] <= block.stop_time)
        block_positions = position_df[block_times].dropna(subset=['x', 'y'])
        
        ### Plot 1: Rat position heatmap for this block
        fig_full, ax_full = plt.subplots(figsize=(6, 6))

        # Create 2D histogram (aka heatmap) of the rat's x, y positions in this block
        heatmap_full, xedges, yedges = np.histogram2d(
            block_positions['x'].values, block_positions['y'].values, bins=100
        )
        heatmap_full_masked = np.ma.masked_where(heatmap_full == 0, heatmap_full)
        log_heatmap_full = np.log1p(heatmap_full_masked)

        # Plot maze layout (open hexes only) using custom centroids if they exist
        if centroids_dict is not None:
            plot_hex_maze(
                barriers=maze, centroids=centroids_dict, ax=ax_full, show_hex_labels=False,
                show_barriers=False, show_choice_points=False, reward_probabilities=reward_probs,
                invert_yaxis=True
            )
        # Plot rat position heatmap on top of the hexes
        ax_full.imshow(
            log_heatmap_full.T, origin='lower', cmap='viridis',
            extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]],
            aspect='equal', zorder=1
        )
        ax_full.set_xticks([])
        ax_full.set_yticks([])
        ax_full.set_title(f"Rat position heatmap ({session_id}, block {block.block})")
        fig_full.tight_layout()
        fig_full.show()
        
        # Now do the same thing, but split into first/second half so we can see behavioral adaptation
    
        # Split block into first and second half
        block_midpoint = len(block_positions) // 2
        halves = [block_positions.iloc[:block_midpoint], block_positions.iloc[block_midpoint:]]

        # Add this block half to our big plot
        for col, half in enumerate(halves):
            # Create 2D histogram (aka heatmap) of the rat's x, y positions in this block half
            heatmap, xedges, yedges = np.histogram2d(half['x'].values, half['y'].values, bins=100)
            heatmap_masked = np.ma.masked_where(heatmap == 0, heatmap)
            log_heatmap = np.log1p(heatmap_masked)

            ax = axs[row, col]
            # Plot maze layout (open hexes only) using custom centroids if they exist
            if centroids_dict is not None:
                plot_hex_maze(
                    barriers=maze, centroids=centroids_dict, ax=ax, show_hex_labels=False,
                    show_barriers=False, show_choice_points=False, reward_probabilities=reward_probs,
                    invert_yaxis=True
                )
            # Plot rat position heatmap on top of the hexes
            im = ax.imshow(
                log_heatmap.T, origin='lower', cmap='viridis', extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]],
                aspect='equal', zorder=1
            )
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_title(f"Block {block.block} ({'first' if col == 0 else 'second'} half)")

    fig.suptitle(f"{session_id} rat position heatmaps", fontsize=16)
    fig.tight_layout()
    plt.show()

## Assign each x, y location to a hex!

Note that the 'hex' column will be NaN for time points outside valid block boundaries.

This is just a test that we can do assignment correctly - next step is putting it in an appropriate downstream table in spyglass!

In [None]:
import math

# Specify the position series you want to use as the rat's position for assignment
# (This should probably be the head position if you are tracking multiple body parts)
name_of_position_series = 'cap_front_position' # cap_position for Jose, cap_front_position for Tim/YS

def euclidean_distance(coord1, coord2):
    """ Calculate Euclidean distance between two coordinates """
    return math.sqrt((coord1[0] - coord2[0]) ** 2 + (coord1[1] - coord2[1]) ** 2)

def assign_to_hex(x_list, y_list, hex_centroids):
    """ 
    Assign each (x, y) coordinate to the nearest hex centroid in hex_centroids.
    
    Make sure the rat's x, y coordinates and the centroids in hex_centroids
    are in the same units (both in pixels or both in cm/meters) !!
    
    Args:
    x_list (list): List of the rat's x coordinates
    y_list (list): List of the rat's y coordinates
    hex_centroids (dict): Dictionary of hex: (x, y) centroid
    
    Returns: 
    List of hexes the same length as x_list and y_list indicating which hex
    this point has been assigned to
    """
    
    # Check that we have the same length x and y
    if len(x_list) != len(y_list):
        raise ValueError("x_list and y_list must have the same length.")
    
    coordinates_list = list(zip(x_list, y_list))
    hex_list = []

    for coord in coordinates_list:
        min_distance = float('inf')
        closest_hex = None
        
        # Loop through all of the hexes to find which one is the closest
        for hex, hex_coords in hex_centroids.items():
            distance_from_hex = euclidean_distance(coord, hex_coords)
            if distance_from_hex < min_distance:
                min_distance = distance_from_hex
                closest_hex = hex
        
        hex_list.append(closest_hex)
    
    return hex_list


with NWBHDF5IO(nwb_path, mode="r+") as io:
    nwbfile = io.read()
    behavior_module = nwbfile.processing["behavior"]

    # Get hex centroids and convert to a dict of hex: (x, y)
    centroids_df = behavior_module.data_interfaces["hex_centroids"].to_dataframe()
    centroids_dict = centroids_df.set_index('hex')[['x', 'y']].apply(tuple, axis=1).to_dict()

    # Get position data (we may have multiple spatial series, so choose the name)
    position = behavior_module.data_interfaces["position"].spatial_series[name_of_position_series]
    position_df = pd.DataFrame(position.data, columns=["x", "y"]) 
    position_df["timestamp"] = position.timestamps
    position_df["hex"] = np.nan

    # Get block data
    block_data = nwbfile.intervals["block"].to_dataframe()
    for block in block_data.itertuples(index=False):
        # Get barrier locations for this block (convert to list of ints)
        barriers = [int(hex) for hex in block.maze_configuration.split(",")]
        
        # Get the centroids of open hexes for this block (remove hexes that are barriers)
        hex_centroids_block = {hex: coords for hex, coords in centroids_dict.items() if hex not in barriers}
        
        # Get position indices for the current block
        block_mask = (position_df["timestamp"] >= block.start_time) & (position_df["timestamp"] <= block.stop_time)

        # Assign each x, y position in this block to a hex 
        position_df.loc[block_mask, "hex"] = assign_to_hex(position_df.loc[block_mask, 'x'], 
                                                           position_df.loc[block_mask, 'y'], 
                                                           hex_centroids_block)


    display(position_df)


  position_df.loc[block_mask, "hex"] = assign_to_hex(position_df.loc[block_mask, 'x'],


Unnamed: 0,x,y,timestamp,hex
0,,,0.311375,
1,,,0.373525,
2,,,0.440090,
3,115.463905,226.461227,0.501664,27.0
4,116.761955,225.359528,0.565580,27.0
...,...,...,...,...
78036,83.313751,370.878845,5202.862919,
78037,83.252724,370.941010,5202.943314,
78038,83.065994,371.158356,5203.007520,
78039,83.120041,370.853638,5203.070971,
