Helper script for looking at the response

In [None]:
# now open in napari
import napari
import tifffile
import os
import numpy as np
import matplotlib.pyplot as plt

from photostim_deve.control_exp.io import get_med_img_s2p

from photostim_deve.response.io import parse_mark_points, mp_dict_to_stim_list, load_photostim_protocol
from photostim_deve.response.compute import get_fov_resp, get_fov_resp_mn

%load_ext autoreload
%autoreload 2

In [None]:
def get_all_tiff_paths(tiff_dir):
    """
    Get full tiff paths from the suite2p motion corrected tiff directory.
    It also ensures that the tiff files are sorted correctly despite strange s2p naming conventions.
    
    -------------
    
    Parameters:
        tiff_dir : (str)
            Directory containing the motion corrected tiff files from suite2p.

    Returns:
        all_tiff_paths : (list) 
            List of full paths to the tiff files, sorted by their start frame index.

    """
    
    all_tiff_paths = [os.path.join(tiff_dir, tiff_path) for tiff_path in os.listdir(tiff_dir) if tiff_path.endswith('.tif')]
    all_tiff_paths.sort()

    # get tiff start frame index for each tiff file (string between file00 and _)
    tiff_start_frames = [int(os.path.basename(tiff_path).split('file00')[1].split('_')[0]) for tiff_path in all_tiff_paths]
    tiff_start_frames = np.array(tiff_start_frames)

    # now resort the tiff paths and start frames
    sort_indices = np.argsort(tiff_start_frames)
    all_tiff_paths = [all_tiff_paths[i] for i in sort_indices]

    return all_tiff_paths

In [None]:
def zscore_act(act):
    """
    Z-score the rows of a 2D array for visualization.
    """
    act_mean = np.mean(act, axis=1, keepdims=True)
    act_std = np.std(act, axis=1, keepdims=True)
    return (act - act_mean) / act_std

In [None]:
def get_dist_dff(fov_map, all_point, all_coords_x, all_coords_y, fov_shape=(512, 512), n_dist_bins=724):
    """
    Calculate the df/f of a pixel in the response map as a function of distance from the stimulus point.
    This function computes the mean and standard deviation across pixels within specified distance bins from each stimulus point.
    
    ------------------
    
    Parameters:
        fov_map : (np.ndarray)
            The response map with shape (n_stim_points, height, width) where each slice corresponds to a mean pixel response to a stimulus point (averaged across all repetitions of that stimulus).
        all_point : (np.ndarray)
            Array of stimulus point indices corresponding to each stimulation.
        all_coords_x : (np.ndarray)
            X coordinates of the stimulus point corresponding to each stimulation
        all_coords_y : (np.ndarray)
            Y coordinates of the stimulus point corresponding to each stimulation
        fov_shape : (tuple)
            Shape of the field of view (height, width).
        n_dist_bins : (int)
            Number of distance bins to compute statistics for.

    Returns:
        dist_diff_mn : (np.ndarray)
            Each row corresponds to the mean of the pixel values for all pixels within a distance bin from a stimulus point. Rows correspond to different stimulus points.
        dist_diff_std : (np.ndarray)
            Each row corresponds to the standard deviation of the pixel values for all pixels within a distance bin from a stimulus point. Rows correspond to different stimulus points.

    """

    dist_max = np.sqrt(fov_shape[0]**2 + fov_shape[1]**2)  # Maximum distance in pixels (diagonal of the FOV)
    dist_bins = np.linspace(0, dist_max, n_dist_bins)  # Create bins for distances

    dist_diff_mn = np.zeros((len(np.unique(all_point)), n_dist_bins))
    dist_diff_std = np.zeros((len(np.unique(all_point)), n_dist_bins))

    for i in np.unique(all_point):
        coords_x = all_coords_x[i]
        coords_y = all_coords_y[i]

        print(f"Stimulus Point {i}: Coordinates: ({coords_x}, {coords_y})")
        
        # Create a distance map
        y_indices, x_indices = np.indices(fov_shape)
        distance_map = np.sqrt((x_indices - coords_x) ** 2 + (y_indices - coords_y) ** 2).T

        for j in range(1, n_dist_bins):
            dist_mask = np.zeros(fov_shape, dtype=bool)
            dist_mask = (distance_map >= dist_bins[j-1]) & (distance_map < dist_bins[j])

            mn = np.mean(fov_map[i][dist_mask])
            std = np.std(fov_map[i][dist_mask])

            dist_diff_mn[i, j-1] = mn
            dist_diff_std[i, j-1] = std

    return dist_diff_mn, dist_diff_std

In [None]:
from scipy.interpolate import interp1d

def compute_dist_kernel(dist_diff_mn, n_dist_bins=724):
    """
    Compute the distance kernel from the mean df/f as a function of distance from the stimulus point.
    
    Parameters:
        dist_diff_mn : (np.ndarray)
            Each row corresponds to the mean of the pixel values for all pixels within a distance bin from a stimulus point. Rows correspond to different stimulus points.
        n_dist_bins : (int)
            Number of distance bins to compute statistics for. This should match the number of distance bins used in get_dist_dff.
    
    Returns:
        k1d : (np.ndarray)
            The distance kernel computed as the mean of the mean df/f across all stimulus points.
        k2d : (np.ndarray)
            The distance kernel rotated around the center of FOV.

    """
    k1d = np.nanmean(dist_diff_mn, axis=0)

    # now generate a 2D kernel from the 1D kernel
    kernel_size = n_dist_bins  # Size of the 2D kernel (e.g., 512x512)
    x = np.linspace(-kernel_size//2, kernel_size//2, kernel_size)
    y = np.linspace(-kernel_size//2, kernel_size//2, kernel_size)
    # Create a meshgrid for 2D coordinates
    X, Y = np.meshgrid(x, y)

    # Calculate the distance from the center for each point in the 2D grid
    distance = np.sqrt(X**2 + Y**2)
    # Interpolate the 1D kernel to create a 2D kernel
    interp_func = interp1d(np.linspace(0, kernel_size//2, len(k1d)), k1d, bounds_error=False, fill_value="extrapolate")
    # Create the 2D kernel by applying the interpolation function to the distance map
    k2d = interp_func(distance)

    return k1d, k2d


In [None]:
# TODO: add save path 
def plot_xyoff(xoff, yoff, save_path=None):
    """
    Plot the x and y offsets from the suite2p motion correction.
    """

    plt.figure(figsize=(10, 1), dpi=300)
    plt.title('Movement correction offsets')
    plt.plot(xoff, label='x')
    plt.plot(yoff, label='y')
    plt.xlabel('Frame')
    plt.ylabel('Offset (px)')
    plt.legend(loc='upper right')

    if save_path is not None:
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
    

In [None]:
def plot_protocol(all_frame, all_point, n_frames=36000, save_path=None):
    """
    Plot the photostimulation protocol.
    """
    plt.figure(figsize=(10, 2), dpi=200)
    for point_idx in np.unique(all_point):
        x = all_frame[all_point == point_idx]
        y = np.ones_like(x) * point_idx
        plt.scatter(x, y, label=f'Point {point_idx}', s=1)
    plt.xlabel('Frame')
    plt.ylabel('Point')
    plt.title('Photostim protocol')
    plt.xlim(0, n_frames)
    # invert y axis
    plt.gca().invert_yaxis()

    if save_path is not None:
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
    

In [None]:
def plot_fov_diff_single(fov_diff, all_point, all_coords_x, all_coords_y, stim_idx, vlim=200, save_path=None):
    """
    Plot the difference image for a single stimulation trial.
    """
    point = all_point[stim_idx]
    
    plt.figure()
    plt.imshow(fov_diff[stim_idx, :, :], cmap='bwr', vmin=-vlim, vmax=vlim)
    
    plt.scatter(all_coords_y[stim_idx], all_coords_x[stim_idx], color='black', s=1, label='Stimulus Point')
    plt.title(f'point {point}')
    plt.axis('off')

    if save_path is not None:
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
    

In [None]:
def plot_fov_all_point(mn_image, all_point, all_coords_x, all_coords_y, txt_shift=(7, 7), save_path=None):
    """
    Plot the mean image map with the stimulation points.
    """
    plt.figure(figsize=(10, 10))
    plt.imshow(mn_image, cmap='gray')

    for i in np.unique(all_point):
        plt.scatter(all_coords_y[i], all_coords_x[i], color='C0', s=1, label='stim point')
        plt.text(all_coords_y[i] + txt_shift[0], all_coords_x[i] + txt_shift[1], str(i), color='C0', fontsize=8, ha='center', va='center')
        plt.axis('off')

    if save_path is not None:
        plt.savefig(save_path, bbox_inches='tight', dpi=300)

In [None]:
def plot_dist_dff(dist_diff_mn, n_points=40, dist_bins_xlim=362, dist_bins_xlim_zoom=45, save_path=None):
    """
    Plot the mean df/f as a function of distance from the stimulus point.

    ------------------

    Parameters:

        dist_diff_mn : (np.ndarray)
            Each row corresponds to the mean of the pixel values for all pixels within a distance bin from a stimulus point. Rows correspond to different stimulus points.
        n_points : (int)
            Number of unique stimulus points in the experiment.
        dist_bins_xlim : (int)
            The x-axis limit for the main plot.
        dist_bins_xlim_zoom : (int)
            The x-axis limit for the zoomed-in plot.
        save_path : (str or None)
            Path to save the plot. If None, the plot will not be saved.
            
    """

    fig, axs = plt.subplots(n_points+1, 2, figsize=(10, 2*(n_points+1)), gridspec_kw={'width_ratios': [5, 1]})
    for i in range(n_points):
        axs[i, 0].plot(dist_diff_mn[i, :], label='mn', c='grey')
        axs[i, 0].set_xlim(0, dist_bins_xlim)
        axs[i, 0].set_ylabel(f'point {i}')

        # now same as above but zoomed in   
        axs[i, 1].plot(dist_diff_mn[i, :], label='mn', c='grey')
        axs[i, 1].set_xlim(0, dist_bins_xlim_zoom)

        # remove all top and right spines
        axs[i, 0].spines['top'].set_visible(False)
        axs[i, 0].spines['right'].set_visible(False)
        axs[i, 1].spines['top'].set_visible(False)
        axs[i, 1].spines['right'].set_visible(False)

        # only put x tick labels on the bottom row
        axs[i, 0].set_xticklabels([])
        axs[i, 1].set_xticklabels([])

    axs[i+1, 0].plot(dist_diff_mn.T,c='grey', alpha=0.1)
    axs[i+1, 0].plot(np.nanmean(dist_diff_mn, axis=0))
    axs[i+1, 0].set_xlim(0, dist_bins_xlim)
    
    axs[i+1, 1].plot(dist_diff_mn.T,c='grey', alpha=0.1)
    axs[i+1, 1].plot(np.nanmean(dist_diff_mn, axis=0))
    axs[i+1, 1].set_xlim(0, dist_bins_xlim_zoom)
    axs[i+1, 0].set_ylabel('mean')

    axs[i+1, 0].set_xlabel('Dist (pixels)')
    axs[i+1, 1].set_xlabel('Dist (pixels)')

    axs[i+1, 0].spines['top'].set_visible(False)
    axs[i+1, 0].spines['right'].set_visible(False)
    axs[i+1, 1].spines['top'].set_visible(False)
    axs[i+1, 1].spines['right'].set_visible(False)

    if save_path is not None:
        plt.savefig(save_path, bbox_inches='tight', dpi=300)

In [None]:
def plot_fov_map(fov_plot, all_coords_x, all_coords_y, vlim=200, save_path=None):
    """
    Plot the fov_map with all stim points overlaid.

    Parameters:
        fov_plot : (np.ndarray)
            The response map with shape (n_stim_points, height, width) where each slice corresponds to a mean pixel response to a stimulus point.
        all_coords_x : (np.ndarray)
            X coordinates of the stimulus point corresponding to each stimulation
        all_coords_y : (np.ndarray)
            Y coordinates of the stimulus point corresponding to each stimulation
        vlim : (int)
            Saturation limits for visualizing fov_map.
    """

    n_rows = 8
    n_cols = int(fov_plot.shape[0] / n_rows)

    fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols*4, n_rows*4), dpi=300)

    for i in range(n_rows):
        for j in range(n_cols):
            idx = i * n_cols + j
            if idx < fov_plot.shape[0]:
                axs[i, j].imshow(fov_plot[idx, :, :], cmap='bwr', vmin=-vlim, vmax=vlim)
                axs[i, j].scatter(all_coords_y[idx], all_coords_x[idx], color='black', s=1, label='Stimulus Point')
                axs[i, j].set_title(f'point {idx}')
                axs[i, j].axis('off')

    if save_path is not None:
        plt.savefig(save_path, bbox_inches='tight', dpi=600)


In [None]:
def plot_kernel_2d(k2d, fov_shape=(512, 512), n_dist_bins=724, vlim=200, save_path=None):
    """
    Plot the 2D kernel.
    
    Parameters:
        k2d : (np.ndarray)
            The 2D kernel to plot.
        fov_shape : (tuple)
            Shape of the field of view (height, width).
        vlim : (int)
            Saturation limits for visualizing the kernel.
    """

    # Display
    plt.imshow(k2d, cmap='bwr', vmin=-vlim, vmax=vlim)
    plt.colorbar()

    cent_xy = n_dist_bins//2

    plt.scatter(cent_xy, cent_xy, color='black', s=1, label='Stimulus Points')
    plt.title("2D Rotated Measured Kernel")
    plt.xlim(cent_xy - fov_shape[1]//2, cent_xy + fov_shape[1]//2)
    plt.ylim(cent_xy - fov_shape[0]//2, cent_xy + fov_shape[0]//2)
    plt.axis('off')
    
    if save_path is not None:
        plt.savefig(save_path, bbox_inches='tight', dpi=300)

In [None]:
# 1) Set parameters
data_dir = 'data_loc' # data_loc is  the directory on local ssd (only two sessions, one for jm049 and one for jm048)
experimenter = 'jm'
mouse = 'jm049' # 'jm049' or 'jm048'
session =  '2025-05-23_b' # '2025-05-23_b' or '2025-05-08_c'
channel = 2
plane = 0
frame_period = 0.033602463 # for jm049: 0.033602463 or for jm048: 0.033602476 # exact frame period from metadata (for '30Hz' acquisition) # TODO: GET FROM METADATA!!!
fov_shape = (512, 512) # shape of the FOV in pixels

# baseline and response parameters
bsln_n_frames = 10 # baseline window in frames
resp_n_frames = 10 # response window in frames

bsln_sub_type = 'trial_by_trial' # 'trial_by_trial' (subtract mean of bsln_n_frames for that specific trial) or 'session_wide' (subtract mean of bsln_n_frames across all repetitions of all trials)

# spatial extent of response
n_dist_bins = 724 # Number of distance bins when computing df/f as a function of distance from mark point - 724 corresponds to 1 pixel resolution for a 512x512 FOV (diagonal)

# visualisation parameters
vlim = 200          # saturation limits for visualising fov_map (response to stimulation)
txt_shift = (7, 7)  # shift when labeling stim point or ROI centroid positions with text

dist_bins_xlim = 724//2 # Limit the distance bins to half the FOV size for visualization (maximum distance for a stim point in the center - lower bound)
dist_bins_xlim_zoom = 724//16


In [None]:
# TODO: simply append this to the pipeline after the suite2p pipeline if the session is _b

In [None]:
session_path = os.path.join(data_dir, experimenter, mouse, session)

# tiff file paths
s2p_path = os.path.join(session_path, 'suite2p', f'plane{plane}')
tiff_dir = os.path.join(s2p_path, f'reg_tif_chan{channel}')
all_tiff_paths = get_all_tiff_paths(tiff_dir)

# stimulation protocol paths
csv_save_path = os.path.join(data_dir, experimenter, mouse, session, 'photostim_protocol.csv')
csv_load_path = csv_save_path

# output paths
output_path = os.path.join(session_path, 'photostim_deve')
output_fig_path = os.path.join(output_path, 'fig')

if not os.path.exists(output_path):
    os.makedirs(output_path)
if not os.path.exists(output_fig_path):
    os.makedirs(output_fig_path)



In [None]:
# loading suite2p data
meds, mn_image, s2p_idxs, ops, f = get_med_img_s2p(session_path)
xoff = ops['xoff']
yoff = ops['yoff']


In [None]:
plot_xyoff(xoff, yoff, save_path=os.path.join(output_fig_path, 'xyoff.png'))

In [None]:
# Load stim protocol
mp_dict = parse_mark_points(session_path)
for key, value in mp_dict.items():
    print(f"Key: {key}, Value: {value}")

_ = mp_dict_to_stim_list(mp_dict, frame_period=frame_period, fov_shape=fov_shape, csv_save_path=csv_save_path)

all_time, all_frame, all_point, all_coords_x, all_coords_y = load_photostim_protocol(csv_load_path)

n_points = len(np.unique(all_point))

In [None]:
# now get the responses from suite2p motion corrected tiff files and related
# TODO: Issue with stim window being on the edge of two batches ...
fov_bsln, fov_resp, fov_diff = get_fov_resp(all_tiff_paths, all_frame, bsln_n_frames=bsln_n_frames, resp_n_frames=bsln_n_frames, fov_shape=fov_shape)

# for each response get the movementss from suite2p

In [None]:
plot_protocol(all_frame, all_point, n_frames=f.shape[1], save_path=os.path.join(output_fig_path, 'stim_protocol.png'))

In [None]:
for i in range(3):
    plot_fov_diff_single(fov_diff, all_point, all_coords_x, all_coords_y, i, vlim=vlim, save_path=os.path.join(output_fig_path, f'diff_single_trial{i}.png'))

In [None]:
if bsln_sub_type == 'trial_by_trial': # subtract the baseline in each trial of each point independently.
    fov_map = get_fov_resp_mn(fov_diff, all_point)

elif bsln_sub_type == 'session_wide': # subtract the baseline across all trials of all points (mean of all those). 
    fov_resp_mn = get_fov_resp_mn(fov_resp, all_point)
    fov_bsln_glob_mean = np.nanmean(fov_bsln, axis=0)
    fov_map = fov_resp_mn - fov_bsln_glob_mean


In [None]:
plot_fov_all_point(mn_image, all_point, all_coords_x, all_coords_y, txt_shift=txt_shift, save_path=os.path.join(output_fig_path, 'fov_mn_markpoints.png'))

In [None]:
# calculate the pixel values as a function of distance from the stimulus point


In [None]:
dist_diff_mn, dist_diff_std = get_dist_dff(fov_map, all_point, all_coords_x, all_coords_y, fov_shape=fov_shape, n_dist_bins=n_dist_bins)

In [None]:
plot_dist_dff(dist_diff_mn, n_points=n_points, dist_bins_xlim=dist_bins_xlim, dist_bins_xlim_zoom=dist_bins_xlim_zoom, save_path=os.path.join(output_fig_path, 'dist_dff.png'))

In [None]:
dist_diff_mn_mn = np.nanmean(dist_diff_mn, axis=0)

plt.plot(dist_diff_mn_mn, label='mean', c='black')

In [None]:
k1d, k2d = compute_dist_kernel(dist_diff_mn, n_dist_bins=n_dist_bins)
plot_kernel_2d(k2d, fov_shape=fov_shape, n_dist_bins=n_dist_bins, vlim=vlim, save_path=os.path.join(output_fig_path, 'kernel_2d.png'))

In [None]:
plot_fov_map(fov_map, all_coords_x, all_coords_y, vlim=vlim, save_path=os.path.join(output_fig_path, 'fov_map.png'))

In [None]:
# TODO: Average image across all points:
# 1) Take a point as centroid and pad with sufficint number of NaNs
# 2) Center the point in the FOV
# 3) Average across all points ...

In [None]:

# TODO: add bounding box to the FOV based on maximum displacement and the PPSF (physiological point spread function) of the neurons
# TODO: for each stimulus show a point in FOV where the stimulus was applied (based on the suite2p correction)

In [None]:
all_point_med_idx = np.zeros(n_points, dtype=int)  # to store the index of the med image closest to each stimulation point
all_point_s2p_idx = np.zeros(n_points, dtype=int) # to store the s2p_idx of the ROI closest to each stimulation point

for i in np.unique(all_point):
    coords_x = int(all_coords_x[i])
    coords_y = int(all_coords_y[i])

    # find the index of the closest s2p_idx to the stimulus point
    point_meds_idx = np.argmin(np.sqrt((meds[:, 0] - coords_x) ** 2 + (meds[:, 1] - coords_y) ** 2))
    
    point_s2p_idx = s2p_idxs[point_meds_idx]
    all_point_s2p_idx[i] = point_s2p_idx

In [None]:
# , save_path=os.path.join(output_fig_path, 'fov_mn_markpoints_s2proi.png')

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(np.mean(np.abs(fov_map), axis=0), cmap='gray')

for i in np.unique(all_point):
    plt.scatter(all_coords_y[i], all_coords_x[i], color='C0', s=1, label='stim point')
    plt.text(all_coords_y[i] + txt_shift[0], all_coords_x[i] + txt_shift[1], str(i), color='C0', fontsize=8, ha='center', va='center')

    s2p_cent_y = meds[np.where(s2p_idxs == all_point_s2p_idx[i])[0], 1]
    s2p_cent_x = meds[np.where(s2p_idxs == all_point_s2p_idx[i])[0], 0]

    plt.scatter(s2p_cent_y, s2p_cent_x, color='C1', s=1, label='matched ROI')
    plt.text(s2p_cent_y + txt_shift[0], s2p_cent_x - txt_shift[1], str(all_point_s2p_idx[i]), color='C1', fontsize=8, ha='center', va='center')

    # only include first two elements in the legend
    if i == 0:
        plt.legend()

    plt.axis('off')

In [None]:
# now plot the traces corresponding to the stimulus (of raw fluorescence)
f_point_s2p = f[all_point_s2p_idx, :]

plt.figure(figsize=(20, 2), dpi=300)
# zscore rows
# restrict to the stim period

f_point_s2p_stim_epoch = f_point_s2p[:, all_frame[0]:all_frame[-1]]  # restrict to the stim period

plt.imshow(zscore_act(f_point_s2p_stim_epoch), aspect='auto', cmap='gray_r', vmin=0, vmax=2)

In [None]:
# now average for each repetition across all trials
n_frames_repetition = int(all_frame[np.where(all_point==0)][1] - all_frame[np.where(all_point==0)][0] + 1)
print(f"Number of frames per repetition: {n_frames_repetition}")

f_point_s2p_stim_epoch_mn = np.zeros((n_points, n_frames_repetition))

offset = 0
for i in range(sum(all_point == 0)-1):
    mask = np.arange(offset, offset + n_frames_repetition)
    f_point_s2p_stim_epoch_mn += f_point_s2p_stim_epoch[:, mask]
    offset += n_frames_repetition


f_point_s2p_stim_epoch_mn /= sum(all_point == 0)


plt.figure(figsize=(2, 2), dpi=300)
plt.imshow(zscore_act(f_point_s2p_stim_epoch_mn), aspect='auto', cmap='gray_r', vmin=0, vmax=5, interpolation='nearest')


In [None]:
# plot the respons
# TODO: SUBTRACT BASELINE IN THE SAME WAY AS FOR THE FOV PLOT
# TODO: STICK TO THE SAME CONVENTION WITH resp, bsln and diff
peristim_wind = [10, 30] # +- in frames
# repetitions = int(mp_dict['Repetitions'])

# resp_mat = np.zeros((repetitions, peristim_wind[0] + peristim_wind[1] + 1))

# count = 0
# for (i, point) in enumerate(all_point):

#     frame = int(all_frame[i])
#     if point == point_idx:
#         resp_mat[count, :] = f[s2p_idx, frame - peristim_wind[0]:frame + peristim_wind[1] + 1]
#         count += 1




n_repetitions = len(np.where(all_point==0)[0]) # TODO: standardize this to not have to define it each time
n_points = len(np.unique(all_point))

s2p_resp = np.zeros((n_points, n_repetitions, peristim_wind[0] + peristim_wind[1] + 1))
s2p_resp_zscore = np.zeros((n_points, n_repetitions, peristim_wind[0] + peristim_wind[1] + 1))

f_zscore = zscore_act(f)

for point_idx in np.unique(all_point):
    for j in range(n_repetitions):
        frame = int(all_frame[np.where(all_point==point_idx)[0][j]])

        bsln = np.mean(f[all_point_s2p_idx[point_idx], frame - bsln_n_frames:frame])
        s2p_resp[point_idx, j, :]  = f[all_point_s2p_idx[point_idx], frame - peristim_wind[0]:frame + peristim_wind[1] + 1] - bsln


        bsln_zscore = np.mean(f_zscore[all_point_s2p_idx[point_idx], frame - bsln_n_frames:frame])
        s2p_resp_zscore[point_idx, j, :]  = f_zscore[all_point_s2p_idx[point_idx], frame - peristim_wind[0]:frame + peristim_wind[1] + 1]






In [None]:
# take the cell above as the template and plot again a 4 by 10 grid of the responses
n_cols = 4
n_rows = n_points // n_cols

fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols*2, n_rows*2), dpi=300)

for i in range(n_rows):
    for j in range(n_cols):
        idx = i * n_cols + j
        if idx < n_points:
            axs[i, j].plot(s2p_resp[idx, :, :].T, color='grey', alpha=0.1)
            axs[i, j].plot(np.mean(s2p_resp[idx, :, :], axis=0), color='C0', alpha=0.5, zorder=10)
            axs[i, j].axvline(peristim_wind[0], color='k', linestyle='--')
            axs[i, j].set_xticks([peristim_wind[0], (peristim_wind[1])/2 + peristim_wind[0], peristim_wind[1] + peristim_wind[0]])
            axs[i, j].set_xticklabels([0, 0.5, 1])
            axs[i, j].set_title(f'ROI {all_point_s2p_idx[idx]} (point {idx})', fontsize=8)

            # remoeve all top and right spines
            axs[i, j].spines['top'].set_visible(False)
            axs[i, j].spines['right'].set_visible(False)

            # only put x tick labels and axis label on the bottom row
            if i == n_rows - 1:
                axs[i, j].set_xlabel('Time (s)')
            else:
                axs[i, j].set_xticklabels([])
            
            # remove all y tick labels
            if j == 0:
                axs[i, j].set_ylabel('F (a.u.)')
            
            axs[i, j].set_yticklabels([])

# reduce the spacing between the subplots
plt.tight_layout()
# now add suptitle with a bit of padding
plt.suptitle('Response of s2p ROI nearest to stim point', fontsize=16, y=1.02)

In [None]:
"""

plt.figure(figsize=(2, 2), dpi=300)
plt.imshow(s2p_resp[point_idx_plot,:,:], aspect='auto', cmap='bwr', vmin=np.median(s2p_resp[point_idx_plot,:,:]) - 8 * np.std(s2p_resp[point_idx_plot,:,:]), vmax=np.median(s2p_resp[point_idx_plot,:,:]) + 8 * np.std(s2p_resp[point_idx_plot,:,:]))
plt.axvline(peristim_wind[0], color='k', linestyle='--')
plt.ylabel('Repetition')
plt.xlabel('Time (s)')
plt.xticks([peristim_wind[0], (peristim_wind[1])/2 + peristim_wind[0], peristim_wind[1] + peristim_wind[0]], [0, 0.5, 1])


"""


# now the same but with the diferent visualsiation
plt.figure(figsize=(2, 2), dpi=300)

# take the cell above as the template and plot again a 4 by 10 grid of the responses
n_cols = 4
n_rows = n_points // n_cols

fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols*2, n_rows*2), dpi=300)

for i in range(n_rows):
    for j in range(n_cols):
        idx = i * n_cols + j
        if idx < n_points:
            
            axs[i, j].imshow(s2p_resp[idx, :, :], aspect='auto', cmap='bwr', vmin=np.median(s2p_resp[idx, :, :]) - 8 * np.std(s2p_resp[idx, :, :]), vmax=np.median(s2p_resp[idx, :, :]) + 8 * np.std(s2p_resp[idx, :, :]))
            axs[i, j].axvline(peristim_wind[0], color='k', linestyle='--')
            axs[i, j].set_xticks([peristim_wind[0], (peristim_wind[1])/2 + peristim_wind[0], peristim_wind[1] + peristim_wind[0]])
            axs[i, j].set_xticklabels([0, 0.5, 1])
            axs[i, j].set_title(f'ROI {all_point_s2p_idx[idx]} (point {idx})', fontsize=8)
            


            # remoeve all top and right spines
            axs[i, j].spines['top'].set_visible(False)
            axs[i, j].spines['right'].set_visible(False)

            # only put x tick labels and axis label on the bottom row
            if i == n_rows - 1:
                axs[i, j].set_xlabel('Time (s)')
            else:
                axs[i, j].set_xticklabels([])
            
            # remove all y tick labels
            if j == 0:
                axs[i, j].set_ylabel('Repetition')
            
            axs[i, j].set_yticklabels([])

# reduce the spacing between the subplots
plt.tight_layout()
# now add suptitle with a bit of padding
plt.suptitle('Response of s2p ROI nearest to stim point', fontsize=16, y=1.02)

In [None]:
# now the mean response across all repetitions with two visualsiatons...


In [None]:
s2p_resp_mn = np.mean(s2p_resp, axis=1)
plt.figure(figsize=(2, 2), dpi=300)
plt.imshow(s2p_resp_mn, aspect='auto', cmap='bwr', vmin=np.median(s2p_resp_mn) - 8 * np.std(s2p_resp_mn), vmax=np.median(s2p_resp_mn) + 8 * np.std(s2p_resp_mn))
plt.axvline(peristim_wind[0], color='k', linestyle='--')
plt.ylabel('Point')
plt.xlabel('Time (s)')
plt.title('Mean F of s2p ROI')
plt.colorbar(label='F (a.u.)')

s2p_resp_mn_zscore = np.mean(s2p_resp_zscore, axis=1)
plt.figure(figsize=(2, 2), dpi=300)
plt.imshow(s2p_resp_mn_zscore, aspect='auto', cmap='bwr', vmin=-8, vmax=8)
plt.axvline(peristim_wind[0], color='k', linestyle='--')
plt.ylabel('Point')
plt.xlabel('Time (s)')
plt.title('Mean F of s2p ROI')
plt.colorbar(label='F (zscore)')



In [None]:
# now check in suite2p if it looks like it makes sense




In [None]:
# # now show it in gui
# labels = [1, 2, 3, 4]
# colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
# size = 30
# properties = {'label': labels}

# with napari.gui_qt():
#     viewer = napari.Viewer()
#     viewer.add_image(crop_tiff, name=f'{experimenter}_{mouse}_{session}_plane{plane}_chan{channel}', colormap='gray')
#     # add points
#     viewer.add_points(crop_point_coords, name='stimulations', size=0, properties=properties, text={'string': labels, 'color': colors, 'size': size})
#     # set frame rate to 30 fps

#     if make_anim:
#         anim = Animation(viewer)

#         for i in np.arange(4150, 4150 + 5 * 120):
#             viewer.dims.set_point(0, i)
#             anim.capture_keyframe()
    

In [None]:
# if make_anim:
#     anim.animate('stim.mp4', fps=300, quality=5)


In [None]:
# # now make a video of the average stimulation (e.g. take the times from each stim0 and average them)
# n_frames_cycle = int(stim_frames[stim_points == 0][1] - stim_frames[stim_points == 0][0])

# crop_tiff_rep_avg = np.zeros((n_frames_cycle, crop_tiff.shape[1], crop_tiff.shape[2]))

# for i in range(int(mp_dict['Repetitions'])-2):
#     t_onset = int(stim_frames[stim_points == 0][i])
#     t_offset = int(stim_frames[stim_points == 0][i + 1])

#     if t_offset - t_onset != n_frames_cycle:
#         t_offset = t_onset + n_frames_cycle
#     print(i)
#     print(f'Onset: {t_onset}, Offset: {t_offset}')
#     crop_tiff_rep_avg += crop_tiff[t_onset:t_offset, :, :]

    
# crop_tiff_rep_avg = crop_tiff_rep_avg / int(mp_dict['Repetitions'])


In [None]:

# with napari.gui_qt():
#     viewer = napari.Viewer()
#     viewer.add_image(crop_tiff_rep_avg, name=f'{experimenter}_{mouse}_{session}_plane{plane}_chan{channel}', colormap='gray')
#     viewer.add_points(crop_point_coords, name='stimulations', size=0, properties=properties, text={'string': labels, 'color': colors, 'size': size})


#     if make_anim:
#         anim = Animation(viewer)

#         # set all frames as keyframes
#         for i in range(crop_tiff_rep_avg.shape[0]):
#             # go to frame i (dim 0 of crop_tiff_rep_avg)
#             viewer.dims.set_point(0, i)
#             anim.capture_keyframe()
#         # set the fps to 30

# if make_anim:
#     anim.animate('stim_avg.mp4', fps=300, quality=5)


In [None]:
# # get the average activations for each stimulus from the average
# resp_duration = 10 # in frames
# n_points = len(mp_dict['AllPoint'])
# inter_stim_interval = 30 # in frames
# # stim0_onset = 0
# # stim1_onset = 30
# # stim2_onset = 60
# # stim3_onset = 90

# # make the median projection be the baseline
# crop_tiff_rep_avg_median = np.median(crop_tiff_rep_avg, axis=0)

# all_mn_stim_resp = np.zeros((n_points, crop_tiff_rep_avg.shape[1], crop_tiff_rep_avg.shape[2]))

# for i in range(all_mn_stim_resp.shape[0]):
#     onset = i * inter_stim_interval
#     offset = onset + resp_duration
#     all_mn_stim_resp[i,:,:] = np.mean(crop_tiff_rep_avg[onset:offset, :, :], axis=0)



In [None]:
# std_vis = 10

# fig, axs = plt.subplots(1, n_points, figsize=(20, 5))

# for i in range(all_mn_stim_resp.shape[0]):
#     diff = crop_tiff_rep_avg_median- all_mn_stim_resp[i,:,:]
#     # center on median of diff
#     diff = diff - np.median(diff)

#     axs[i].imshow(diff, cmap='bwr_r', vmin=np.std(diff) * -std_vis, vmax=np.std(diff) * std_vis)
#     axs[i].scatter(crop_point_coords[i, 1], crop_point_coords[i, 0], c='black', s=10)
#     # remove axis ticks and labels
#     axs[i].set_xticks([])
#     axs[i].set_yticks([])
