In [None]:
%load_ext autoreload

In [None]:
from pathlib import Path
import numpy as np
import flammkuchen as fl
import json

from split_dataset import SplitDataset
%autoreload
from glob import glob
from bouter import Experiment

from lavian_et_al_2025.visual_motion.stimulus_functions import stim_vel_dir_dataframe, quantize_directions
from scipy.interpolate import interp1d
from scipy.signal import convolve2d
import pandas as pd

In [None]:
import shutil
import warnings
from typing import List, Union, Tuple
import time

In [None]:
def make_sensory_regressors(exp, n_dirs=8, upsampling=5, sampling=1/3):
    stim = stim_vel_dir_dataframe(exp)
    bin_centres, dir_bins = quantize_directions(stim.theta)
    ind_regs = np.zeros((n_dirs, len(stim)))
    for i_dir in range(n_dirs):
        ind_regs[i_dir, :] = (np.abs(dir_bins - i_dir) < 0.1) & (stim.vel > 0.1)  

    dt_upsampled = sampling / upsampling
        
    t_imaging_up = np.arange(0, stim.t.values[-1], dt_upsampled)
    reg_up = interp1d(stim.t.values, ind_regs, axis=1, fill_value="extrapolate")(
        t_imaging_up
    )
    
    # 6s kernel
    u_steps = t_imaging_up.shape[0]
    print(stim.t.values[-1])
    u_time = np.arange(u_steps) * dt_upsampled
    decay = np.exp(-u_time / (1.5 / np.log(2)))
    kernel = decay / np.sum(decay)
    
    convolved = convolve2d(reg_up, kernel[None, :])[:, 0:u_steps]
    reg_sensory = convolved[:, ::upsampling]

    return pd.DataFrame(reg_sensory.T, columns=[f"motion_{i}" for i in range(n_dirs)])

In [None]:
def calculate_dff(
    stack: np.ndarray,
    baseline_frames,
    min_baseline_value: float = 1.0,
    percentile_value: float = 1.0,
    global_offset: float = None,
    timing: bool = True
) 
    """
    Calculate ΔF/F for a 4D calcium imaging stack with offset correction.
    
    Parameters
    ----------
    stack : np.ndarray
        4D array with dimensions (time, z, x, y)
    baseline_frames : Union[List[int], np.ndarray, List[List[int]]]
        Frame indices for baseline calculation.
        If per_plane_baselines=True, should be a list with baseline frames for each z-plane.
    min_baseline_value : float, optional
        Minimum value for the baseline, by default 1.0
    percentile_value : float, optional
        Percentile to use for 'min_subtract' method (1.0 = 1st percentile), by default 1.0
    global_offset : float, optional
        Manual offset value to add if offset_correction='manual', by default None
        Whether to handle NaN values in the data, by default True
    timing : bool, optional
        Whether to print timing information, by default True
        
    Returns
    -------
    np.ndarray
        4D array with ΔF/F values
    """
    if timing:
        start_time = time.time()
    
    # Convert to float if needed
    if not np.issubdtype(stack.dtype, np.floating):
        stack = stack.astype(np.float32)
        if timing:
            print(f"Data type conversion: {time.time() - start_time:.2f} seconds")
    
    # Get dimensions
    t_max, z_max, x_max, y_max = stack.shape
    
    # Create a mask for NaN values in the original stack (to preserve them)
    original_nan_mask = np.isnan(stack)
    if timing and np.any(original_nan_mask):
        print(f"NaN values detected in input: {np.sum(original_nan_mask)} out of {stack.size} elements")
    
    # Apply offset correction
    if timing:
        offset_start = time.time()
    
    # Create a copy to avoid modifying the original
    # Replace NaNs with zeros for calculation purposes
    corrected_stack = np.nan_to_num(stack, nan=0.0)
    
    # For each z-plane, find minimum (or low percentile) value and subtract
    for z in range(z_max):
        if handle_nans:
            # Use nanpercentile to ignore NaNs when calculating percentile
            min_value = np.nanpercentile(stack[:, z, :, :], percentile_value)
        else:
            # Use regular percentile
            min_value = np.percentile(stack[:, z, :, :], percentile_value)

        corrected_stack[:, z, :, :] -= min_value
        if timing:
            print(f"Z-plane {z}: Subtracted {min_value:.2f} (percentile {percentile_value})")

    
    if timing:
        print(f"Offset correction: {time.time() - offset_start:.2f} seconds")
        baseline_start = time.time()
    
    # Initialize result array
    dff = np.zeros_like(corrected_stack)
    
    # Calculate ΔF/F with corrected stack
    # Process each z-plane with its own baseline frames
    for z in range(z_max):
        # Get baseline frames for this plane
        plane_frames = np.array(baseline_frames[z], dtype=int).flatten()

        # Calculate baseline for this plane
        if handle_nans:
            # Use nanmean to ignore NaNs when calculating baseline
            f0_plane = np.nanmean(stack[plane_frames, z, :, :], axis=0)
            # Replace any remaining NaNs in baseline with min_baseline_value
            f0_plane = np.nan_to_num(f0_plane, nan=min_baseline_value)
        else:
            f0_plane = np.mean(corrected_stack[plane_frames, z, :, :], axis=0)

        # Apply minimum threshold
        f0_plane = np.maximum(f0_plane, min_baseline_value)

        # Calculate ΔF/F for this plane
        dff[:, z, :, :] = (corrected_stack[:, z, :, :] - f0_plane) / f0_plane

    
    # Restore original NaN values
    if np.any(original_nan_mask):
        dff[original_nan_mask] = np.nan
    
    if timing:
        total_time = time.time() - start_time
        baseline_time = time.time() - baseline_start
        print(f"Baseline calculation: {baseline_time:.2f} seconds")
        print(f"Total ΔF/F calculation: {total_time:.2f} seconds")
    
    return dff


In [None]:
# find the frames to calculate the baseline.
def no_regressor_frames(regressors, threshold=0.05):
    return np.where(np.all(regressors < threshold, axis=0))[0]


In [None]:
master = Path(r"\\Funes2\data\Hagar and Ot\E0040\v10\2p\s1186t")

all_fish = list(master.glob("*_f*"))
fish_dir = all_fish[-1]
print(all_fish)

In [None]:
for f in all_fish[:]:
    print(f)
    #try:
    if not (f / "dff2").exists():
        
        stack = SplitDataset(f / "aligned")[:,:,:,:]
        len_rec, num_planes, x_pix, y_pix = np.shape(stack)
        exp_list = glob(str(f / "behavior/*.json"))
        reg_list = [make_sensory_regressors(Experiment(exp), sampling=1/3) for exp in exp_list]
        baseline_frames = [no_regressor_frames(reg) for reg in reg_list]
        print(baseline_frames)
        analyze_data_statistics(stack, z_plane=0)
        
        dff_result = calculate_dff(
            stack, 
            baseline_frames,
            percentile_value=1.0,             # Use 1st percentile for min_subtract
            timing=True                       # Show timing information
        )
        
        for i in range(num_planes):
    
            if i < 10:
                file_name = '000' + str(i) + '.h5'
            else:
                file_name = '00' + str(i) + '.h5'

            data_out = dff_result[:,i,:,:]
            data_out = np.expand_dims(data_out, 1)

            dff_dir = (f / "dff2")
            dff_dir.mkdir(exist_ok=True)
            data_out_folder = str(f / "dff2" / file_name)
            fl.save(data_out_folder, {"stack_4D": data_out}, compression="blosc")
            
        shutil.copy(str(f / "aligned/stack_metadata.json"), str(f / "dff2"))