In [211]:
import re
import warnings
import numpy as np
import pandas as pd
from collections import Counter

poke_in_regex = re.compile(r"^(\d+)\sUP\s(\d+)")  # matches: timestamp UP port_num
poke_out_regex = re.compile(r"^(\d+)\sDOWN\s(\d+)")  # matches: timestamp DOWN port_num
behavior_data_regex = re.compile(r"(\d+)\s+(contingency|trialThresh|totalPokes|totalRewards|ifDelay|countPokes[1-3]|countRewards[1-3]|portProbs[1-3])\s*=\s*(\d+)")
block_end_regex = re.compile(r"(\d+)\s+This block is over!")
session_end_regex = re.compile(r"(\d+)\s+This session is complete!")


def parse_nosepoke_events(nosepoke_events, nosepoke_DIOs):
    """
    Given a list of all nosepoke events from the statescript, 
    and nosepoke DIOs, ensure all events are valid.
    Return a dataframe including only nosepoke events at a new port.
    """

    # NOTE: For now, we do the exact same processing separately for the pokes from the statescript
    # and the pokes from the DIOs, and don't combine them until the very end.
    # This results in a lot of repeat code, but we do it because the number of raw pokes from the DIO
    # often doesn't match the number of raw pokes from the statescript, so we can't combine them earlier.
    # It's also nice to do the same checks on both the statescript and the DIOs every step of the way
    # at this early stage where we may not have run into all special cases yet. This will probably be streamlined later.

    # Make sure we have the same number of poke_in and poke_out events from the statescriptlog
    event_counts = Counter(event["event_name"] for event in nosepoke_events)
    if event_counts["poke_in"] != event_counts["poke_out"]:
        raise Exception(f"Warning: {event_counts['poke_in']} poke_in events but {event_counts['poke_out']} poke_out events!")

    # Convert statescript pokes from list of dicts to a dataframe (mapping DIO 1, 2, 4 to ports A, B, C)
    nosepoke_df = pd.DataFrame(nosepoke_events)
    nosepoke_df["port"] = nosepoke_df["port"].map({1: "A", 2: "B", 4: "C"})

    # Create a dataframe of DIO pokes that matches the dataframe from the statescript
    port_map = {'wellA_poke': 'A', 'wellB_poke': 'B', 'wellC_poke': 'C'}
    DIO_nosepoke_df = pd.DataFrame([
    {'timestamp': ts, 'event_name': 'poke_in' if d == 1 else 'poke_out', 'port': port_map[k]}
    for k, (data_list, timestamps) in nosepoke_DIOs.items()
    for d, ts in zip(data_list, timestamps)
    ])
    DIO_nosepoke_df = DIO_nosepoke_df.sort_values(by='timestamp').reset_index(drop=True)

    # Make sure the number of DIO pokes matches the number of pokes from the statescriptlog
    if len(DIO_nosepoke_df) != len(nosepoke_df):
        print(f"Length mismatch: DIO nosepokes has {len(DIO_nosepoke_df)} elements, "
              f"but statescript nosepokes has {len(nosepoke_df)} elements (all nosepokes).")

    # Make sure each poke_in is followed by a poke_out at the same port (statescript)
    for row in range(0, len(nosepoke_df) - 1, 2):
        event1 = nosepoke_df.iloc[row]
        event2 = nosepoke_df.iloc[row + 1]
        if not (event1["event_name"] == "poke_in" and event2["event_name"] == "poke_out" and event1["port"] == event2["port"]):
            raise Exception(f"Warning: Invalid nosepoke pair from statescript at timestamps {event1['timestamp']} and {event2['timestamp']}!")
        
    # Make sure each poke_in is followed by a poke_out at the same port (DIO)
    for row in range(0, len(DIO_nosepoke_df) - 1, 2):
        event1 = DIO_nosepoke_df.iloc[row]
        event2 = DIO_nosepoke_df.iloc[row + 1]
        if not (event1["event_name"] == "poke_in" and event2["event_name"] == "poke_out" and event1["port"] == event2["port"]):
            raise Exception(f"Warning: Invalid nosepoke pair from DIOs at timestamps {event1['timestamp']} and {event2['timestamp']}!")

    # Iterate over nosepoke events to keep only pokes at new ports (statescript)
    nosepokes_at_new_ports, current_port = [], None
    # If the poke_in is at a new port, append the poke_in and its corresponding poke_out
    for i in range(0, len(nosepoke_df) - 1, 2):
        if nosepoke_df.iloc[i]["port"] != current_port:
            nosepokes_at_new_ports.extend([nosepoke_df.iloc[i], nosepoke_df.iloc[i + 1]])
            current_port = nosepoke_df.iloc[i]["port"]

    # Iterate over nosepoke events to keep only pokes at new ports (DIO)
    DIO_nosepokes_at_new_ports, current_port = [], None
    # If the poke_in is at a new port, append the poke_in and its corresponding poke_out
    for i in range(0, len(DIO_nosepoke_df) - 1, 2):
        if DIO_nosepoke_df.iloc[i]["port"] != current_port:
            DIO_nosepokes_at_new_ports.extend([DIO_nosepoke_df.iloc[i], DIO_nosepoke_df.iloc[i + 1]])
            current_port = DIO_nosepoke_df.iloc[i]["port"]

     # Now that we only have pokes at new ports, make sure the number of DIO pokes = the number of statescript pokes
    if len(DIO_nosepokes_at_new_ports) != len(nosepokes_at_new_ports):
        print(f"Length mismatch: DIO nosepokes has {len(DIO_nosepokes_at_new_ports)} elements, "
              f"but statescript nosepokes has {len(nosepokes_at_new_ports)} elements "
              "(only nosepokes at new ports).")
        
    DIO_nosepokes_at_new_ports = pd.DataFrame(DIO_nosepokes_at_new_ports)
    nosepokes_at_new_ports = pd.DataFrame(nosepokes_at_new_ports)

    # Merge the DIO and statescript dataframes based on matching 'event_name' and 'port' values
    # NOTE: for now, we only iterate over the shorter length because the DIO might have extra entries 
    # that I *think* can be safely ignored but I probably want to add some extra checks here
    merged_nosepokes_at_new_ports = []
    for dio_row, statescript_row in zip(DIO_nosepokes_at_new_ports.iterrows(), nosepokes_at_new_ports.iterrows()):
        _, dio_row = dio_row
        _, statescript_row = statescript_row
    
        # Make sure the event_name and port match
        if dio_row['event_name'] == statescript_row['event_name'] and dio_row['port'] == statescript_row['port']:
            # Add matching rows to the result
            merged_nosepokes_at_new_ports.append({
                'timestamp_DIO': dio_row['timestamp'],
                'timestamp_statescript': statescript_row['timestamp'],
                'event_name': dio_row['event_name'],
                'port': dio_row['port']
            })
        else:
            warnings.warn(f"Event and port mismatch between DIO and statescript! DIO has {dio_row['event_name']} at port {dio_row['port']}, statescript has {statescript_row['event_name']} at {statescript_row['port']}")

    # Return a dataframe of nosepokes including only nosepokes at new ports
    return pd.DataFrame(merged_nosepokes_at_new_ports)


def parse_trial_and_block_data(behavior_data, block_ends):
    """
    Parse behavioral data from the stateScriptLog into dictionaries
    of trial-level and block-level data
    """
    
    # Convert our list of block end timestamps to a dictionary of block: timestamp
    block_ends_dict = {index + 1: item['timestamp'] for index, item in enumerate(block_ends)}
    # Set the default end time as a big number that will definitely be larger than all timestamps in the stateScriptLog.
    # This will be used if we don't have a recorded block end time and overwritten by the real timestamp later
    default_block_end_time = 100_000_000

    # Make sure we have the complete set of information for each trial
    variable_counts = Counter(item["name"] for item in behavior_data)
    if len(set(variable_counts.values())) != 1:
        raise Exception(f"Warning: Mismatch in the amount of information for each trial: {variable_counts}")

    # Initalize variables
    trial_data = []
    block_data = []
    current_trial = {}
    previous_trial = {}
    current_block = {}
    previous_block = {}
    trial_within_block = 1  # tracks the trials in each block
    trial_within_session = 1  # tracks the total trials in this session
    block = 1

    port_visit_counts = {1: 0, 2: 0, 3: 0}
    total_rewards = 0

    # Group our behavioral data into trials
    info_rows_per_trial = len(variable_counts)
    for row in range(0, len(behavior_data), info_rows_per_trial):
        # Grab the data for this trial
        trial_dict = {
            item["name"]: {"timestamp": item["timestamp"], "value": item["value"]}
            for item in behavior_data[row : row + info_rows_per_trial]
        }

        # Start the first block
        if trial_within_session == 1:
            current_block = {
                "block": block,
                "pA": trial_dict["portProbs1"]["value"],
                "pB": trial_dict["portProbs2"]["value"],
                "pC": trial_dict["portProbs3"]["value"],
                "end_timestamp": block_ends_dict.get(block, default_block_end_time),
                # This may be updated later if the rat does not complete all trials in this block
                "num_trials_in_block": trial_dict["trialThresh"]["value"],
            }
        # Or move to the next block if it's time
        elif trial_dict["contingency"]["timestamp"] >= current_block["end_timestamp"]:
            # Update the number of trials in the block because we may not have reached the trial threshold
            current_block["num_trials_in_block"] = trial_within_block-1
            # The current block is now the previous block
            previous_block = current_block
            block_data.append(previous_block)
            block += 1
            # Set up the new current block
            current_block = {
                "block": block,
                "pA": trial_dict["portProbs1"]["value"],
                "pB": trial_dict["portProbs2"]["value"],
                "pC": trial_dict["portProbs3"]["value"],
                "end_timestamp": block_ends_dict.get(block, default_block_end_time),
                "num_trials_in_block": trial_dict["trialThresh"]["value"],
            }
            # Reset port visit counts and reward info for the new block
            port_visit_counts = {1: 0, 2: 0, 3: 0}
            total_rewards = 0
            trial_within_block = 1

        # Get the end port for this trial by checking which poke count increased
        current_port_visit_counts = {port_num: trial_dict[f"countPokes{port_num}"]["value"] for port_num in [1, 2, 3]}
        end_port = next((i for i in [1, 2, 3] if current_port_visit_counts[i] == port_visit_counts[i] + 1), None)
        if end_port is None:
            raise Exception(f"Warning: No end port detected for trial: {trial_within_block} in block {current_block.get('block_num')}")

        # Add the information for this trial
        current_trial = {
            "trial": trial_within_block,
            "trial_within_session": trial_within_session,
            "block": current_block.get("block"),
            "start_port": previous_trial.get("end_port", -1),
            "end_port": end_port,
            "reward": 1 if trial_dict["totalRewards"]["value"] == total_rewards + 1 else 0,
            "delay": trial_dict["ifDelay"]["value"],
            "statescript_reference_timestamp": trial_dict["contingency"]["timestamp"],
        }
        trial_data.append(current_trial)

        # Update for the next trial
        previous_trial = current_trial
        trial_within_block += 1
        trial_within_session += 1
        port_visit_counts = current_port_visit_counts
        total_rewards = trial_dict["totalRewards"]["value"]

    # Update the number of trials in the final block because we may not have reached the trial threshold
    current_block["num_trials_in_block"] = trial_within_block-1
    # Update the end time of the final block
    if current_block["end_timestamp"] == default_block_end_time:
        current_block["end_timestamp"] = previous_trial.get("statescript_reference_timestamp")
    # Append the final block
    block_data.append(current_block)

    # Sanity check that we got data for the expected number of trials
    total_trials = set(variable_counts.values()).pop()
    if len(trial_data) != total_trials:
        raise Exception(f"Warning: Expected data for {total_trials} trials, got data for {len(trial_data)}")
    
    # Map ports 1, 2, 3 to A, B, C (mapping -1 to "None" for the first start_port)
    trial_df = pd.DataFrame(trial_data)
    trial_df["start_port"] = trial_df["start_port"].map({-1: "None", 1: "A", 2: "B", 3: "C"})
    trial_df["end_port"] = trial_df["end_port"].map({1: "A", 2: "B", 3: "C"})

    return trial_df, pd.DataFrame(block_data)


def combine_nosepoke_and_trial_data(nosepoke_df, trial_df, session_end):
    """
    Check that nosepoke data matches trial data
    and add nosepoke data to the trial dataframe
    """

    # Check that we have the right lengths for one poke_in and one poke_out per trial
    if len(nosepoke_df) != 2 * len(trial_df):
        print("More pokes than trials!!")
        if session_end is None:
            raise Exception(f"Warning: Expected {2*len(trial_df)} nosepokes for {len(trial_df)} trials, got {len(nosepoke_df)}")
        else:
            # We may have more nosepoke pairs than trials if the rat kept running after the session end.
            # If we have a recorded session_end time, ignore all nosepokes after this time
            nosepoke_df = nosepoke_df[nosepoke_df["timestamp"] <= session_end]

            # Check again after removing nosepokes after session end
            if len(nosepoke_df) != 2 * len(trial_df):
                raise Exception(f"Warning: After removing nosepokes after the session end, " 
                                f"expected {2*len(trial_df)} nosepokes for {len(trial_df)} trials, "
                                f"got {len(nosepoke_df)}")
    
    # Create columns to add poke_in and poke_out data to the trial_df
    trial_df["poke_in_time_statescript"] = None
    trial_df["poke_out_time_statescript"] = None
    trial_df["poke_in_time_DIO"] = None
    trial_df["poke_out_time_DIO"] = None

    # Iterate through the trial df and find corresponding poke_in and poke_out times
    for i, trial_row in trial_df.iterrows():
        # Find the nosepoke timestamps for the current trial and add them to the trial df
        poke_in_row = nosepoke_df.loc[nosepoke_df["event_name"] == "poke_in"].iloc[i]
        poke_out_row = nosepoke_df.loc[nosepoke_df["event_name"] == "poke_out"].iloc[i]
        trial_df.at[i, "poke_in_time_statescript"] = poke_in_row["timestamp_statescript"]
        trial_df.at[i, "poke_out_time_statescript"] = poke_out_row["timestamp_statescript"]
        trial_df.at[i, "poke_in_time_DIO"] = poke_in_row["timestamp_DIO"]
        trial_df.at[i, "poke_out_time_DIO"] = poke_out_row["timestamp_DIO"]

        # Sanity check that poke out timestamp is close enough to the time
        # the trial info was printed to ensure these are matched correctly
        if abs(trial_row["statescript_reference_timestamp"] - poke_out_row["timestamp_statescript"]) > 500:
            raise Exception(f"Warning: Poke out at time {poke_out_row['timestamp_statescript']} may not match trial printed at {trial_row['statescript_reference_timestam']}")

        # Sanity check to ensure the poke in and poke out match the end_port for this trial
        if not ((trial_row["end_port"] == poke_in_row["port"]) and (trial_row["end_port"] == poke_out_row["port"])):
            raise Exception(
                f"Warning: Trial ending at port {trial_row['end_port']} does not match "
                f"poke in at port {poke_in_row['port']} and poke out at port {poke_out_row['port']}")
    return trial_df


def combine_reward_and_trial_data(trial_df, reward_DIOs):
    """
    Check that reward DIOs matches the reward data from the statescript
    and add reward DIO times to the dataframe.
    """

    # Create a dataframe of reward pump times from the DIO data
    reward_pump_times = []
    port_map = {'wellA_pump': 'A', 'wellB_pump': 'B', 'wellC_pump': 'C'}
    for key, (data, timestamps) in reward_DIOs.items():
        for i in range(0, len(data), 2):
            # Make sure the data matches structure pump_on, pump_off
            assert (data[i] == 1 and data[i + 1] == 0), f"Data mismatch at index {i} for key {key}: expected [1, 0], got [{data[i]}, {data[i + 1]}]"
            
            # Make sure the pump_on and pump_off times are close together (<5s) to check they are matched correctly 
            assert (abs(timestamps[i] - timestamps[i+1])<5), f"Expected timestamps to be within 5s, got pump_on_time {timestamps[i]}, pump_off_time {timestamps[i+1]}]"
            
            # Combine the pump_on and pump_off events into a single row
            reward_pump_times.append({
            "port": port_map[key],
            "pump_on_time": timestamps[i],
            "pump_off_time": timestamps[i + 1]
        })
    
    # Make sure pump events end up in the same order regardless of if we sort by pump_on_time or pump_off_time
    reward_pump_df = pd.DataFrame(reward_pump_times).sort_values(by="pump_on_time").reset_index(drop=True)
    assert reward_pump_df.equals(pd.DataFrame(reward_pump_times).sort_values(by="pump_off_time").reset_index(drop=True)), \
    "DataFrames do not match when sorted by pump_on_time vs. pump_off_time"
    
    # Make sure each pump_on_time occurs before its corresponding pump_off_time
    assert (reward_pump_df["pump_on_time"] < reward_pump_df["pump_off_time"]).all(), \
    "Timing mismatch: not every pump_on_time is correctly matched to its pump_off_time"

    # Ensure we have one reward pump on/off DIO per rewarded trial
    rewarded_trial_df = trial_df[trial_df["reward"] == 1]

    if len(reward_pump_df) != len(rewarded_trial_df):
        warnings.warn(f"Expected {len(rewarded_trial_df)} reward DIO events "
                        f"for {len(rewarded_trial_df)} rewarded trials, "
                        f"got {len(reward_pump_df)}")
    
    # Create columns to add reward pump times to the trial_df
    trial_df["pump_on_time"] = None
    trial_df["pump_off_time"] = None

    # Iterate through the rewarded trials and their corresponding DIO events
    for trial_row, DIO_times in zip(rewarded_trial_df.itertuples(index=True), reward_pump_df.itertuples(index=False)):
        # The end_port of this rewarded trial must match the reward pump port
        assert trial_row.end_port == DIO_times.port, \
        f"Mismatch: trial end_port {trial_row.end_port} does not match reward pump port {DIO_times.port}"
    
        # Ensure the reward pump turns on within a second of the poke
        if abs(trial_row.poke_in_time_DIO - DIO_times.pump_on_time) > 1:
            raise Exception(f"Warning: Pump on at time {DIO_times.pump_on_time} may not match nosepoke at {trial_row.poke_in_time_DIO}")

        # Update the original trial_df with pump_on_time and pump_off_time
        trial_df.loc[trial_row.Index, "pump_on_time"] = DIO_times.pump_on_time
        trial_df.loc[trial_row.Index, "pump_off_time"] = DIO_times.pump_off_time

    return trial_df


def parse_state_script_log(statescriptlog, DIO_events):
    """Function to read and parse the stateScriptLog file and align it to DIO events"""
    nosepoke_events = []
    behavior_data = []
    block_ends = []
    session_end = None

    # Read the statescriptlog line by line
    for line in str(statescriptlog).splitlines():
        # Ignore lines starting with '#'
        if line.startswith("#"):
            continue

        # Find all poke_in and poke_out events
        for match in poke_in_regex.finditer(line):
            nosepoke_events.append(
                {"timestamp": int(match.group(1)),
                "event_name": "poke_in",
                "port": int(match.group(2))})
        for match in poke_out_regex.finditer(line):
            nosepoke_events.append(
                {"timestamp": int(match.group(1)),
                "event_name": "poke_out",
                "port": int(match.group(2))})
        # Find behavioral data and reward info
        for match in behavior_data_regex.finditer(line):
            behavior_data.append(
                {"timestamp": int(match.group(1)),
                "name": match.group(2),
                "value": int(match.group(3))})
        # Check for block or session end timestamps
        for match in block_end_regex.finditer(line):
            block_ends.append({"timestamp": int(match.group(1))})
        for match in session_end_regex.finditer(line):
            session_end = int(match.group(1))

    # Create dataframes of trial and block data based on the stateScriptLog
    trial_data, block_data = parse_trial_and_block_data(behavior_data, block_ends)

    # Align statescript and DIO nosepokes and create nosepoke dataframe including only nosepoke events at a new port
    # with both statescript timestamps (trodes time) and DIO timestamps (unix time)
    nosepoke_DIOs = {key: value for key, value in DIO_events.items() if key in ['wellA_poke', 'wellB_poke', 'wellC_poke']}
    nosepoke_df = parse_nosepoke_events(nosepoke_events, nosepoke_DIOs)

    # Add combined nosepoke timestamps (both statescript and DIO) to the trial dataframe
    combined_df = combine_nosepoke_and_trial_data(nosepoke_df, trial_data, session_end)

    # TODO: add trial start and stop times based on DIO

    # Add reward pump timestamps from DIOs to the combined dataframe
    reward_DIOs =  {key: value for key, value in DIO_events.items() if key in ['wellA_pump', 'wellB_pump', 'wellC_pump']}
    combined_df = combine_reward_and_trial_data(combined_df, reward_DIOs)

    # TODO: adjust what trial the blocks started at - if it's a barrier shift session, use the barrier shift DIOs. 
    # Or if its an early barrier shift session, read from the excel sheet... sigh.
    # Maybe this can just boul down to getting which trial the block started in? Not too hard
    # But we read from the excel sheet anyway to get the maze config

    # TODO: now that we have what trial the block started in, adjust the block start and end times

    # TODO: add them as tables to the NWB (easy, same as berke) and do the same checks that berke does

    return combined_df, block_data

In [212]:
from pynwb import NWBHDF5IO
import warnings

nwb_path = 'BraveLu20240519_copy.nwb'

def get_DIO_event_data(nwbfile, behavioral_event_name):
    """Get DIO data and timestamps from the nwbfile for a given behavioral event"""

    data = nwbfile.processing["behavior"]["behavioral_events"][behavioral_event_name].data[:]
    timestamps = nwbfile.processing["behavior"]["behavioral_events"][behavioral_event_name].timestamps[:]
    return data, timestamps


def parse_DIOs(behavioral_event_data):
    """Parse behavioral event DIOs and timestamps into DIO pulses for actual events vs epoch starts."""

    # Get timestamps shared among all behavioral events (triggered by an epoch start)
    epoch_start_timestamps = set.intersection(*[set(ts) for _, ts in behavioral_event_data.values()])

    # Remove epoch start data/timestamps so we are left with only DIOs triggered by real behavioral events
    behavioral_event_data = {key: 
                     ([d for d, ts in zip(data, timestamps) if ts not in epoch_start_timestamps],
                      [ts for ts in timestamps if ts not in epoch_start_timestamps])
                      for key, (data, timestamps) in behavioral_event_data.items()}
    
    # After removing extra 0s for epoch starts, check that each 1 has a corresponding 0
    for key, (data, timestamps) in behavioral_event_data.items():
        for i in range(len(data) - 1):
            if not ((data[i] == 1 and data[i + 1] == 0) or (data[i] == 0 and data[i + 1] == 1)):
                # For now, just warn about it - it may end up being ok
                # It may be due to a session timeout that cut off a poke_out - we can deal with that elsewhere
                warnings.warn(f"{key} has mismatched DIO {data[i], data[i+1]} at timestamps {timestamps[i], timestamps[i+1]}")
                
    return behavioral_event_data, sorted(epoch_start_timestamps)


with NWBHDF5IO(nwb_path, 'r') as io:
    nwbfile = io.read()

    # Get epoch table defining session boundaries with columns "start_time", "end_time", and "tags"
    epoch_table = nwbfile.intervals["epochs"][:]

    # Filter epochs to include only run sessions (should include "r" in the tags)
    run_epochs = epoch_table[epoch_table["tags"].apply(lambda x: 'r' in x[0])]
    # Filtering epochs for run sessions should be the same as taking every other epoch
    assert run_epochs.equals(epoch_table.iloc[1::2])

    # Get all stateScriptLogs from run sessions (ignoring logs from sleep sessions)
    module = nwbfile.get_processing_module("associated_files")
    run_statescript_logs = {name: log 
                            for name, log in module.data_interfaces.items()
                            if name.startswith("statescript r")
                            }
    assert len(run_statescript_logs) == len(run_epochs)

    # Get behavioral events from the nwbfile as a dict of (data, timestamps) for each named behavioral event 
    behavioral_events = ["barrier_shift", "wellA_poke", "wellA_pump", "wellB_poke", 
                     "wellB_pump", "wellC_poke", "wellC_pump"]
    behavioral_event_data = {event: get_DIO_event_data(nwbfile, event) for event in behavioral_events}

    # Separate DIOs into those for actual behavioral events vs epoch starts
    behavioral_event_data, epoch_start_timestamps = parse_DIOs(behavioral_event_data)

    # Check that we have the expected amount of epoch starts
    # NOTE: epoch_start_timestamps from the DIO pulses lag the timestamps in the epoch_table
    # by ~1 second to ~1 minute - check where this discrepancy comes from and which one to use!
    assert len(epoch_start_timestamps) == len(epoch_table)
    
    # Parse behavioral data for each epoch using the statescriptlog and align to DIOs
    log_num = 0
    for idx, epoch in run_epochs.iterrows():
        print(idx)
        print(epoch.start_time)
        statescriptlog = list(run_statescript_logs.items())[log_num]

        # Filter DIOs to only include those in this epoch
        # TODO: maybe replace epoch.start_time and epoch.stop_time with DIO times instead?
        DIO_events_in_epoch = {
            event: (list(filtered_data), list(filtered_timestamps))
            for event, (data, timestamps) in behavioral_event_data.items()
            if (filtered := [(d, ts) for d, ts in zip(data, timestamps) if epoch.start_time <= ts <= epoch.stop_time])
            # if (filtered := [(d, ts) for d, ts in zip(data, timestamps) if epoch_start_timestamps[idx] < ts < epoch_start_timestamps[idx+1]])
            for filtered_data, filtered_timestamps in [zip(*filtered)]
        }

        combined_df, block_data = parse_state_script_log(statescriptlog, DIO_events_in_epoch)
        log_num += 1

        display(combined_df)
        display(block_data)

        
        
pd.set_option('display.float_format', '{:.0f}'.format)

#print(run_epochs)
#print(run_statescript_logs.keys())
#print("") 
print(epoch_table)
for start in epoch_start_timestamps:
    print(start)




  return func(args[0], **pargs)
  return func(args[0], **pargs)
  return func(args[0], **pargs)


1
1716147415.683


Unnamed: 0,trial,trial_within_session,block,start_port,end_port,reward,delay,statescript_reference_timestamp,poke_in_time_statescript,poke_out_time_statescript,poke_in_time_DIO,poke_out_time_DIO,pump_on_time,pump_off_time
0,1,1,1,,B,1,0,75570,75324,75540,1716147462,1716147463,1716147462,1716147463
1,2,2,1,B,A,0,1,97335,96993,97299,1716147484,1716147484,,
2,3,3,1,A,B,1,1,115265,114951,115246,1716147502,1716147502,1716147502,1716147502
3,4,4,1,B,A,1,1,136308,135776,136283,1716147523,1716147523,1716147523,1716147523
4,5,5,1,A,B,0,1,155731,155344,155712,1716147542,1716147543,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
202,65,203,3,B,A,0,0,3263628,3263508,3263599,1716150651,1716150651,,
203,66,204,3,A,B,0,0,3272356,3271929,3272336,1716150659,1716150659,,
204,67,205,3,B,A,0,0,3282576,3282434,3282553,1716150670,1716150670,,
205,68,206,3,A,B,1,0,3291477,3290981,3291455,1716150678,1716150679,1716150678,1716150678


Unnamed: 0,block,pA,pB,pC,end_timestamp,num_trials_in_block
0,1,10,50,90,1059629,69
1,2,90,10,50,2193559,69
2,3,10,90,50,3315015,69


3
1716154367.35
Length mismatch: DIO nosepokes has 1998 elements, but statescript nosepokes has 1996 elements (all nosepokes).


Unnamed: 0,trial,trial_within_session,block,start_port,end_port,reward,delay,statescript_reference_timestamp,poke_in_time_statescript,poke_out_time_statescript,poke_in_time_DIO,poke_out_time_DIO,pump_on_time,pump_off_time
0,1,1,1,,C,0,0,75363,75051,75334,1716154390,1716154390,,
1,2,2,1,C,B,1,0,85874,85544,85855,1716154400,1716154401,1716154400,1716154401
2,3,3,1,B,C,1,1,107835,107317,107795,1716154422,1716154423,1716154422,1716154423
3,4,4,1,C,B,0,0,125859,125274,125770,1716154440,1716154441,,
4,5,5,1,B,A,1,0,134591,134552,134568,1716154449,1716154449,1716154449,1716154450
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
202,65,203,3,B,C,1,0,3326405,3313873,3326380,1716157629,1716157641,1716157629,1716157629
203,66,204,3,C,B,1,1,3340188,3339551,3340168,1716157655,1716157655,1716157655,1716157655
204,67,205,3,B,C,0,1,3366203,3365182,3366177,1716157680,1716157681,,
205,68,206,3,C,B,1,1,3374084,3373261,3374049,1716157688,1716157689,1716157688,1716157689


Unnamed: 0,block,pA,pB,pC,end_timestamp,num_trials_in_block
0,1,90,50,10,1073804,69
1,2,50,10,90,2202466,69
2,3,10,90,50,3392697,69


5
1716162039.383
Length mismatch: DIO nosepokes has 1511 elements, but statescript nosepokes has 1500 elements (all nosepokes).
Length mismatch: DIO nosepokes has 282 elements, but statescript nosepokes has 278 elements (only nosepokes at new ports).


Unnamed: 0,trial,trial_within_session,block,start_port,end_port,reward,delay,statescript_reference_timestamp,poke_in_time_statescript,poke_out_time_statescript,poke_in_time_DIO,poke_out_time_DIO,pump_on_time,pump_off_time
0,1,1,1,,B,1,0,119187,118917,119153,1716162121,1716162122,1716162122,1716162122
1,2,2,1,B,A,0,1,141553,141261,141529,1716162144,1716162144,,
2,3,3,1,A,B,1,1,157063,156623,157036,1716162159,1716162160,1716162159,1716162159
3,4,4,1,B,C,0,0,185988,177136,185967,1716162180,1716162189,,
4,5,5,1,C,B,1,0,194394,193973,194364,1716162197,1716162197,1716162197,1716162197
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
134,3,135,3,A,B,1,0,2208143,2207739,2208105,1716164210,1716164211,1716164210,1716164211
135,4,136,3,B,A,1,1,2231108,2230547,2231082,1716164233,1716164234,1716164233,1716164233
136,5,137,3,A,B,1,0,2253106,2252598,2253086,1716164255,1716164256,1716164255,1716164256
137,6,138,3,B,A,1,0,2277606,2277083,2277586,1716164280,1716164280,1716164280,1716164280


Unnamed: 0,block,pA,pB,pC,end_timestamp,num_trials_in_block
0,1,50,90,10,1011832,66
1,2,50,90,10,2137726,66
2,3,50,90,10,2300312,7


7
1716168683.379


Unnamed: 0,trial,trial_within_session,block,start_port,end_port,reward,delay,statescript_reference_timestamp,poke_in_time_statescript,poke_out_time_statescript,poke_in_time_DIO,poke_out_time_DIO,pump_on_time,pump_off_time
0,1,1,1,,B,0,0,47867,47791,47845,1716168699,1716168699,,
1,2,2,1,B,A,0,0,62704,62364,62679,1716168713,1716168714,,
2,3,3,1,A,C,1,0,80666,80278,80637,1716168731,1716168732,1716168731,1716168731
3,4,4,1,C,B,0,0,103493,102947,103475,1716168754,1716168754,,
4,5,5,1,B,A,0,0,109348,108764,109322,1716168760,1716168760,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
211,68,212,3,B,C,1,1,3203537,3203481,3203506,1716171854,1716171854,1716171854,1716171855
212,69,213,3,C,B,1,0,3222702,3222069,3222671,1716171873,1716171874,1716171873,1716171873
213,70,214,3,B,C,1,1,3244214,3243923,3244185,1716171895,1716171895,1716171895,1716171895
214,71,215,3,C,B,0,0,3266362,3265626,3266331,1716171917,1716171917,,


Unnamed: 0,block,pA,pB,pC,end_timestamp,num_trials_in_block
0,1,10,50,90,939577,72
1,2,10,50,90,1944017,72
2,3,10,50,90,3274151,72


    start_time  stop_time     tags
id                                
0   1716145015 1716145961  [00_s1]
1   1716147416 1716150715  [01_r1]
2   1716150832 1716152369  [02_s2]
3   1716154367 1716157714  [03_r2]
4   1716157838 1716159630  [04_s3]
5   1716162039 1716164342  [05_r3]
6   1716164609 1716166481  [06_s4]
7   1716168683 1716171931  [07_r4]
8   1716172175 1716173172  [08_s5]
1716145017.0422087
1716147454.6872206
1716150835.987083
1716154385.1617517
1716157840.1140137
1716162098.002699
1716164611.1275144
1716168692.8061771
1716172184.6517525


In [124]:
for start in epoch_start_timestamps:
    print(start)

1716145017.0422087
1716147454.6872206
1716150835.987083
1716154385.1617517
1716157840.1140137
1716162098.002699
1716164611.1275144
1716168692.8061771
1716172184.6517525


In [165]:
#print(DIO_events_in_epoch)
DIO_events = DIO_events_in_epoch

key_map = {'wellA_pump': 'A', 'wellB_pump': 'B', 'wellC_pump': 'C'}
reward_DIOs =  {key: value for key, value in DIO_events.items() if key in ['wellA_pump', 'wellB_pump', 'wellC_pump']}
DIO_rewards_df = pd.DataFrame([
    {'timestamp': ts, 'event_name': 'rwd_on' if d == 1 else 'rwd_off', 'port': key_map[k]}
    for k, (data_list, timestamps) in reward_DIOs.items()
    for d, ts in zip(data_list, timestamps)
    ])
DIO_rewards_df = DIO_rewards_df.sort_values(by='timestamp').reset_index(drop=True)

display(DIO_rewards_df)

Unnamed: 0,timestamp,event_name,port
0,1716168731,rwd_on,C
1,1716168731,rwd_off,C
2,1716168767,rwd_on,C
3,1716168767,rwd_off,C
4,1716168784,rwd_on,B
...,...,...,...
267,1716171855,rwd_off,C
268,1716171873,rwd_on,B
269,1716171873,rwd_off,B
270,1716171895,rwd_on,C


In [1]:
from pynwb import NWBHDF5IO







nwb_path = "BraveLu20240519_.nwb"

# Read and parse DIO events from the NWB file
with NWBHDF5IO(nwb_path, 'r') as io:
    nwbfile = io.read()
    # Read DIO events and timestamps into a dict for each named behavioral event 
    behavioral_event_data = {event: get_DIO_event_data(nwbfile, event) for event in behavioral_events}
    # Read epoch start timestamps into a separate dict for each behavioral event
    # epoch_start_timestamps_dict = {event: get_DIO_event_data(nwbfile, event)[2] for event in behavioral_events}


# Do a quick check that we got the same number of epoch start timestamps for all behavioral events
# It is rare (but not impossible)

print(behavioral_event_data)


# get epoch table
# for each run epoch
# get DIOS that occur in this epoch
# get statescriptlog and parse it + with alignment to DIOs
# return trial table and also put it in NWB!
# yay

# dear future self - if we don;t have block end or session end it means we ended because of a timeout
# this could cause mismatched DIO pokes



  from pandas.core import (
  return func(args[0], **pargs)
  return func(args[0], **pargs)
  return func(args[0], **pargs)


{'barrier_shift': (array([0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0,
       1, 0, 1, 0, 0], dtype=uint8), array([1.71614502e+09, 1.71614745e+09, 1.71614836e+09, 1.71614836e+09,
       1.71614969e+09, 1.71614969e+09, 1.71615084e+09, 1.71615439e+09,
       1.71615550e+09, 1.71615550e+09, 1.71615646e+09, 1.71615646e+09,
       1.71615784e+09, 1.71616210e+09, 1.71616270e+09, 1.71616270e+09,
       1.71616370e+09, 1.71616370e+09, 1.71616461e+09, 1.71616869e+09,
       1.71616960e+09, 1.71616960e+09, 1.71616999e+09, 1.71616999e+09,
       1.71617101e+09, 1.71617101e+09, 1.71617218e+09])), 'wellA_poke': (array([0, 1, 0, ..., 1, 0, 0], dtype=uint8), array([1.71614502e+09, 1.71614536e+09, 1.71614536e+09, ...,
       1.71617050e+09, 1.71617050e+09, 1.71617218e+09])), 'wellA_pump': (array([0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
       1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
       1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,

In [74]:
def parse_DIOs(behavioral_event_data, num_epochs):
    """Parse behavioral event DIOs and timestamps into DIO pulses for actual events vs epoch starts."""

    # This should be easy, buuuuut sometimes we have issues where things don't match up.
    # So far this has manifested on sessions that end due to a timeout so we have a poke_in 
    # without a corresponding poke_out. We'll keep track of epoch start timestamps and make sure
    # they match the expected epoch starts so we can adjust the ones that don't.
    # This function is more complicated than needed and I'll probably improve it later but for now
    # it does what it needs to to account for special cases

    official_epoch_start_timestamps = None

    for key in behavioral_event_data:
        data, timestamps = behavioral_event_data[key]

        # Get indices of 1-0 pairs (DIO triggered by behavioral events) vs singular 0s (DIO triggered by epoch starts)
        event_indices = [j for i in range(len(data) - 1) if data[i] == 1 and data[i + 1] == 0 for j in (i, i + 1)]
        epoch_start_indices = [i for i in range(len(data)) if data[i] == 0 and (i == 0 or data[i - 1] != 1)]
        epoch_start_timestamps = timestamps[epoch_start_indices]
        
        # Hopefully, we have the expected number of epoch start timestamps
        if len(epoch_start_timestamps) == num_epochs:
            # Set the official epoch start timestamps if they don't yet exist
            if official_epoch_start_timestamps is None:
                official_epoch_start_timestamps = epoch_start_timestamps
            # Or check them if they already exist
            elif not np.allclose(official_epoch_start_timestamps, epoch_start_timestamps, rtol=1e-10):
                raise Exception(f"Expected epoch start timestamps {official_epoch_start_timestamps} to match {epoch_start_timestamps}")
        # If we don't get the expected number of epoch start timestamps based on parsing, fix it
        else:
            # Use the official timestamps from a different behavioral event
            if official_epoch_start_timestamps is not None:
                epoch_start_indices = [timestamps.index(t) for t in official_epoch_start_timestamps if t in timestamps]
                event_indices = [i for i in range(len(data)) if i not in epoch_start_indices]
            else:
                # If we're here, we ran into this issue in the first key. This hasn't happened yet so I haven't implemented it yet, sorry
                raise Exception(f"Key {key} has {len(epoch_start_timestamps)} epoch start timestamps, but we have {num_epochs} epochs.")

        event_data = data[event_indices]
        event_timestamps = timestamps[event_indices]

        if len(event_timestamps) + len(epoch_start_timestamps) != len(timestamps):
            raise Exception("Warning: Not all timestamps have been classified as behavioral events or epoch starts!")
        

    
    #return event_data, event_timestamps, epoch_start_timestamps


parse_DIOs(behavioral_event_data, 9)

help! wellB_poke has 8 timestamps


In [96]:
import warnings






    #print(behavioral_event_data["wellB_poke"])
        #assert all(data[i] == 1 and data[i + 1] == 0 for i in range(0, len(data) - 1, 2))

parse_DIOs(behavioral_event_data)




#     separated_dict[key] = list(zip(*[(d, ts) for d, ts in separated_dict[key] if d is not None]))
#     extra_zeros_dict[key] = list(zip(*[(d, ts) for d, ts in extra_zeros_dict[key] if d is not None]))
    

{1716164611.1275144, 1716157840.1140137, 1716154385.1617517, 1716162098.002699, 1716150835.987083, 1716168692.8061771, 1716172184.6517525, 1716145017.0422087, 1716147454.6872206}




In [184]:
#display(reward_DIOs)


rows = []
for key, (data, timestamps) in reward_DIOs.items():
    for value, ts in zip(data, timestamps):
        rows.append({"key": key, f"{value}_timestamp": ts})

df = pd.DataFrame(rows)#.pivot(index="key", columns=None).reset_index()

display(rows)
display(df)

[{'key': 'wellA_pump', '1_timestamp': 1716168989.5362775},
 {'key': 'wellA_pump', '0_timestamp': 1716168989.8152504},
 {'key': 'wellA_pump', '1_timestamp': 1716169079.5354865},
 {'key': 'wellA_pump', '0_timestamp': 1716169079.8147595},
 {'key': 'wellA_pump', '1_timestamp': 1716169111.85122},
 {'key': 'wellA_pump', '0_timestamp': 1716169112.1302927},
 {'key': 'wellB_pump', '1_timestamp': 1716168784.3525212},
 {'key': 'wellB_pump', '0_timestamp': 1716168784.6318944},
 {'key': 'wellB_pump', '1_timestamp': 1716168831.1280494},
 {'key': 'wellB_pump', '0_timestamp': 1716168831.4076557},
 {'key': 'wellB_pump', '1_timestamp': 1716169004.9235933},
 {'key': 'wellB_pump', '0_timestamp': 1716169005.2034998},
 {'key': 'wellB_pump', '1_timestamp': 1716169026.0220056},
 {'key': 'wellB_pump', '0_timestamp': 1716169026.3026118},
 {'key': 'wellB_pump', '1_timestamp': 1716169094.8556008},
 {'key': 'wellB_pump', '0_timestamp': 1716169095.1350071},
 {'key': 'wellB_pump', '1_timestamp': 1716169126.2352796},

Unnamed: 0,key,1_timestamp,0_timestamp
0,wellA_pump,1716168990,
1,wellA_pump,,1716168990
2,wellA_pump,1716169080,
3,wellA_pump,,1716169080
4,wellA_pump,1716169112,
...,...,...,...
267,wellC_pump,,1716171812
268,wellC_pump,1716171854,
269,wellC_pump,,1716171855
270,wellC_pump,1716171895,


In [190]:
reward_pump_times = []
port_map = {'wellA_pump': 'A', 'wellB_pump': 'B', 'wellC_pump': 'C'}
for key, (data, timestamps) in reward_DIOs.items():
    for i in range(0, len(data), 2):
        reward_pump_times.append({
            "port": port_map[key],
            "pump_on_time": timestamps[i],
            "pump_off_time": timestamps[i + 1]
        })

reward_pump_df = pd.DataFrame(reward_pump_times).sort_values(by="pump_on_time").reset_index(drop=True)
assert (reward_pump_df["pump_on_time"] < reward_pump_df["pump_off_time"]).all()
display(reward_pump_df)

Unnamed: 0,port,pump_on_time,pump_off_time
0,C,1716168731,1716168731
1,C,1716168767,1716168767
2,B,1716168784,1716168785
3,B,1716168831,1716168831
4,C,1716168861,1716168861
...,...,...,...
131,C,1716171812,1716171812
132,B,1716171833,1716171833
133,C,1716171854,1716171855
134,B,1716171873,1716171873
