In [None]:
%matplotlib widget

In [None]:
from glob import glob
import numpy as np
import flammkuchen as fl
from split_dataset import SplitDataset
from pathlib import Path

In [None]:
# calculate dot product with each regressor from dF/F traces, px-wise
def get_tuning_map(img, sens_regs):
    """
    Calculates pixel-wise correlation maps between an imaging time series and sensory regressors.
    
    This function computes Pearson correlation coefficients between each pixel's time series
    and each sensory regressor, resulting in a spatial map of correlation values for each regressor.
    
    Parameters
    ----------
    img : ndarray
        Imaging data with shape (time, height, width)
    sens_regs : ndarray or DataFrame
        Sensory regressors with shape (time, n_regressors) containing stimulus information
        
    Returns
    -------
    reg : ndarray
        Correlation maps with shape (n_regressors, height, width) showing spatial distribution
        of correlation between each regressor and pixel activity
    """
    # Reshape image data to (time, pixels) for vectorized operations
    traces = img.reshape(img.shape[0], -1)
    
    # Calculate correlation
    n_t = sens_regs.shape[0]
    traces = traces[:n_t,:]
    a = np.dot(traces.T, sens_regs) - traces.shape[0] * np.outer(np.nanmean(traces, 0), np.nanmean(sens_regs, 0))
    b = (traces.shape[0] - 1) * np.outer(np.nanstd(traces, 0), np.nanstd(sens_regs, 0))
    reg = (a / b).T
    
    # Reshape result back to spatial dimensions
    reg = reg.reshape(reg.shape[0], img.shape[-2], img.shape[-1])

    return reg

In [None]:
master = Path(r"")
fish_list = list(master.glob("*_f*"))

sampling = 1/3
n_dir = 8

In [None]:
for fish in fish_list:
    print(fish)
    if not (fish / "plane_corrmap_corrvalues_.h5").exists():
        stack = SplitDataset(fish / "dff")
        exp_list = glob(str(fish / "behavior/*.json"))
        time = np.linspace(0, stack.shape[0]*sampling, stack.shape[0])

        len_rec, num_planes, x_pix, y_pix = np.shape(stack)

        plane_list = glob(str(fish / "suite2p/00*"))

        plane_corr = np.zeros((num_planes, n_dir, x_pix, y_pix ))
        for i in range(num_planes):
            plane = plane_list[i]
            print(plane)

            file_name = "sensory_regressors" + str(i) + ".h5"
            regs = fl.load(fish / file_name)['regressors'].values

            plane_corr[i] = get_corr_map(stack[:,i,:,:], regs)

        d = {'plane_corr': plane_corr}
        fl.save(fish / 'plane_corrmap_corrvalues.h5', d)
