## Usage

Provide the path to a Frank Lab nwbfile with statescriptlogs saved as AssociatedFiles objects and behavioral event DIOs in the behavior processing module.

Also provide the path to the excel sheet of experimental notes (including maze configurations for each block in column 'barrier location'), and the name of the sheet tab (defaults to 'Daily configs and notes_Bandit+' if not specified.)

Specify save_type: `"pickle"` or `"csv"` to save the trial and block dataframes for each epoch as .pkl or .csv files, `"nwb"` to save the trial and block data directly as time intervals in the nwbfile. It is possible to combine save types, e.g. `save_type="pickle,nwb"`.

Additional argument `overwrite` specifies if existing block and trials data in the nwbfile should be overwritten. Applies only to `save_type="nwb"`. Keeping `overwrite=False` is a good protection against rewriting the file over and over (it will write to the file the first time if there is no existing trial and block data, then stop and complain.)

#### Example
```
nwb_path = 'data/BraveLu20240519_copy.nwb'
excel_path = 'data/BraveLu_experimental_notes.xlsx'
sheet_name = 'Daily configs and notes_Bandit+'
add_behavioral_data_to_nwb(nwb_path=nwb_path, excel_path=excel_path, sheet_name=sheet_name, save_type="nwb", overwrite=True)
```

### Deleting block and trials table from the nwb
If you need to delete the block and trials table from the nwb, run:
```
nwb_path = 'data/BraveLu20240519_copy.nwb'
delete_blocks_and_trials_from_nwb(nwb_path)
```

This function deletes block and trials tables (stored as TimeIntervals) from an nwbfile if they exist. 
This modifies the file in-place. Note that this will not actually reduce the file size due to limitations in the HDF5 format.



In [22]:
import re
import ast
import numpy as np
import pandas as pd
from collections import Counter
import h5py
from pynwb import NWBFile, NWBHDF5IO

pd.set_option('display.float_format', '{:.0f}'.format)

# Suppress warning message text after the `warnings.warn` call
import warnings
warnings.simplefilter("always")
warnings.showwarning = lambda message, category, filename, lineno, file=None, line=None: print(f"{filename}:{lineno}: {category.__name__}: {message}")

poke_in_regex = re.compile(r"^(\d+)\sUP\s(\d+)")  # matches: timestamp UP port_num
poke_out_regex = re.compile(r"^(\d+)\sDOWN\s(\d+)")  # matches: timestamp DOWN port_num
behavior_data_regex = re.compile(r"(\d+)\s+(contingency|trialThresh|totalPokes|totalRewards|ifDelay|countPokes[1-3]|countRewards[1-3]|portProbs[1-3])\s*=\s*(\d+)")
block_end_regex = re.compile(r"(\d+)\s+This block is over!")
session_end_regex = re.compile(r"(\d+)\s+This session is complete!")


def parse_trial_and_block_data(behavior_data, block_ends):
    """
    Parse behavioral data from the stateScriptLog into dataframes of trial-level and block-level data

    Args:
    behavior_data: list of dicts of behavioral event data from the statescriptlog
    block_ends: list of timestamps of block ends found in the statescriptlog

    Returns:
    trial_df: Dataframe of trial information for this epoch
    block_df: Dataframe of block information for this epoch
    """
    
    # Convert our list of block end timestamps to a dictionary of block: timestamp
    block_ends_dict = {index + 1: item['timestamp'] for index, item in enumerate(block_ends)}
    # Set the default end time as a big number that will definitely be larger than all timestamps in the stateScriptLog.
    # This will be used if we don't have a recorded block end time and overwritten by the real timestamp later
    default_block_end_time = 100_000_000

    # Make sure we have the complete set of information for each trial
    variable_counts = Counter(item["name"] for item in behavior_data)
    info_rows_per_trial = len(variable_counts)
    if len(set(variable_counts.values())) != 1:
        raise Exception(f"Warning: Mismatch in the amount of information for each trial: {variable_counts}")

    # Initialize variables
    trial_data = []
    block_data = []
    current_trial = {}
    previous_trial = {}
    current_block = {}
    previous_block = {}
    trial_within_block = 1  # tracks the trials in each block
    trial_within_session = 1  # tracks the total trials in this session
    block = 1

    port_visit_counts = {1: 0, 2: 0, 3: 0}
    total_rewards = 0

    # Group our behavioral data into trials
    for row in range(0, len(behavior_data), info_rows_per_trial):
        # Grab the data for this trial
        trial_dict = {
            item["name"]: {"timestamp": item["timestamp"], "value": item["value"]}
            for item in behavior_data[row : row + info_rows_per_trial]
        }

        # Start the first block
        if trial_within_session == 1:
            current_block = {
                "block": block,
                "pA": trial_dict["portProbs1"]["value"],
                "pB": trial_dict["portProbs2"]["value"],
                "pC": trial_dict["portProbs3"]["value"],
                "statescript_end_timestamp": block_ends_dict.get(block, default_block_end_time),
                "start_trial": 1,
                "end_trial": None,
                # This may be updated later if the rat does not complete all trials in this block
                "num_trials": trial_dict["trialThresh"]["value"],
            }
        # Or move to the next block if it's time
        elif trial_dict["contingency"]["timestamp"] >= current_block["statescript_end_timestamp"]:
            # Update the number of trials in the block because we may not have reached the trial threshold
            current_block["num_trials"] = trial_within_block-1
            current_block["end_trial"] = current_trial.get("trial_within_session")
            # The current block is now the previous block
            previous_block = current_block
            block_data.append(previous_block)
            block += 1
            # Set up the new current block
            current_block = {
                "block": block,
                "pA": trial_dict["portProbs1"]["value"],
                "pB": trial_dict["portProbs2"]["value"],
                "pC": trial_dict["portProbs3"]["value"],
                "statescript_end_timestamp": block_ends_dict.get(block, default_block_end_time),
                "start_trial": previous_block.get("end_trial")+1,
                "end_trial": None,
                "num_trials": trial_dict["trialThresh"]["value"],
            }
            # Reset port visit counts and reward info for the new block
            port_visit_counts = {1: 0, 2: 0, 3: 0}
            total_rewards = 0
            trial_within_block = 1

        # Get the end port for this trial by checking which poke count increased
        current_port_visit_counts = {port_num: trial_dict[f"countPokes{port_num}"]["value"] for port_num in [1, 2, 3]}
        end_port = next((i for i in [1, 2, 3] if current_port_visit_counts[i] == port_visit_counts[i] + 1), None)
        if end_port is None:
            raise Exception(f"Warning: No end port detected for trial: {trial_within_block} in block {current_block.get('block_num')}")

        # Only record the delay value if this was a rewarded trial
        reward = 1 if trial_dict["totalRewards"]["value"] == total_rewards + 1 else 0

        # We may not have delay info for all trial types:
        # If trial dict does not include key "delay", create a default dict that makes delay value "N/A"
        delay_dict = trial_dict.get("delay", {"value": "N/A"})
        delay = delay_dict.get("value") if reward else "N/A"

        # Add the information for this trial
        current_trial = {
            "trial_within_block": trial_within_block,
            "trial_within_session": trial_within_session,
            "block": current_block.get("block"),
            "start_port": previous_trial.get("end_port", -1),
            "end_port": end_port,
            "reward": reward,
            "delay": delay,
            "statescript_reference_timestamp": trial_dict["contingency"]["timestamp"],
        }
        trial_data.append(current_trial)

        # Update for the next trial
        previous_trial = current_trial
        trial_within_block += 1
        trial_within_session += 1
        port_visit_counts = current_port_visit_counts
        total_rewards = trial_dict["totalRewards"]["value"]

    # Update the number of trials in the final block because we may not have reached the trial threshold
    current_block["num_trials"] = trial_within_block-1
    current_block["end_trial"] = current_trial.get("trial_within_session")
    # Update the end time of the final block
    if current_block["statescript_end_timestamp"] == default_block_end_time:
        current_block["statescript_end_timestamp"] = previous_trial.get("statescript_reference_timestamp")
    # Append the final block
    block_data.append(current_block)

    # Sanity check that we got data for the expected number of trials
    total_trials = set(variable_counts.values()).pop()
    if len(trial_data) != total_trials:
        raise Exception(f"Warning: Expected data for {total_trials} trials, got data for {len(trial_data)}")
    
    # Map ports 1, 2, 3 to A, B, C (mapping -1 to "None" for the first start_port)
    trial_df = pd.DataFrame(trial_data)
    trial_df["start_port"] = trial_df["start_port"].map({-1: "None", 1: "A", 2: "B", 3: "C"})
    trial_df["end_port"] = trial_df["end_port"].map({1: "A", 2: "B", 3: "C"})

    return trial_df, pd.DataFrame(block_data)


def parse_nosepoke_events(nosepoke_events, nosepoke_DIOs, poke_time_threshold=1):
    """
    Given a all nosepoke events from the statescript and all nosepoke DIOs, 
    ensure all events are valid and match the statescript and DIO nosepokes.
    Return a dataframe including only nosepoke events at a new port.

    The rat must stop breaking the beam for at least poke_time_threshold
    for that poke to be considered over (multiple consecutive pokes at the same 
    port in a short period of time are considered one poke).
    Instead of recording the poke_out directly following the poke_in at a new port,
    we record the last poke_out after which the rat didn't immediately poke back in again
    (immediately = poke_out and next poke_in are less than poke_time_threshold apart).

    Args:
    nosepoke_events: list of dicts, where each dict contains key-value pairs
        describing each nosepoke event from the statescriptlog:
            'timestamp': statescript timestamp, 
            'event_name': 'poke_in' or 'poke_out', 
            'port': 1, 2, or 4 (referring to ports A, B, and C)
    nosepoke_DIOs: dict with keys wellA_poke, wellB_poke, and wellC_poke corresponding
        to DIO events, and values (data, timestamps) where data is 1/0 DIO high/low
        and timestamps are the DIO timestamps of these events

    Returns:
    Dataframe including only nosepoke events at new ports, with columns
    event_name, port, timestamp_DIO, timestamp_statescript
    """

    # Make sure we have the same number of poke_in and poke_out events from the statescriptlog
    # NOTE: We later check that each poke_in is followed by a poke_out at the same port, so maybe overkill.
    # Keeping for now, probably delete later.
    event_counts = Counter(event["event_name"] for event in nosepoke_events)
    if event_counts["poke_in"] != event_counts["poke_out"]:
        raise Exception(f"Warning: {event_counts['poke_in']} poke_in events but {event_counts['poke_out']} poke_out events in the statescript!")

    # Convert statescript pokes from list of dicts to a dataframe (mapping DIO 1, 2, 4 to ports A, B, C)
    statescript_nosepoke_df = pd.DataFrame(nosepoke_events)
    statescript_nosepoke_df["port"] = statescript_nosepoke_df["port"].map({1: "A", 2: "B", 4: "C"})

    # Create a dataframe of DIO pokes that matches the dataframe from the statescript
    port_map = {'wellA_poke': 'A', 'wellB_poke': 'B', 'wellC_poke': 'C'}
    DIO_nosepoke_df = pd.DataFrame([
    {'timestamp': ts, 'event_name': 'poke_in' if d == 1 else 'poke_out', 'port': port_map[k]}
    for k, (data_list, timestamps) in nosepoke_DIOs.items()
    for d, ts in zip(data_list, timestamps)
    ])
    DIO_nosepoke_df = DIO_nosepoke_df.sort_values(by='timestamp').reset_index(drop=True)

    # Print poke info for debugging purposes while this code is still in development
    # TODO: When we switch to a formal logging framework, make this a low/debug log level
    print(f"{len(DIO_nosepoke_df)} nosepokes from the DIOs: {DIO_nosepoke_df['port'].value_counts().to_dict()}")
    print(f"{len(statescript_nosepoke_df)} nosepokes from the statescript: {statescript_nosepoke_df['port'].value_counts().to_dict()}")

    # Make sure each poke_in is followed by a poke_out at the same port (statescript)
    for row in range(0, len(statescript_nosepoke_df) - 1, 2):
        event1 = statescript_nosepoke_df.iloc[row]
        event2 = statescript_nosepoke_df.iloc[row + 1]
        if not (event1["event_name"] == "poke_in" and event2["event_name"] == "poke_out" and event1["port"] == event2["port"]):
            raise Exception(f"Warning: Invalid nosepoke pair from statescript at timestamps {event1['timestamp']} and {event2['timestamp']}!")
      
    # Make sure each poke_in is followed by a poke_out at the same port (DIO)
    for row in range(0, len(DIO_nosepoke_df) - 1, 2):
        event1 = DIO_nosepoke_df.iloc[row]
        event2 = DIO_nosepoke_df.iloc[row + 1]
        if not (event1["event_name"] == "poke_in" and event2["event_name"] == "poke_out" and event1["port"] == event2["port"]):
            raise Exception(f"Warning: Invalid nosepoke pair from DIOs at timestamps {event1['timestamp']} and {event2['timestamp']}!")

    # Make sure the number of DIO pokes matches the number of pokes from the statescriptlog.
    # Note that the DIO may have more pokes because it keeps recording after the statescript has been stopped (this is ok).
    # Warn the user about it anyway for the sake of providing all of the info.
    if len(DIO_nosepoke_df) > len(statescript_nosepoke_df):
        warnings.warn(f"Length mismatch: {len(DIO_nosepoke_df)} nosepokes from DIOs, " 
                      f"but only {len(statescript_nosepoke_df)} nosepokes from statescript.\n"
                      "The DIO may have more pokes because it keeps recording after the statescript has been stopped (this is ok).")
    # The statescript should never have more pokes than the DIOs - break if this happens so we can figure out why.
    elif len(statescript_nosepoke_df) > len(DIO_nosepoke_df):
        raise Exception(f"Length mismatch: {len(statescript_nosepoke_df)} nosepokes from statescript but {len(DIO_nosepoke_df)} nosepokes from DIOs!")
    
    # Match statescript and DIO pokes.
    # For each event_name and port combination, add an index column enumerating which one it is.
    # This will allow us to merge the DIO and statescript dfs while matching the correct instances of each event
    DIO_nosepoke_df['index'] = DIO_nosepoke_df.groupby(['event_name', 'port']).cumcount()
    statescript_nosepoke_df['index'] = statescript_nosepoke_df.groupby(['event_name', 'port']).cumcount()

    # Merge based on matching event_name, port, and index (created above)
    merged_nosepokes = pd.merge(DIO_nosepoke_df, statescript_nosepoke_df, on=['event_name', 'port', 'index'], how='inner', suffixes=('_DIO', '_statescript'))

    # Also do an outer merge that keeps all rows so we can print info about which rows (if any) do not match.
    # This is for info/debugging purposes only.
    merged_nosepokes_outer = pd.merge(DIO_nosepoke_df, statescript_nosepoke_df, on=['event_name', 'port', 'index'], how='outer', suffixes=('_DIO', '_statescript'))
    DIO_statescript_mismatches = merged_nosepokes_outer[merged_nosepokes_outer["timestamp_DIO"].isna() | merged_nosepokes_outer["timestamp_statescript"].isna()]
    
    if not DIO_statescript_mismatches.empty:
        print("Mismatched rows:")
        print(DIO_statescript_mismatches)
    else:
        print("All DIO and statescript nosepokes were matched successfully.")

    # Iterate through pairs of rows in the dataframe, keeping only rows 
    # that represent poke_in and poke_out events at a new port.
    # The rat must stop breaking the beam for at least poke_time_threshold
    # for that poke to be considered over (multiple consecutive pokes at the same 
    # port in a short period of time are considered one poke).
    # Instead of recording the poke_out directly following the poke_in at a new port,
    # we record the last poke_out after which the rat didn't immediately poke back in again
    # (immediately = poke_out and next poke_in are less than poke_time_threshold apart).
    nosepokes_at_new_ports = []
    current_port = None
    potential_poke_out = None

    # Iterate through poke_in / poke_out pairs
    for row in range(0, len(merged_nosepokes) - 1, 2):
        poke_in = merged_nosepokes.iloc[row]
        poke_out = merged_nosepokes.iloc[row + 1]
        # Sanity check for merged statescript/DIO events: make sure each poke_in is followed by a poke_out at the same port
        if not (poke_in["event_name"] == "poke_in" and poke_out["event_name"] == "poke_out" and poke_in["port"] == poke_out["port"]):
            raise Exception(f"Warning: Invalid nosepoke pair at timestamps {poke_in['timestamp_DIO']} and {poke_out['timestamp_DIO']}!")
        
        # If we have a poke_in at a new port, record it!
        if poke_in["port"] != current_port:
            # Record the last poke_out for the previous port if we haven't already
            if potential_poke_out is not None:
                nosepokes_at_new_ports.append(potential_poke_out)
            # Add the poke_in event
            nosepokes_at_new_ports.append(poke_in)
            # Save the poke_out as the potential poke end
            # This will likely be overwritten by the "true" poke_out, defined as the time 
            # the rat pokes out and then does not immediately poke back in again
            potential_poke_out = poke_out
            # Update the current port so we can search for the "true" poke_out end
            current_port = poke_in["port"]

        # Or if we have another poke_in at the current port, and we are searching for the "true" poke_out,
        # check if the poke has already ended or if this is a continuation of the same poke event.
        elif potential_poke_out is not None:
            # If the poke_in is close enough in time to the previous poke_out, it counts as the same poke
            if (poke_in["timestamp_DIO"] - potential_poke_out["timestamp_DIO"]) <= poke_time_threshold:
                # Update the poke_out as the potential poke end
                potential_poke_out = poke_out
            # Otherwise, the poke_in is far enough in time from the previous poke_out, so the poke has ended.
            else:
                # The previous potential_poke_out is the true poke_out, so record it
                nosepokes_at_new_ports.append(potential_poke_out)
                # Indicate the poke has ended and we are no longer searching for the poke_out
                potential_poke_out = None
        
         # Otherwise, if we reach here, it means that we have another poke_in at the current port,
         # but we have already determined that the poke event has ended. Ignore these pokes.
        else:
            # NOTE: While this code is still in development, it may be helpful to print how often we reach
            # this case. It may also be helpful to print how far the poke_in was from the previous
            # poke_out as feedback on if we have chosen a good poke_time_threshold or if it should be adjusted.
            continue

    # Add the last poke_out if we missed it
    if potential_poke_out is not None:
        nosepokes_at_new_ports.append(potential_poke_out)

    # Return a dataframe of nosepokes including only nosepokes at new ports
    return pd.DataFrame(nosepokes_at_new_ports).drop(columns='index')


def combine_nosepoke_and_trial_data(nosepoke_df, trial_df, session_end):
    """
    Check that nosepoke data matches trial data and add nosepoke data to the trial dataframe

    Args:
    nosepoke_df: Dataframe of nosepoke events at new ports with columns event_name, port, timestamp_DIO, timestamp_statescript
    trial_df: Dataframe of trial information
    session_end: Timestamp of session end (in statescript time), or None if no session_end was recorded in the statescript

    Returns:
    trial_df: Dataframe of trial information with added columns for poke_in and poke_out times (both DIO time and statescript time)
    """
    

    # Check that we have the right lengths for one poke_in and one poke_out per trial
    if len(nosepoke_df) != 2 * len(trial_df):
        if session_end is None:
            raise Exception(f"Warning: Expected {2*len(trial_df)} nosepokes for {len(trial_df)} trials, got {len(nosepoke_df)}")
        else:
            # We may have more nosepoke pairs than trials if the rat kept running after the session end.
            # If we have a recorded session_end time, ignore all poke_in after this time.
            nosepoke_df = nosepoke_df.reset_index()
            pokes_before_session_end = nosepoke_df[nosepoke_df["timestamp_statescript"] <= session_end].copy()
        
            # Note that if the last event before the session end is a poke_in, make sure to keep the poke_out!
            # The poke_out likely happened after the session end time was printed 
            # (as the session end print is triggered by poke_in and not poke_out).
        
            last_poke_index = 0
            # If the last event before the session end was a poke_in, add 1 to the index to keep its poke_out
            if pokes_before_session_end.iloc[-1]["event_name"] == "poke_in":
                last_poke_index = pokes_before_session_end.index[-1] + 1
            # If the last event before the session end is a poke_out, no adjustment needed!
            elif pokes_before_session_end.iloc[-1]["event_name"] == "poke_out":
                last_poke_index = pokes_before_session_end.index[-1]
            else:
                raise Exception("event_name must be either poke_in or poke_out!!")

            # Filter dataframe to remove all extra pokes
            nosepoke_df = nosepoke_df[nosepoke_df.index <= last_poke_index]
    
            # Check again after removing nosepokes after session end
            if len(nosepoke_df) != 2 * len(trial_df):
                raise Exception(f"Warning: After removing nosepokes after the session end, " 
                                f"expected {2*len(trial_df)} nosepokes for {len(trial_df)} trials, "
                                f"got {len(nosepoke_df)}")
    
    # Create columns to add poke_in and poke_out data to the trial_df
    trial_df["poke_in_time_statescript"] = None
    trial_df["poke_out_time_statescript"] = None
    trial_df["poke_in_time"] = None # DIO time
    trial_df["poke_out_time"] = None # DIO time

    # Iterate through the trial df and find corresponding poke_in and poke_out times
    for i, trial_row in trial_df.iterrows():
        # Find the nosepoke timestamps for the current trial and add them to the trial df
        poke_in_row = nosepoke_df.loc[nosepoke_df["event_name"] == "poke_in"].iloc[i]
        poke_out_row = nosepoke_df.loc[nosepoke_df["event_name"] == "poke_out"].iloc[i]
        trial_df.at[i, "poke_in_time_statescript"] = poke_in_row["timestamp_statescript"]
        trial_df.at[i, "poke_out_time_statescript"] = poke_out_row["timestamp_statescript"]
        trial_df.at[i, "poke_in_time"] = poke_in_row["timestamp_DIO"]
        trial_df.at[i, "poke_out_time"] = poke_out_row["timestamp_DIO"]

        # Sanity check that poke_in timestamp is close enough to the time
        # the trial info was printed to ensure these are matched correctly.
        # NOTE: It seems the trial info is printed after the first poke_out following
        # the poke_in (which is not always the recorded poke_out - see parse_nosepoke_events).
        # This check worked better in an earlier version of the code where we checked against that poke_out,
        # which we no longer record. We probably want to suppress output or set this to a low (debug) 
        # log level in the future, as even correctly matched pokes can trigger this warning
        # if the poke was long (causing trial info to be printed >5s after initial poke_in).
        # Keeping it for now - it is still a useful warning as we have not encountered all bug-causing cases.
        # if abs(trial_row["statescript_reference_timestamp"] - poke_in_row["timestamp_statescript"]) > 5000:
        #    warnings.warn(f"Poke in at time {poke_in_row['timestamp_statescript']} may not match trial printed at {trial_row['statescript_reference_timestamp']}")

        # Sanity check to ensure the poke in and poke out match the end_port for this trial
        if not ((trial_row["end_port"] == poke_in_row["port"]) and (trial_row["end_port"] == poke_out_row["port"])):
            raise Exception(
                f"Warning: Trial ending at port {trial_row['end_port']} does not match "
                f"poke in at port {poke_in_row['port']} and poke out at port {poke_out_row['port']}")
            
        # Add start and end times based on DIO poke times (trials are poke_out to poke_out)
        trial_df['start_time'] = trial_df['poke_out_time'].shift(1)
        trial_df['end_time'] = trial_df['poke_out_time']

        # Set the start time of the first trial to 3 seconds before the first poke_in.
        # This handles cases where the epoch start button was pressed and then the rat
        # was placed in the maze, so using epoch start time would be too early.
        # This may be overwritten by the epoch start time later, if the recorded epoch start 
        # is after this time (which could happen in the case where 2 people were present 
        # so the epoch start button was pressed at the same time the rat was placed in the maze).
        trial_df.at[0, 'start_time'] = trial_df.at[0, 'poke_in_time']-3
        
        # Add trial duration as a column
        trial_df['duration'] = trial_df['end_time'] - trial_df['start_time']

    return trial_df


def combine_reward_and_trial_data(trial_df, reward_DIOs):
    """
    Check that reward data from the statescript matches reward data 
    from the DIOs, and add reward DIO times to the trial dataframe.

    Args:
    trial_df: Dataframe of information for each trial including column 'reward'
    reward_DIOs: tuple of (1/0 data, timestamps) for reward DIOs 'wellA_pump', 'wellB_pump', 'wellC_pump'

    Returns:
    trial_df: Dataframe of information for each trial with added columns 'pump_on_time' and 'pump_off_time'
    """

    # Create a dataframe of reward pump times from the DIO data
    reward_pump_times = []
    port_map = {'wellA_pump': 'A', 'wellB_pump': 'B', 'wellC_pump': 'C'}
    for key, (data, timestamps) in reward_DIOs.items():
        for i in range(0, len(data), 2):
            # Make sure the data matches structure pump_on, pump_off
            assert (data[i] == 1 and data[i + 1] == 0), f"Data mismatch at index {i} for key {key}: expected [1, 0], got [{data[i]}, {data[i + 1]}]"
            
            # Make sure the pump_on and pump_off times are close together (<1s) to check they are matched correctly 
            assert (abs(timestamps[i] - timestamps[i+1])<1), f"Expected timestamps to be within 1s, got pump_on_time {timestamps[i]}, pump_off_time {timestamps[i+1]}]"
            
            # Combine the pump_on and pump_off events into a single row
            reward_pump_times.append({
            "port": port_map[key],
            "pump_on_time": timestamps[i],
            "pump_off_time": timestamps[i + 1]
        })
    
    # Make sure pump events end up in the same order regardless of if we sort by pump_on_time or pump_off_time
    reward_pump_df = pd.DataFrame(reward_pump_times).sort_values(by="pump_on_time").reset_index(drop=True)
    assert reward_pump_df.equals(pd.DataFrame(reward_pump_times).sort_values(by="pump_off_time").reset_index(drop=True)), \
    "DataFrames do not match when sorted by pump_on_time vs. pump_off_time"
    
    # Make sure each pump_on_time occurs before its corresponding pump_off_time
    assert (reward_pump_df["pump_on_time"] < reward_pump_df["pump_off_time"]).all(), \
    "Timing mismatch: not every pump_on_time is correctly matched to its pump_off_time"

    # Ensure we have one reward pump on/off DIO per rewarded trial
    rewarded_trial_df = trial_df[trial_df["reward"] == 1]

    if len(reward_pump_df) != len(rewarded_trial_df):
        warnings.warn(f"Expected {len(rewarded_trial_df)} reward DIO events "
                        f"for {len(rewarded_trial_df)} rewarded trials, "
                        f"got {len(reward_pump_df)}")
    
    # Create columns to add reward pump times to the trial_df
    trial_df["pump_on_time"] = "N/A"
    trial_df["pump_off_time"] = "N/A"

    # Iterate through the rewarded trials and their corresponding DIO events
    for trial_row, DIO_times in zip(rewarded_trial_df.itertuples(index=True), reward_pump_df.itertuples(index=False)):
        # The end_port of this rewarded trial must match the reward pump port
        assert trial_row.end_port == DIO_times.port, \
        f"Mismatch: trial end_port {trial_row.end_port} does not match reward pump port {DIO_times.port}"
    
        # Ensure the reward pump turns on within a second of the poke
        if abs(trial_row.poke_in_time - DIO_times.pump_on_time) > 1:
            raise Exception(f"Warning: Pump on at time {DIO_times.pump_on_time} may not match nosepoke at {trial_row.poke_in_time}")

        # Update the original trial_df with pump_on_time and pump_off_time
        trial_df.loc[trial_row.Index, "pump_on_time"] = DIO_times.pump_on_time
        trial_df.loc[trial_row.Index, "pump_off_time"] = DIO_times.pump_off_time

    return trial_df


def determine_session_type(block_data):
    """Determine the session type ("barrier change" or "probability change") based on block data."""

    # This case is rare/hopefully nonexistent - we always expect to have more than one block per session
    if len(block_data) == 1:
        return "single block"

    # Get the reward probabilities at each port for each block in the session
    reward_probabilities = []
    for _, block in block_data.iterrows():
        reward_probabilities.append([block["pA"], block["pB"], block["pC"]])

    # If the reward probabilities change with each block, this is a probability change session
    if reward_probabilities[0] != reward_probabilities[1]:
        return "probability change"
    # Otherwise, this must be a barrier change session
    else:
        return "barrier change"

    
def adjust_block_start_trials(trial_data, block_data, DIO_events, excel_data):
    '''
    Adjust the block start trials based on barrier_shift DIO events (if they exist)
    or data from the experimental notes excel sheet
    
    Args:
    trial_data: Dataframe of information for each trial in this epoch
    block_data: Dataframe of information for each block in this epoch
    DIO_events: dict of event_name: (data, timestamps) for each named DIO event,
    including "barrier_shift" event if we have data for it
    excel_data: Dataframe of info for this epoch, with column "barrier shift trial ID"
    
    Returns:
    trial_data: Dataframe of trial info reflecting updated block boundaries
    block_data: Dataframe of block info reflecting updated block boundaries
    '''

    barrier_shift_trials_DIO = None
    barrier_shift_trials_excel = None
    
    # If barrier_shift DIOs exist, use those as the ground truth
    if "barrier_shift" in DIO_events:
        print("Adjusting barrier shift times based on barrier_shift DIOs ...")
        
        barrier_shift_DIOs = DIO_events.get("barrier_shift")
        # The barrier_shift_DIOs are a pair of lists: (1/0 events, timestamps)
        # Take every other timestamp to get the times of the "1" (DIO button press) events
        # (We have already checked each 1 has a corresponding 0 so just taking every other is fine)
        barrier_shift_times = barrier_shift_DIOs[1][0::2]

        barrier_shift_trials_DIO = []
        for barrier_shift_time in barrier_shift_times:
            # Find the closest poke_in time just before the barrier shift time
            # Barrier shifts happen when the rat is at a port (just after poke_in)
            # The next trial (that begins on poke_out) is the first trial of the new block
            trials_pre_shift = trial_data.index[trial_data['poke_in_time'] <= barrier_shift_time]
            closest_idx = ((trial_data.loc[trials_pre_shift, 'poke_in_time'] - barrier_shift_time).abs()).idxmin()
            barrier_shift_trial = trial_data.loc[closest_idx, 'trial_within_session']

            # Sanity check: get the time from trial start to barrier shift, and shift to next poke
            barrier_shift_time_from_poke = barrier_shift_time - trial_data.loc[closest_idx, 'poke_in_time']
            time_to_next_poke = trial_data.loc[closest_idx+1, 'poke_in_time'] - barrier_shift_time
            print(f"Barrier shift DIO pressed {barrier_shift_time_from_poke:.2f}s "
                  f"after poke_in of trial {barrier_shift_trial}.")
            print(f"Next poke_in was {time_to_next_poke:.2f}s after barrier shift DIO pressed.")
            print(f"Trial {barrier_shift_trial+1} is the first trial of the new block.")
  
            barrier_shift_trials_DIO.append(barrier_shift_trial)
        # Convert to np.int(64) to int
        barrier_shift_trials_DIO = [int(x) for x in barrier_shift_trials_DIO]

    # If the excel sheet has barrier shift info, use that also
    if "barrier shift trial ID" in excel_data.columns:
        # Read barrier shift trials as a comma-separated string, and convert to a list
        barrier_shift_trials_str = excel_data["barrier shift trial ID"].iloc[0]
        barrier_shift_trials_excel = list(map(int, barrier_shift_trials_str.split(", ")))    

    # If we have barrier shift info from both DIOs and excel sheet, check if they match
    if barrier_shift_trials_DIO is not None and barrier_shift_trials_excel is not None:
        # If there is a mismatch between DIO and excel, DIO wins
        if barrier_shift_trials_DIO != barrier_shift_trials_excel:
            warnings.warn(f"Mismatch in barrier shift info between barrier_shift DIOs and data from excel sheet!\n"
                f"DIO has barrier shift trials {barrier_shift_trials_DIO}, "
                f"excel sheet has {barrier_shift_trials_excel}!")
        else:
            print(f"Barrier_shift DIOs match data from excel sheet, "
                  f"with barrier shifts at trials {barrier_shift_trials_DIO}")
        barrier_shift_trials = barrier_shift_trials_DIO
    
    # If only DIOs, use that
    elif barrier_shift_trials_DIO is not None:
        barrier_shift_trials = barrier_shift_trials_DIO

    # Or if only excel, use that
    elif  barrier_shift_trials_excel is not None:
        barrier_shift_trials = barrier_shift_trials_excel
    
    else:
        raise ValueError("No 'barrier_shift' DIO event or 'barrier shift trial ID' from excel found\n"
                         "when trying to adjust block start trials for a barrier change session.")

    # Sanity check: make sure the last barrier shift is before the last trial
    last_trial = block_data['end_trial'].iloc[-1]
    if barrier_shift_trials[-1] > last_trial:
        # This has never happened (and it never should), but DIOs can be odd. 
        # Complain and break if it does so we can evaluate what caused it and how to handle it then.
        raise ValueError("Something went wrong! Last barrier shift is after the last trial!")

    # Set up the start and end trials of the blocks based on the barrier shifts
    # Add trial 1 as the start of the first block and the last trial as the end of the last block
    block_start_trials = [1] + [t + 1 for t in barrier_shift_trials]
    block_end_trials = barrier_shift_trials + [int(last_trial)]

    # Get pA, pB, pC (same for all blocks because this is a barrier change session)
    pA, pB, pC = block_data.iloc[0][['pA', 'pB', 'pC']]

    # Create new block dataframe using new block start/end trials
    # statescript_end_timestamp is now N/A because statescript timestamps 
    # no longer correspond to barrier changes
    new_block_data = pd.DataFrame({
        'block': range(1, len(block_start_trials)+1),
        'pA': [pA]*len(block_start_trials), 
        'pB': [pB]*len(block_start_trials),
        'pC': [pC]*len(block_start_trials),
        'statescript_end_timestamp': "N/A",
        'start_trial': block_start_trials,
        'end_trial': block_end_trials,
        'num_trials': [end-start+1 for end,start in zip(block_end_trials, block_start_trials)],
        'task_type': 'barrier change'
    })

    # Update 'block' and 'trial' columns in trial_data to reflect the updated block boundaries
    for _, row in new_block_data.iterrows():
        trials_in_block = (trial_data['trial_within_session'] >= row['start_trial']) & (trial_data['trial_within_session'] <= row['end_trial'])
        trial_data.loc[trials_in_block, 'block'] = row['block']
        trial_data.loc[trials_in_block, 'trial_within_block'] = range(1, trials_in_block.sum() + 1)

    return trial_data, new_block_data


def add_block_start_end_times(trial_data, block_data):
    '''
    Add the DIO start and end times to the blocks
    
    Args:
    trial_data: Dataframe of trial information with columns 'start_time' and 'end_time'
    block_data: Dataframe of block information

    Returns:
    block_data: Dataframe of block information with columns 'start_time' and 'end_time' added
    '''

    # The start time of a block is the start time of the first trial in the block
    block_data['start_time'] = block_data['start_trial'].map(lambda x: trial_data.loc[x-1, 'start_time'])
    # The end time of a block is the end time of the last trial in the block
    block_data['end_time'] = block_data['end_trial'].map(lambda x: trial_data.loc[x-1, 'end_time'])

    return block_data


def validate_trial_and_block_data(trial_data, block_data):
    """Run basic tests to check that trial and block data is valid."""

    # The number of the last trial/block must match the number of trials/blocks
    assert len(trial_data) == trial_data["trial_within_session"].max()
    assert len(block_data) == block_data["block"].max()

    # All trial numbers must be unique and match the range 1 to [num trials in session]
    assert set(trial_data["trial_within_session"]) == set(range(1, len(trial_data) + 1))

    # All block numbers must be unique and match the range 1 to [num blocks in session]
    assert set(block_data["block"]) == set(range(1, len(block_data) + 1))

    # There must be a legitimate reward value (1 or 0) for all trials
    assert set(trial_data["reward"]).issubset({0, 1})

    # There must be a legitimate p(reward) value for each block at ports A, B, and C
    assert block_data[["pA", "pB", "pC"]].applymap(lambda x: 0 <= x <= 100).all().all()

    # There must be a not-null maze_configuration for each block
    assert not block_data["maze_configuration"].isnull().any(), "Not all blocks have a maze configuration!"
    
    # There must be a valid task type for each block
    assert block_data["task_type"].isin(["probability change", "barrier change"]).all()
    # The task type must be the same for all blocks in the epoch
    assert block_data["task_type"].nunique() == 1
    
    # In a probability change session, reward probabilities vary and maze configs do not
    if block_data["task_type"].iloc[0] == "probability change":
        # All maze configs should be the same
        assert block_data["maze_configuration"].nunique() == 1
        # Reward probabilities should vary
        # They may eventually repeat with many blocks, so minimum 1 change is ok
        assert block_data["pA"].nunique() > 1
        assert block_data["pB"].nunique() > 1
        assert block_data["pC"].nunique() > 1
    # In a barrier change session, maze configs vary and reward probabilities do not
    elif block_data["task_type"].iloc[0] == "barrier change":
        # All reward probabilities should be the same for all blocks
        assert block_data["pA"].nunique() == 1
        assert block_data["pB"].nunique() == 1
        assert block_data["pC"].nunique() == 1
        # Maze configurations should be different for each block
        assert block_data["maze_configuration"].nunique() == len(block_data)
    
    summed_trials = 0
    # Check trials within each block
    for _, block in block_data.iterrows():
        block_trials = trial_data[trial_data["block"] == block["block"]]
        trial_numbers = block_trials["trial_within_block"]

        # All trial numbers in the block must be unique and match the range 1 to [num trials in block]
        num_trials_expected = block["num_trials"]
        num_trials_expected_2 = block["end_trial"] - block["start_trial"] + 1
        assert len(trial_numbers.unique()) == num_trials_expected == num_trials_expected_2
        assert set(trial_numbers) == set(range(1, int(num_trials_expected) + 1))

        # Check time alignment between trials and blocks
        first_trial = block_trials.loc[block_trials["trial_within_block"].idxmin()]
        last_trial = block_trials.loc[block_trials["trial_within_block"].idxmax()]
        block_start = block["start_time"]
        block_end = block["end_time"]

        assert first_trial["start_time"] == block_start, (
            f"First trial start {first_trial['start_time']} does not match block start {block_start}"
        )
        assert last_trial["end_time"] == block_end, (
            f"Last trial end {last_trial['end_time']} does not match block end {block_end}"
        )

        # Ensure trial times are within block bounds
        assert block_trials["start_time"].between(block_start, block_end).all(), (
            f"Some trial start_times are outside block bounds ({block_start} to {block_end})"
        )
        assert block_trials["end_time"].between(block_start, block_end).all(), (
            f"Some trial end_times are outside block bounds ({block_start} to {block_end})"
        )

        # Ensure poke_in_time and poke_out_time are within trial bounds
        assert block_trials["poke_in_time"].between(
            block_trials["start_time"], block_trials["end_time"]
        ).all(), (
            f"Some poke_in_times are outside trial bounds (start_time to end_time)"
        )
        assert (block_trials["poke_out_time"] == block_trials["end_time"]).all(), (
            f"Some poke_out_times do not match the trial end_time"
        )

        summed_trials += num_trials_expected

    # The summed number of trials in each block must match the total number of trials
    assert summed_trials == len(trial_data)


def validate_poke_timestamps(trial_data):
    """
    Validate that the DIO poke_in_time and poke_out_time matches the statescript 
    and poke_in_time and poke_out_time for each trial, after converting units.
    """
    
    # Get the time of the first poke_in so we can convert all other timestamps to be relative to this
    first_poke_in_DIO = trial_data.loc[0, "poke_in_time"]
    first_poke_in_statescript = trial_data.loc[0, "poke_in_time_statescript"]

    # Get relative DIO poke_in and poke_out times, convert to ms to match statescript times
    DIO_poke_in_times = (trial_data["poke_in_time"] - first_poke_in_DIO) * 1000
    DIO_poke_out_times = (trial_data["poke_out_time"] - first_poke_in_DIO) * 1000

    # Get relative statescript poke_in and poke_out times
    statescript_poke_in_times = (trial_data["poke_in_time_statescript"] - first_poke_in_statescript)
    statescript_poke_out_times = (trial_data["poke_out_time_statescript"] - first_poke_in_statescript)

    # Make sure DIO and statescript times are close (enough) together.
    
    # It is expected for the timestamps to drift apart over the course 
    # of a session (drifting by roughly 0 to 0.5 ms per trial). 
    # By the end of a session, the DIO and statescript timestamps may be up to ~70ms apart.
    # Reduce warning_tol_ms to a lower value to watch this happen.
    # Because of this, warning_tol_ms is currently set to 100ms, which should 
    # be high enough to only warn about variations larger than this expected drift.
    warning_tol_ms = 100

    # NOTE: For now, our error tolerance is unreasonably high because we have some weird stuff
    # going on in BraveLu20240519 epoch 3 and I want to print about it but not break.
    # Apparently trodes crashed during this session. Investigate further.
    # Ultimately change this tolerance to something like 1000
    error_tol_ms = 100_000
    
    # Check poke_in times
    for i, (DIO_poke, ss_poke) in enumerate(zip(DIO_poke_in_times, statescript_poke_in_times), start=1):
        diff = abs(DIO_poke - ss_poke)
        if diff > error_tol_ms:
            raise ValueError(f"Trial {i}: DIO poke_in at {DIO_poke:.1f} and statescript poke_in at {ss_poke} are {diff:.1f} ms apart, exceeds error tolerance of {error_tol_ms} ms")
        elif diff > warning_tol_ms:
            warnings.warn(f"Trial {i}: DIO poke_in at {DIO_poke:.1f} and statescript poke_in at {ss_poke} are {diff:.1f} ms apart, exceeds warning tolerance of {warning_tol_ms} ms")

    # Check poke_out times
    for i, (DIO_poke, ss_poke) in enumerate(zip(DIO_poke_out_times, statescript_poke_out_times), start=1):
        diff = abs(DIO_poke - ss_poke)
        if diff > error_tol_ms:
            raise ValueError(f"Trial {i}: DIO poke_out at {DIO_poke:.1f} and statescript poke_out at {ss_poke} are {diff:.1f} ms apart, exceeds error tolerance of {error_tol_ms} ms")
        elif diff > warning_tol_ms:
            warnings.warn(f"Trial {i}: DIO poke_out at {DIO_poke:.1f} and statescript poke_out at {ss_poke} are {diff:.1f} ms apart, exceeds warning tolerance of {warning_tol_ms} ms")


def get_barrier_locations_from_excel(excel_data):
    """
    Load barrier locations from all rows in excel_data,
    where each row is a session.

    Args:
    excel_data: dataframe of info for this experiment
    (originally read from the excel sheet of experimental notes)

    If there are multiple sessions, return a list of lists of sets:
    where the first sub-list is for each run session that day, 
    second sub-list is for each block of the session.
    
    If there is a single session, return a lists of sets:
    where each set is for each block of the session.
    """

    # Helper to read the sets of barrier locations from the excel sheet
    def extract_sets_from_string(value):
        if isinstance(value, str):
            # Regular expression to find all the sets in the string
            sets = re.findall(r'\{.*?\}', value)
            return [ast.literal_eval(s) for s in sets]
        return None

    # Get barrier locations as a list of lists of sets
    list_of_barrier_sets = excel_data['barrier location'].apply(extract_sets_from_string).tolist()

    # If excel_data is for only one session, remove the outer list
    if excel_data.shape[0] == 1 and len(list_of_barrier_sets) == 1:
        return list_of_barrier_sets[0]
    # Else return a list of lists, where each outer list is for a different session
    else:
        return list_of_barrier_sets


def add_barrier_locations_to_block_data(block_data, excel_data, session_type):
    """
    Add "maze_configuration" column to block_data.

    Args:
    block_data: Dataframe of information for each block in this epoch
    excel_data: Dataframe of info for this experiment, with column "barrier_location"
    session_type: "barrier change" or "probability change"

    Returns:
    block_data with added "maze_configuration" column
    """
    
    def barrier_set_to_string(set):
        """
        Helper to convert a set of ints to a sorted, comma-separated string.
        Used for going from a set of barrier locations to a string
        maze configuration that plays nicely with NWB and spyglass.
        """
        return ",".join(map(str, sorted(set)))

    # Read barrier locations from excel data
    maze_configs = get_barrier_locations_from_excel(excel_data)

    # Make sure the number of blocks matches the number of loaded maze configurations
    if len(block_data) != len(maze_configs):
        # If this is a probability change session, we have a single maze configuration
        # to be used for all blocks. If so, duplicate it so we have one maze per block.
        if len(maze_configs) == 1 and session_type == "probability change":
            maze_configs = maze_configs * len(block_data)
        else:
            raise ValueError(
                f"There are {len(block_data)} blocks, but {len(maze_configs)} maze configurations "
                "From the excel data. There should be exactly one maze configuration per block, "
                "or a single maze configuration if this is a probability change session."
            )
            
    # Convert each maze config from a set to a sorted, comma separated string for compatibility
    maze_configs = [barrier_set_to_string(maze) for maze in maze_configs]
    
    # Add the maze configuration for each block
    block_data["maze_configuration"] = maze_configs
    return block_data


def parse_state_script_log(statescriptlog, DIO_events, excel_data_for_epoch):
    """
    Read and parse the stateScriptLog file and align it to DIO events
    for a given behavioral epoch. Get barrier locations and
    other info (if needed) from excel data.
    
    Args:
    statescriptlog: tuple where statescriptlog[0] is a big string containing the log,
    statescriptlog[1] is the AssociatedFiles object (unused)
    DIO_events: dict of event_name: (data, timestamps) for each named DIO event
    excel_data_for_epoch: dataframe of info for this epoch
    (originally read from the excel sheet of experimental notes)
    
    Returns:
    trial_df: Dataframe of information for each trial in this epoch
    block_data: Dataframe of information for each block in this epoch
    """
    nosepoke_events = []
    behavior_data = []
    block_ends = []
    session_end = None

    # Read the statescriptlog line by line
    for line in str(statescriptlog).splitlines():
        # Ignore lines starting with '#'
        if line.startswith("#"):
            continue

        # Find all poke_in and poke_out events
        for match in poke_in_regex.finditer(line):
            nosepoke_events.append(
                {"timestamp": int(match.group(1)),
                "event_name": "poke_in",
                "port": int(match.group(2))})
        for match in poke_out_regex.finditer(line):
            nosepoke_events.append(
                {"timestamp": int(match.group(1)),
                "event_name": "poke_out",
                "port": int(match.group(2))})
        # Find behavioral data and reward info
        for match in behavior_data_regex.finditer(line):
            behavior_data.append(
                {"timestamp": int(match.group(1)),
                "name": match.group(2),
                "value": int(match.group(3))})
        # Check for block or session end timestamps
        for match in block_end_regex.finditer(line):
            block_ends.append({"timestamp": int(match.group(1))})
        for match in session_end_regex.finditer(line):
            session_end = int(match.group(1))

    # Create dataframes of trial and block data based on the stateScriptLog
    trial_data, block_data = parse_trial_and_block_data(behavior_data, block_ends)

    # Align statescript and DIO nosepokes and create nosepoke dataframe including only nosepoke events at a new port
    # with both statescript timestamps (trodes time) and DIO timestamps (unix time)
    nosepoke_DIOs = {key: value for key, value in DIO_events.items() if key in ['wellA_poke', 'wellB_poke', 'wellC_poke']}
    nosepoke_df = parse_nosepoke_events(nosepoke_events, nosepoke_DIOs)

    # Add combined nosepoke timestamps (both statescript and DIO) to the trial dataframe
    trial_df = combine_nosepoke_and_trial_data(nosepoke_df, trial_data, session_end)

    # Add reward pump timestamps from DIOs to the combined dataframe
    reward_DIOs =  {key: value for key, value in DIO_events.items() if key in ['wellA_pump', 'wellB_pump', 'wellC_pump']}
    trial_df = combine_reward_and_trial_data(trial_df, reward_DIOs)

    # Use block data to determine if this is a probability change or barrier change session
    session_type = determine_session_type(block_data)
    print(f"This is a {session_type} session.")
    # Add the session type as a column to the block data
    block_data['task_type'] = session_type

    # If this is a barrier change session, the statescript does not accurately reflect block changes
    # Instead, a DIO ("barrier_shift") is pressed to mark the trial in which the barrier is shifted
    # For early sessions, the "barrier_shift" DIO didn't exist yet so this is recorded in the 
    # "experimental notes" excel sheet
    if session_type == "barrier change":
        trial_data, block_data = adjust_block_start_trials(trial_df, block_data, DIO_events, excel_data_for_epoch)

    # Now that we have the correct start/end trial for each block, add the block start/end times
    block_data = add_block_start_end_times(trial_data, block_data)

    # Add maze configs from the excel data to the block dataframe
    block_data = add_barrier_locations_to_block_data(block_data, excel_data_for_epoch, session_type)
    
    # Do even more basic checks to make sure trial and block data seems reasonable
    validate_trial_and_block_data(trial_data, block_data)
    validate_poke_timestamps(trial_data)

    return trial_df, block_data


def get_DIO_event_data(nwbfile, behavioral_event_name):
    """
    Get DIO data and timestamps from the nwbfile for a given behavioral event
    
    Args:
    nwbfile: NWB file containing behavioral_event DIOs in the behavioral processing module
    behavioral_event_name: named behavioral_event to access

    Returns:
    data: 1/0 data corresponding to DIO high/low for this event
    timestamps: timestamps for each data point (in unix time)
    """

    data = nwbfile.processing["behavior"]["behavioral_events"][behavioral_event_name].data[:]
    timestamps = nwbfile.processing["behavior"]["behavioral_events"][behavioral_event_name].timestamps[:]
    return data, timestamps


def parse_DIOs(behavioral_event_data):
    """
    Parse behavioral event DIOs and timestamps into DIO pulses for actual events vs epoch starts
    
    Epoch starts are marked by a shared "0" data point and timestamp across all DIO events.
    Remove this event and timestamp from all DIOs so the data and timestamps for
    each DIO reflects the actual behavioral event of interest.

    Args:
    behavioral_event_data: dict of event_name: (data, timestamps) for each named DIO event,
    including DIO data/timestamps for both "real" events and epoch starts

    Returns:
    behavioral_event_data: dict of DIO_event: (data, timestamps) for that event 
    (with DIO data/timestamps for epoch starts removed)
    epoch_start_timestamps: List of timestamps marking epoch starts
    """

    # Get timestamps shared among all behavioral events (triggered by an epoch start)
    epoch_start_timestamps = set.intersection(*[set(ts) for _, ts in behavioral_event_data.values()])

    # Remove epoch start data/timestamps so we are left with only DIOs triggered by real behavioral events
    behavioral_event_data = {key: 
                     ([d for d, ts in zip(data, timestamps) if ts not in epoch_start_timestamps],
                      [ts for ts in timestamps if ts not in epoch_start_timestamps])
                      for key, (data, timestamps) in behavioral_event_data.items()}
    
    # After removing extra 0s for epoch starts, check that each 1 has a corresponding 0
    for key, (data, timestamps) in behavioral_event_data.items():
        for i in range(len(data) - 1):
            if not ((data[i] == 1 and data[i + 1] == 0) or (data[i] == 0 and data[i + 1] == 1)):
                # For now, just warn about it - it may end up being ok
                # It may be due to a session timeout that cut off a poke_out - we can deal with that elsewhere
                warnings.warn(f"{key} has mismatched DIO {data[i], data[i+1]} at timestamps {timestamps[i], timestamps[i+1]}")
                
    return behavioral_event_data, sorted(epoch_start_timestamps)


def get_data_from_excel_sheet(excel_path, date, sheet_name='Daily configs and notes_Bandit+'):
    """ 
    Read the excel sheet of experimental notes and return a dataframe 
    of relevant rows for this recording date.
    """

    # Read the excel sheet into a dataframe and filter for run sessions on our target date
    df = pd.read_excel(excel_path, sheet_name=sheet_name, skiprows=1)
    return df[(df['date'].astype(str) == str(date)) & (df['barrier location'].notna())].reset_index(drop=True)


def add_block_and_trial_data_to_nwb(nwbfile: NWBFile, trial_data, block_data, overwrite=False):
    """
    Add trial and block data to the nwbfile as timeintervals.
    If "block" and "trials" already exist in the nwbfile, complain and
    return without modifying the nwb unless overwrite=True.
    
    Args:
    nwbfile (NWBFile): The nwbfile
    trial_data: Dataframe of trial data for all run epochs in the nwbfile
    block_data: Dataframe of block data for all run epochs in the nwbfile
    overwrite (bool, optional): If we should overwrite existing trial 
    and block data in the nwbfile. Defaults to False
    """

    # Check if a block or trials table already exists in the nwbfile
    if not overwrite and ("block" in nwbfile.intervals or "trials" in nwbfile.intervals):
        print("Stopping. Run again with overwrite=True if you wish to overwrite the original block and trials table.")
        return 

    def get_opto_condition(delay):
        """Helper to get opto condition as a string based on delay"""
        return {1: "delay", 0: "no_delay"}.get(delay, "None")

    # Add columns for block data to the NWB file
    block_table = nwbfile.create_time_intervals(
        name="block",
        description="The block within a session. "
        "Each block is defined by a maze configuration and set of reward probabilities.",
    )
    block_table.add_column(name="epoch", description="The epoch (session) this block is in")
    block_table.add_column(name="block", description="The block number within the session")
    block_table.add_column(
        name="maze_configuration",
        description="The maze configuration for each block, "
        "defined by the set of hexes in the maze where barriers are placed.",
    )
    block_table.add_column(name="pA", description="The probability of reward at port A")
    block_table.add_column(name="pB", description="The probability of reward at port B")
    block_table.add_column(name="pC", description="The probability of reward at port C")
    block_table.add_column(name="start_trial", description="The first trial in this block")
    block_table.add_column(name="end_trial", description="The last trial in this block")
    block_table.add_column(name="num_trials", description="The number of trials in this block")
    block_table.add_column(name="task_type", description="The session type ('barrier change' or 'probability change'")

    # Add columns for trial data to the NWB file
    nwbfile.add_trial_column(name="epoch", description="The epoch (session) this trial is in")
    nwbfile.add_trial_column(name="block", description="The block this trial is in")
    nwbfile.add_trial_column(name="trial_within_block", description="The trial number within the block")
    nwbfile.add_trial_column(name="trial_within_epoch", description="The trial number within the epoch (session)")
    nwbfile.add_trial_column(name="start_port", description="The reward port the rat started at (A, B, or C)")
    nwbfile.add_trial_column(name="end_port", description="The reward port the rat ended at (A, B, or C)")
    nwbfile.add_trial_column(name="reward", description="If the rat got a reward at the port (1 or 0)")
    nwbfile.add_trial_column(name="opto_condition", description="Description of the opto condition, if any")
    nwbfile.add_trial_column(name="duration", description="The duration of the trial")
    nwbfile.add_trial_column(name="poke_in", description="The time the rat entered the reward port")
    nwbfile.add_trial_column(name="poke_out", description="The time the rat exited the reward port")

    # Add each block to the block table in the NWB
    for idx, block in block_data.iterrows():
        block_table.add_row(
            epoch=block["epoch"],
            block=block["block"],
            maze_configuration=block["maze_configuration"],
            pA=block["pA"],
            pB=block["pB"],
            pC=block["pC"],
            start_trial=block["start_trial"],
            end_trial=block["end_trial"],
            num_trials=block["num_trials"],
            task_type=block["task_type"],
            start_time=block["start_time"],
            stop_time=block["end_time"],
        )

    # Add each trial to the NWB
    for idx, trial in trial_data.iterrows():
        nwbfile.add_trial(
            epoch=trial["epoch"],
            block=trial["block"],
            trial_within_block=trial["trial_within_block"],
            trial_within_epoch=trial["trial_within_session"],
            start_port=trial["start_port"],
            end_port=trial["end_port"],
            reward=trial["reward"],
            opto_condition=get_opto_condition(trial['delay']),
            duration=trial["duration"],
            poke_in=trial["poke_in_time"],
            poke_out=trial["poke_out_time"],
            start_time=trial["start_time"],
            stop_time=trial["end_time"],
        )

In [23]:
def add_behavioral_data_to_nwb(nwb_path, excel_path, 
                                 sheet_name='Daily configs and notes_Bandit+', save_type=None, overwrite=False):
    """
    Given an nwbfile, parse behavioral data in to trial and block 
    dataframes for each run epoch and save them for future use.
    
    Args:
    nwb_path: Path to a Frank Lab nwbfile with statescriptlogs saved as AssociatedFiles \
    objects and behavioral event DIOs in the behavior processing module
    excel_path: Path to an excel sheet of behavioral notes for the experiment, \
    including 'date' and 'barrier location' column
    sheet_name: Name of sheet to read from in excel. \
    Defaults to 'Daily configs and notes_Bandit+' if not specified.
    save_type: "nwb", "pickle", or "csv". \
    "nwb" to save the data as timeintervals in the nwbfile. \
    "pickle" or "csv" to save the trial and block dataframes as .pkl or .csv files. " \
    Any combination of save_types is allowed, e.g. save_type="pickle,nwb"
    overwrite: If we should overwrite existing trial and block data in the nwbfile, \
    if it exists. Applies only to save_type="nwb". Defaults to False
    """

    # Hack to remove trials and block table if we need to overwrite them
    removed_old_data_from_nwb = False
    if overwrite and "nwb" in save_type:
        with h5py.File(nwb_path, "r+") as f:
            if "block" in f["intervals"]:
                print("A block table already exists in the nwbfile!")
                print("The original block table in the nwb will be deleted and overwritten.")
                del f["intervals/block"]
            if "trials" in f["intervals"]:
                print("A trials table already exists in the nwbfile!")
                print("The original trials table in the nwb will be deleted and overwritten.")
                del f["intervals/trials"]
            removed_old_data_from_nwb = True

    with NWBHDF5IO(nwb_path, mode="r+") as io:
        nwbfile = io.read()
        print(f"Parsing behavior for {nwbfile.session_id} ...")

        # Get session date assuming session ID is in format rat_date
        session_date = nwbfile.session_id.split("_")[-1]
        # Read rows from excel sheets for run sessions on this date
        excel_data = get_data_from_excel_sheet(excel_path, session_date, sheet_name)

        # Get epoch table defining session boundaries with columns "start_time", "end_time", and "tags"
        epoch_table = nwbfile.intervals["epochs"][:]

        # Filter epochs to include only run sessions (should include "r" in the tags)
        run_epochs = epoch_table[epoch_table["tags"].apply(lambda x: 'r' in x[0])]
        # Filtering epochs for run sessions should be the same as taking every other epoch
        assert run_epochs.equals(epoch_table.iloc[1::2])

        # Get all stateScriptLogs from run sessions (ignoring logs from sleep sessions)
        module = nwbfile.get_processing_module("associated_files")
        run_statescript_logs = {name: log 
                                for name, log in module.data_interfaces.items()
                                if name.startswith("statescript r")
                                }
        assert len(run_statescript_logs) == len(run_epochs) == len(excel_data)

        # Get behavioral events from the nwbfile as a dict of (data, timestamps) for each named behavioral event 
        behavioral_events = ["barrier_shift", "wellA_poke", "wellA_pump", "wellB_poke", 
                        "wellB_pump", "wellC_poke", "wellC_pump"]
        behavioral_event_data = {event: get_DIO_event_data(nwbfile, event) for event in behavioral_events}

        # Separate DIOs into those for actual behavioral events vs epoch starts
        behavioral_event_data, epoch_start_timestamps = parse_DIOs(behavioral_event_data)

        # Check that we have the expected amount of epoch starts
        # NOTE: epoch_start_timestamps from the DIO pulses lag the timestamps in the epoch_table
        # by ~1 second to ~1 minute - check where this discrepancy comes from and which one to use!
        assert len(epoch_start_timestamps) == len(epoch_table)
        
        # Set up lists to store block and trial data for each epoch
        block_dataframes = []
        trial_dataframes = []
        
        # Parse behavioral data for each epoch using the statescriptlog and align to DIOs
        run_session_num = 0
        for idx, epoch in run_epochs.iterrows():

            print(f"Parsing statescript for epoch {epoch.name} ...\n")
            # Get the statescriptlog for this epoch
            statescriptlog = list(run_statescript_logs.items())[run_session_num]

            # Filter DIOs to only include those in this epoch
            # NOTE: maybe replace epoch.start_time and epoch.stop_time with DIO times 
            # (see above comment for reasoning, and commented line below for how to make this switch)
            # So far it does not seem to make a difference in our results, but something to consider.
            DIO_events_in_epoch = {
                event: (list(filtered_data), list(filtered_timestamps))
                for event, (data, timestamps) in behavioral_event_data.items()
                if (filtered := [(d, ts) for d, ts in zip(data, timestamps) if epoch.start_time <= ts <= epoch.stop_time])
                # if (filtered := [(d, ts) for d, ts in zip(data, timestamps) if epoch_start_timestamps[idx] < ts < epoch_start_timestamps[idx+1]])
                for filtered_data, filtered_timestamps in [zip(*filtered)]
            }

            # Filter excel data for this epoch
            excel_data_for_epoch = excel_data.iloc[[run_session_num]]

            # Parse statescriptlog and DIO events for this epoch into tables of trial and block data
            trial_data, block_data = parse_state_script_log(statescriptlog, DIO_events_in_epoch, excel_data_for_epoch)
            
            # Adjustment for start time of first trial/block
            # If the epoch start is after the start time, set the start time to the epoch start.
            if epoch.start_time > trial_data.loc[0, 'start_time']:
                print(f"Setting start time of the first block/trial to epoch start time {epoch.start_time}, was previously {trial_data.loc[0, 'start_time']}")
                trial_data.loc[0, 'start_time'] = epoch.start_time
                block_data.loc[0, 'start_time'] = epoch.start_time
                trial_data.loc[0, 'duration'] = trial_data.loc[0, 'end_time']-trial_data.loc[0, 'start_time']

            # Add epoch column to the dataframes
            trial_data['epoch'] = epoch.name
            block_data['epoch'] = epoch.name
            
            # Reorder columns so epoch comes first
            trial_data = trial_data[['epoch'] + [col for col in trial_data.columns if col != 'epoch']]
            block_data = block_data[['epoch'] + [col for col in block_data.columns if col != 'epoch']]
    
            # Append the dataframes for this epoch
            trial_dataframes.append(trial_data)
            block_dataframes.append(block_data)

            print(f"Trial and block data for epoch {epoch.name}:")
            display(trial_data)
            display(block_data)
            
            run_session_num += 1

        trial_data_all_epochs = pd.concat(trial_dataframes, ignore_index=True)
        block_data_all_epochs = pd.concat(block_dataframes, ignore_index=True)

        if "pickle" in save_type:
            trial_data_all_epochs.to_pickle(f"{nwbfile.session_id}_trial_data.pkl")
            block_data_all_epochs.to_pickle(f"{nwbfile.session_id}_block_data.pkl")
        if "csv" in save_type:
            trial_data_all_epochs.to_csv(f"{nwbfile.session_id}_trial_data.csv", index=False)
            block_data_all_epochs.to_csv(f"{nwbfile.session_id}_block_data.csv", index=False)
        if "nwb" in save_type:
            # Add the trial and block tables to the original nwbfile
            add_block_and_trial_data_to_nwb(nwbfile, trial_data_all_epochs, block_data_all_epochs, overwrite)
            # Write to the nwb if either overwrite=True, 
            # or overwrite=False but there was no existing block and trial data (so we are adding not overwriting)
            if overwrite or not removed_old_data_from_nwb:
                io.write(nwbfile)


def delete_blocks_and_trials_from_nwb(nwb_path):
    """
    Delete block and trials tables (stored as TimeIntervals) from an nwbfile if they exist. 
    Modifies the file in-place. Note that this will not actually reduce the file size 
    due to limitations in the HDF5 format
    """
    with h5py.File(nwb_path, "r+") as f:
        if "block" in f["intervals"]:
            print("Deleting block tale from the nwbfile")
            del f["intervals/block"]
        else:
            print("No block table to delete.")
        if "trials" in f["intervals"]:
            print("Deleting trials table from the nwbfile")
            del f["intervals/trials"]
        else:
            print("No trials table to delete.")

# NOTE: Currently this pipeline prints a lot of stuff because it is still in development - 
# ultimately we will switch to formal logging and save all of the output in a text file instead

In [24]:
nwb_path = 'data/BraveLu20240519_copy.nwb'
excel_path = 'data/BraveLu_experimental_notes.xlsx'
sheet_name = 'Daily configs and notes_Bandit+'

# Add trial and block data to the nwb (modifies existing nwb in-place)
add_behavioral_data_to_nwb(nwb_path, excel_path, sheet_name=sheet_name, save_type="csv", overwrite=False)

# If needed, delete block and trials table from the nwb
# delete_blocks_and_trials_from_nwb(nwb_path)

Parsing behavior for BraveLu_20240519 ...
Parsing statescript for epoch 1 ...

1908 nosepokes from the DIOs: {'A': 840, 'B': 820, 'C': 248}
1908 nosepokes from the statescript: {'A': 840, 'B': 820, 'C': 248}
All DIO and statescript nosepokes were matched successfully.
This is a probability change session.
Trial and block data for epoch 1:


Unnamed: 0,epoch,trial_within_block,trial_within_session,block,start_port,end_port,reward,delay,statescript_reference_timestamp,poke_in_time_statescript,poke_out_time_statescript,poke_in_time,poke_out_time,start_time,end_time,duration,pump_on_time,pump_off_time
0,1,1,1,1,,B,1,,75570,75324,83539,1716147462,1716147471,1716147459,1716147471,11,1716147462,1716147463
1,1,2,2,1,B,A,0,,97335,96993,106012,1716147484,1716147493,1716147471,1716147493,22,,
2,1,3,3,1,A,B,1,,115265,114951,127177,1716147502,1716147514,1716147493,1716147514,21,1716147502,1716147502
3,1,4,4,1,B,A,1,,136308,135776,149024,1716147523,1716147536,1716147514,1716147536,22,1716147523,1716147523
4,1,5,5,1,A,B,0,,155731,155344,156820,1716147542,1716147544,1716147536,1716147544,8,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
202,1,65,203,3,B,A,0,,3263628,3263508,3264676,1716150651,1716150652,1716150639,1716150652,13,,
203,1,66,204,3,A,B,0,,3272356,3271929,3273643,1716150659,1716150661,1716150652,1716150661,9,,
204,1,67,205,3,B,A,0,,3282576,3282434,3283749,1716150670,1716150671,1716150661,1716150671,10,,
205,1,68,206,3,A,B,1,,3291477,3290981,3305558,1716150678,1716150693,1716150671,1716150693,22,1716150678,1716150678


Unnamed: 0,epoch,block,pA,pB,pC,statescript_end_timestamp,start_trial,end_trial,num_trials,task_type,start_time,end_time,maze_configuration
0,1,1,10,50,90,1059629,1,69,69,probability change,1716147459,1716148449,91020212327353745
1,1,2,90,10,50,2193559,70,138,69,probability change,1716148449,1716149597,91020212327353745
2,1,3,10,90,50,3315015,139,207,69,probability change,1716149597,1716150702,91020212327353745


Parsing statescript for epoch 3 ...

1998 nosepokes from the DIOs: {'B': 918, 'A': 570, 'C': 510}
1996 nosepokes from the statescript: {'B': 916, 'A': 570, 'C': 510}
The DIO may have more pokes because it keeps recording after the statescript has been stopped (this is ok).
Mismatched rows:
      timestamp_DIO event_name port  index  timestamp_statescript
743      1716157698    poke_in    B    458                    NaN
1742     1716157699   poke_out    B    458                    NaN
This is a probability change session.
Trial and block data for epoch 3:


Unnamed: 0,epoch,trial_within_block,trial_within_session,block,start_port,end_port,reward,delay,statescript_reference_timestamp,poke_in_time_statescript,poke_out_time_statescript,poke_in_time,poke_out_time,start_time,end_time,duration,pump_on_time,pump_off_time
0,3,1,1,1,,C,0,,75363,75051,76248,1716154390,1716154391,1716154387,1716154391,4,,
1,3,2,2,1,C,B,1,,85874,85544,98981,1716154400,1716154414,1716154391,1716154414,23,1716154400,1716154401
2,3,3,3,1,B,C,1,,107835,107317,117084,1716154422,1716154432,1716154414,1716154432,18,1716154422,1716154423
3,3,4,4,1,C,B,0,,125859,125274,126299,1716154440,1716154441,1716154432,1716154441,9,,
4,3,5,5,1,B,A,1,,134591,134552,149164,1716154449,1716154464,1716154441,1716154464,23,1716154449,1716154450
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
202,3,65,203,3,B,C,1,,3326405,3313873,3326773,1716157629,1716157642,1716157619,1716157642,23,1716157629,1716157629
203,3,66,204,3,C,B,1,,3340188,3340251,3374049,1716157655,1716157670,1716157642,1716157670,29,1716157655,1716157655
204,3,67,205,3,B,C,0,,3366203,3365182,3366177,1716157680,1716157681,1716157670,1716157681,11,,
205,3,68,206,3,C,B,1,,3374084,3374289,3383567,1716157688,1716157698,1716157681,1716157698,17,1716157688,1716157689


Unnamed: 0,epoch,block,pA,pB,pC,statescript_end_timestamp,start_trial,end_trial,num_trials,task_type,start_time,end_time,maze_configuration
0,3,1,90,50,10,1073804,1,69,69,probability change,1716154387,1716155400,91018192129374143
1,3,2,50,10,90,2202466,70,138,69,probability change,1716155400,1716156522,91018192129374143
2,3,3,10,90,50,3392697,139,207,69,probability change,1716156522,1716157709,91018192129374143


Parsing statescript for epoch 5 ...

1511 nosepokes from the DIOs: {'B': 973, 'A': 516, 'C': 22}
1500 nosepokes from the statescript: {'B': 964, 'A': 514, 'C': 22}
The DIO may have more pokes because it keeps recording after the statescript has been stopped (this is ok).
Mismatched rows:
      timestamp_DIO event_name port  index  timestamp_statescript
257      1716164327    poke_in    A    257                    NaN
740      1716164315    poke_in    B    482                    NaN
741      1716164315    poke_in    B    483                    NaN
742      1716164316    poke_in    B    484                    NaN
743      1716164336    poke_in    B    485                    NaN
744      1716164337    poke_in    B    486                    NaN
1013     1716164327   poke_out    A    257                    NaN
1496     1716164315   poke_out    B    482                    NaN
1497     1716164315   poke_out    B    483                    NaN
1498     1716164317   poke_out    B    484         

Unnamed: 0,epoch,trial_within_block,trial_within_session,block,start_port,end_port,reward,delay,statescript_reference_timestamp,poke_in_time_statescript,poke_out_time_statescript,poke_in_time,poke_out_time,start_time,end_time,duration,pump_on_time,pump_off_time
0,5,1,1,1,,B,1,,119187,118917,133403,1716162121,1716162136,1716162118,1716162136,17,1716162122,1716162122
1,5,2,2,1,B,A,0,,141553,141261,151140,1716162144,1716162154,1716162136,1716162154,18,,
2,5,3,3,1,A,B,1,,157063,156623,167191,1716162159,1716162170,1716162154,1716162170,16,1716162159,1716162159
3,5,4,4,1,B,C,0,,185988,177136,186217,1716162180,1716162189,1716162170,1716162189,19,,
4,5,5,5,1,C,B,1,,194394,193973,203613,1716162197,1716162206,1716162189,1716162206,17,1716162197,1716162197
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
134,5,28,135,3,A,B,1,,2208143,2207739,2221207,1716164210,1716164224,1716164200,1716164224,24,1716164210,1716164211
135,5,29,136,3,B,A,1,,2231108,2230547,2242272,1716164233,1716164245,1716164224,1716164245,21,1716164233,1716164233
136,5,30,137,3,A,B,1,,2253106,2252598,2266679,1716164255,1716164269,1716164245,1716164269,24,1716164255,1716164256
137,5,31,138,3,B,A,1,,2277606,2277083,2290169,1716164280,1716164293,1716164269,1716164293,23,1716164280,1716164280


Unnamed: 0,epoch,block,pA,pB,pC,statescript_end_timestamp,start_trial,end_trial,num_trials,task_type,start_time,end_time,maze_configuration
0,5,1,50,90,10,,1,43,43,barrier change,1716162118,1716162699,91116212933374446
1,5,2,50,90,10,,44,107,64,barrier change,1716162699,1716163700,91114212933374446
2,5,3,50,90,10,,108,139,32,barrier change,1716163700,1716164315,91114202133374446


Parsing statescript for epoch 7 ...

1748 nosepokes from the DIOs: {'B': 878, 'C': 716, 'A': 154}
1748 nosepokes from the statescript: {'B': 878, 'C': 716, 'A': 154}
All DIO and statescript nosepokes were matched successfully.
This is a barrier change session.
Adjusting barrier shift times based on barrier_shift DIOs ...
Barrier shift DIO pressed 13.35s after poke_in of trial 72.
Next poke_in was 11.50s after barrier shift DIO pressed.
Trial 73 is the first trial of the new block.
Barrier shift DIO pressed 11.64s after poke_in of trial 100.
Next poke_in was 9.83s after barrier shift DIO pressed.
Trial 101 is the first trial of the new block.
Barrier shift DIO pressed 10.01s after poke_in of trial 164.
Next poke_in was 12.88s after barrier shift DIO pressed.
Trial 165 is the first trial of the new block.
Barrier_shift DIOs match data from excel sheet, with barrier shifts at trials [72, 100, 164]
Trial and block data for epoch 7:


Unnamed: 0,epoch,trial_within_block,trial_within_session,block,start_port,end_port,reward,delay,statescript_reference_timestamp,poke_in_time_statescript,poke_out_time_statescript,poke_in_time,poke_out_time,start_time,end_time,duration,pump_on_time,pump_off_time
0,7,1,1,1,,B,0,,47867,47791,55020,1716168699,1716168706,1716168696,1716168706,10,,
1,7,2,2,1,B,A,0,,62704,62364,72285,1716168713,1716168723,1716168706,1716168723,17,,
2,7,3,3,1,A,C,1,,80666,80278,95623,1716168731,1716168747,1716168723,1716168747,23,1716168731,1716168731
3,7,4,4,1,C,B,0,,103493,102947,103475,1716168754,1716168754,1716168747,1716168754,8,,
4,7,5,5,1,B,A,0,,109348,108764,109322,1716168760,1716168760,1716168754,1716168760,6,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
211,7,48,212,4,B,C,1,,3203537,3203481,3214297,1716171854,1716171865,1716171846,1716171865,19,1716171854,1716171855
212,7,49,213,4,C,B,1,,3222702,3222069,3235817,1716171873,1716171887,1716171865,1716171887,22,1716171873,1716171873
213,7,50,214,4,B,C,1,,3244214,3243923,3257944,1716171895,1716171909,1716171887,1716171909,22,1716171895,1716171895
214,7,51,215,4,C,B,0,,3266362,3265626,3266808,1716171917,1716171918,1716171909,1716171918,9,,


Unnamed: 0,epoch,block,pA,pB,pC,statescript_end_timestamp,start_trial,end_trial,num_trials,task_type,start_time,end_time,maze_configuration
0,7,1,10,50,90,,1,72,72,barrier change,1716168696,1716169605,7111220263436383941
1,7,2,10,50,90,,73,100,28,barrier change,1716169605,1716169989,7111214202634363839
2,7,3,10,50,90,,101,164,64,barrier change,1716169989,1716171017,7111420262934363839
3,7,4,10,50,90,,165,216,52,barrier change,1716171017,1716171927,7111420293436383945
