In [None]:
# WIP for stitching together Labview phot recordings that got weird

from jdb_to_nwb.convert_photometry import process_raw_labview_photometry_signals, process_and_add_labview_to_nwb
from pynwb import NWBHDF5IO

metadata = {}
metadata["photometry"] = {}
metadata["photometry"]["phot_file_path"] = ["/Volumes/Tim/Photometry/IM-1875/20250725/IM-1875_2025-07-25_11-11-56____Tim_Conditioning.phot", 
    "/Volumes/Tim/Photometry/IM-1875/20250725/IM-1875_2025-07-25_12-16-34____Tim_Conditioning.phot"]
metadata["photometry"]["box_file_path"] = ["/Volumes/Tim/Photometry/IM-1875/20250725/IM-1875_2025-07-25_11-11-56____Tim_Conditioning.box",
    "/Volumes/Tim/Photometry/IM-1875/20250725/IM-1875_2025-07-25_12-16-34____Tim_Conditioning.box"]

# Read nwb for visit times
nwb_file_path = "/Users/steph/berkelab/jdb_to_nwb/nwbs/IM-1875/Darling_20250725.nwb"
with NWBHDF5IO(nwb_file_path, 'r') as io:
    nwbfile = io.read()

    # Get trial and block data from the nwb
    block_data = nwbfile.intervals["block"].to_dataframe()
    trial_data = nwbfile.intervals["trials"].to_dataframe()
    visit_times = trial_data["poke_in"][:]
    print(visit_times)


class DummyLogger:
	def debug(self, msg):
		print(f"[DEBUG] {msg}")

	def info(self, msg):
		print(f"[INFO] {msg}")

	def warning(self, msg):
		print(f"[WARNING] {msg}")

	def error(self, msg):
		print(f"[ERROR] {msg}")
logger = DummyLogger()


# If we have raw LabVIEW data (.phot and .box files)
if "phot_file_path" in metadata["photometry"] and "box_file_path" in metadata["photometry"]:
    # Process photometry data from LabVIEW to create a signals dict of relevant photometry signals
    logger.info("Using LabVIEW for photometry!")
    logger.info("Processing raw .phot and .box files from LabVIEW...")
    print("Processing raw .phot and .box files from LabVIEW...")
    phot_file_path = metadata["photometry"]["phot_file_path"]
    box_file_path = metadata["photometry"]["box_file_path"]

    # Ideally we have a string for phot_file_path and box_file_path
    if isinstance(phot_file_path, str) and isinstance(box_file_path, str):
        signals = process_raw_labview_photometry_signals(phot_file_path, box_file_path, logger)
        #photometry_data_dict = process_and_add_labview_to_nwb(nwbfile, signals, logger, fig_dir)
    # Or if LabVIEW photometry cut out during the recording, we specify a list of paths instead
    elif isinstance(phot_file_path, list) and isinstance(box_file_path, list):
        assert len(phot_file_path) == len(box_file_path), "phot and box file path lists must be the same length!"
        sig_list = []
        for phot_path, box_path in zip(phot_file_path, box_file_path):
            signals = process_raw_labview_photometry_signals(phot_path, box_path, logger)
            sig_list.append(signals)
    else:
        logger.error("phot_file_path and box_file_path must both be strings or both be lists of equal length!")
        raise TypeError("phot_file_path and box_file_path must both be strings or both be lists of equal length!")


id
0        24.871130
1        36.674752
2        59.505293
3        79.218906
4       106.591770
          ...     
242    6589.534592
243    6619.991693
244    6629.178918
245    6667.868710
246    6718.653837
Name: poke_in, Length: 247, dtype: float64
[INFO] Using LabVIEW for photometry!
[INFO] Processing raw .phot and .box files from LabVIEW...
Processing raw .phot and .box files from LabVIEW...
[INFO] Reading LabVIEW .phot file into a dictionary...
[DEBUG] Data read from LabVIEW .phot file at /Volumes/Tim/Photometry/IM-1875/20250725/IM-1875_2025-07-25_11-11-56____Tim_Conditioning.phot:
[DEBUG] magic_key: 22289481
[DEBUG] header_size: 20480
[DEBUG] main_version: 0
[DEBUG] secondary_version: 0
[DEBUG] sampling_rate: 10000
[DEBUG] bytes_per_sample: 2
[DEBUG] num_channels: 8
[DEBUG] file_name: C:\Users\BerkeLab\Documents\Labview\IM-1875\2025-07-25\IM-1875_2025-07-25_11-11-56____Tim_Conditioning.phot                                                                                       

In [None]:
import numpy as np
from jdb_to_nwb.timestamps_alignment import trim_sync_pulses

unaligned_visits_1 = sig_list[0]["visits"]
unaligned_visits_2 = sig_list[1]["visits"]

SR = 10000  # Original sampling rate of the photometry system (Hz)
Fs = 250  # Target downsample frequency (Hz)

unaligned_visits_1_ds = np.divide(unaligned_visits_1, SR / Fs).astype(int)
unaligned_visits_2_ds = np.divide(unaligned_visits_2, SR / Fs).astype(int)

# Convert port visits to seconds to use for alignment
unaligned_visits_1 = [visit_time / Fs for visit_time in unaligned_visits_1_ds]
unaligned_visits_2 = [visit_time / Fs for visit_time in unaligned_visits_2_ds]


def align_with_gap(ground_truth_visits, unaligned_visits_1, unaligned_visits_2, logger):
    """ Align sync pulsies with a gap. WIP. """


    num_unaligned_visits = len(unaligned_visits_1) + len(unaligned_visits_2)
    if num_unaligned_visits == len(ground_truth_visits):
        # Ideal case - unaligned_visits_1 and unaligned_visits_2 perfectly span ground_truth_visits
        expected_best_start_1 = 0
        expected_best_start_2 = len(ground_truth_visits) - len(unaligned_visits_2)
        logger.info("The number of port visits recorded by each unaligned chunk "
                    f"sums to the total number of ground truth port visits ({num_unaligned_visits}).")
        logger.debug("We expect the first chunk to match the first section of the ground truth visits (starting at 0) "
                     f"and the second to match the second section (starting at {expected_best_start_2})")
    else:
        logger.warning("ew.")

    logger.info("Finding best alignment for the first chunk...")
    ground_truth_visits_1, unaligned_visits_1 = trim_sync_pulses(ground_truth_visits=ground_truth_visits, 
                                                                 unaligned_visits=unaligned_visits_1,
                                                                 logger=logger,
                                                                 expected_best_start=expected_best_start_1)
    
    logger.info("Finding best alignment for the second chunk...")
    ground_truth_visits_2, unaligned_visits_2 = trim_sync_pulses(ground_truth_visits=ground_truth_visits, 
                                                                 unaligned_visits=unaligned_visits_2,
                                                                 logger=logger,
                                                                 expected_best_start=expected_best_start_2)


align_with_gap(ground_truth_visits=visit_times,
               unaligned_visits_1=unaligned_visits_1,
               unaligned_visits_2=unaligned_visits_2,
               logger=logger)


    # # Step 1: Convert sync indices to relative time
    # B1_times = np.array(B1_sync_indices) / rate
    # B2_times = np.array(B2_sync_indices) / rate

    # # Step 2: Fit linear time transforms for each part
    # scale1, offset1 = np.polyfit(B1_times, A1_sync_times, 1)
    # scale2, offset2 = np.polyfit(B2_times, A2_sync_times, 1)

    # # Step 3: Create timestamps for each sample
    # B1_all_times = np.arange(len(B1)) / rate
    # B2_all_times = np.arange(len(B2)) / rate

    # B1_aligned_times = scale1 * B1_all_times + offset1
    # B2_aligned_times = scale2 * B2_all_times + offset2

    # # Step 4: Build full time axis with consistent sampling
    # full_start = np.floor(B1_aligned_times[0] * rate) / rate
    # full_end = np.ceil(B2_aligned_times[-1] * rate) / rate
    # full_time_axis = np.arange(full_start, full_end, 1 / rate)

    # # Step 5: Initialize stitched signal with NaNs
    # stitched_values = np.full_like(full_time_axis, fill_value=np.nan, dtype=float)

    # # Step 6: Insert values into correct time slots
    # def insert_segment(aligned_times, values):
    #     indices = np.round((aligned_times - full_start) * rate).astype(int)
    #     valid = (indices >= 0) & (indices < len(stitched_values))
    #     stitched_values[indices[valid]] = values[valid]

    # insert_segment(B1_aligned_times, B1)
    # insert_segment(B2_aligned_times, B2)

    # return full_time_axis, stitched_values

[INFO] The number of port visits recorded by each unaligned chunk sums to the total number of ground truth port visits (247).
[DEBUG] We expect the first chunk to match the first section of the ground truth visits (starting at 0) and the second to match the second section (starting at 198)
[INFO] Finding best alignment for the first chunk...
[INFO] Initial number of port visits: ground truth=247, unaligned=198
[DEBUG] List of ground truth visits is longer; trimming it.
[INFO] Finding best alignment between port visits by minimizing error in pulse spacing.
[DEBUG] Trimming 0 samples from start and 49 from end of longer list. Sum of differences in pulse spacing = 0.44660480000973024
[DEBUG] Trimming 1 samples from start and 48 from end of longer list. Sum of differences in pulse spacing = 2884.5682816000735
[DEBUG] Trimming 2 samples from start and 47 from end of longer list. Sum of differences in pulse spacing = 2742.0027648000437
[DEBUG] Trimming 3 samples from start and 46 from end of