In [None]:
#Auto-reload modules (used to develop functions outside this notebook)
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import labrotation.file_handling as fh
import datadoc_util as ddoc
import os
import pims_nd2
import numpy as np
import h5py
from datetime import datetime as dt

In [None]:
env_dict = dict()
if not os.path.exists("./.env"):
    print(".env does not exist")
else:
    with open("./.env", "r") as f:
        for line in f.readlines():
            l = line.rstrip().split("=")
            env_dict[l[0]] = l[1]
print(env_dict.keys())

In [None]:
def get_datetime_for_fname():
    now = dt.now()
    return f"{now.year:04d}{now.month:02d}{now.day:02d}-{now.hour:02d}{now.minute:02d}{now.second:02d}"

In [None]:
if "DOWNLOADS_FOLDER" in env_dict.keys():
    output_folder =  env_dict["DOWNLOADS_FOLDER"]
    print(f"Output: {output_folder}")

In [None]:
recovery_traces_fpath = fh.open_file("Select excel file containing recovery trace list")

In [None]:
df_recovery_traces = pd.read_excel(recovery_traces_fpath)

In [None]:
if "DATA_DOCU_FOLDER" in env_dict.keys():
    dd = ddoc.DataDocumentation(env_dict["DATA_DOCU_FOLDER"])
    dd.loadDataDoc()
    if "SERVER_SYMBOL" in env_dict.keys():
        dd.setDataDriveSymbol(env_dict["SERVER_SYMBOL"])
    print("data documentation loaded")

In [None]:
df_recovery_traces = df_recovery_traces[df_recovery_traces["event_uuid"] =="74473c5d22e04525acf53f5a5cb799f4" ]

In [None]:
rec_uuids = df_recovery_traces.recording_uuid.unique()

In [None]:
fpath_list_nd2 = dd.getNikonFilePathForUuid(rec_uuids)

### Test that all nd2 files are present

In [None]:
for fpath_nd2 in fpath_list_nd2:
    if not (os.path.exists(fpath_nd2)):
        print(fpath_nd2)

# Load nd2 data

In [None]:
# nd2_data_dict = { uuid: (t_start_utc, nd2_tstamps, mean_fluo)}
nd2_data_dict = dict()
for i_fpath_nd2 in range(len(fpath_list_nd2)):
    fpath_nd2 = fpath_list_nd2[i_fpath_nd2]
    rec_uuid = rec_uuids[i_fpath_nd2]
    print(rec_uuid)
    nd2r = pims_nd2.ND2_Reader(fpath_nd2)
    t_start_jdn = nd2r.metadata["time_start_jdn"]
    t_start_utc = nd2r.metadata["time_start_utc"]
    mean_fluo = np.mean(nd2r, axis=(1,2))
    
    nd2_tstamps = np.array([nd2r[i].metadata["t_ms"] for i in range(len(nd2r))]) 
    
    nd2_data_dict[rec_uuid] = (t_start_utc, nd2_tstamps, mean_fluo)

In [None]:
# Check uuid-nd2 consistency
for i_nd2 in range(len(fpath_list_nd2)):
    uuid = rec_uuids[i_nd2]
    nd2_fname = os.path.split(fpath_list_nd2[i_nd2])[-1]
    rec_len = dd.getSegmentsForUUID(uuid).frame_end.max()
    len_to_check = len(nd2_data_dict[uuid][2])
    if len_to_check != rec_len:
        print(f"Length mismatch {uuid} ({nd2_fname}):\n\tshould: {rec_len} but is {len_to_check}")
        print(uuid)

# Assemble the traces

In [None]:
# break_points: break points of segments AND recordings. 
# recording_break_points: break points of nd2 files
# segment_type_break_points: break points of segments. 0, n_bl_frames, n_bl_frames + n_sz_frames
# session_uuids: uuids of recording sessions
# window_type


In [None]:
window_mapping = {"cx": "Cx", "ca1": "CA1"}

In [None]:
sz_begin_end_frames = dict()
traces_dict = dict()
tstamps_dict = dict()
trace_attributes_dict = dict()
for event_uuid, g in df_recovery_traces.groupby("event_uuid"):
    trace = np.array([])
    tstamps = np.array([])
    mouse_id = dd.getMouseIdForUuid(g.recording_uuid.unique()[0])
    session_uuids = []
    recording_break_points = []  # 0-indices in the trace of first frames of each nd2 contributing to trace
    segment_type_break_points = [0]  # bl, sz, am first (0-indexing) frames in the trace
    win_type = window_mapping[dd.getMouseWinInjInfo(mouse_id).window_type.iloc[0]]  # CA1 or Cx
    i_current_frame = 0
    # When starting to construct a trace, the very first segment starts at 0.
    # Each consecutive segment is matched by taking the time stamps (ms since beginning of current recording),
    # and subtract first segment time (first recording start time stamp + ms between start and first frame used for trace)
    first_recording_begin_datetime = None
    first_frame_dt_ms_since_rec_begin = None  # dt between first frame used in trace and start timestamp of first recording
    for i_row, row in g.iterrows():

        rec_uuid = row["recording_uuid"]

        if rec_uuid not in session_uuids:  # first time the recording appears in this trace
            recording_break_points.append(i_current_frame)
            session_uuids.append(rec_uuid)

        begin_frame = row["begin_frame"]
        end_frame = row["end_frame"]

        if first_recording_begin_datetime is None:  # this is the first segment
            first_recording_begin_datetime = nd2_data_dict[rec_uuid][0]
            first_frame_ms_since_rec_begin =  nd2_data_dict[rec_uuid][1][begin_frame-1]
        current_recording_begin_datetime = nd2_data_dict[rec_uuid][0]

        assert (current_recording_begin_datetime - first_recording_begin_datetime).total_seconds() >= 0.

        segment_timestamps = nd2_data_dict[rec_uuid][1]  # ms since start of current recording
        segment_timestamps = segment_timestamps[begin_frame-1:end_frame]  # cut to segment used from this recording

        # get the number in ms needed to match to timeframe where first frame of this trace is 0
        dt_ms_first_frame_current_start_frame = (current_recording_begin_datetime - first_recording_begin_datetime).total_seconds()*1000. - first_frame_ms_since_rec_begin

        segment_timestamps = segment_timestamps + dt_ms_first_frame_current_start_frame  # convert time stamps to set t=0 to the first frame in the trace

        segment = nd2_data_dict[rec_uuid][2][begin_frame-1:end_frame]  # both inclusive, 1-indexing -> convert to 0-indexing
        segment_type = row["segment_type"]

        # TODO: check how segment type break points are added. Should not add duplicates!!!
        if segment_type == "sz":  # add sz begin and end frames
            if event_uuid not in sz_begin_end_frames.keys():  # first, and maybe last, sz segment
                sz_begin_end_frames[event_uuid] = [i_current_frame, i_current_frame + len(segment)-1]  # both indices 0-indexing, inclusive
                segment_type_break_points.append(i_current_frame)  # add first 0-index of sz segment
            else:  # not the first "sz" segment
                sz_begin_end_frames[event_uuid][1] = i_current_frame + len(segment)-1  # expand sz segment in trace
        elif segment_type == "am":
            if event_uuid in ["f0442bebcd1a4291a8d0559eb47df08e"]:  # manually add sz begin end frames to recordings where they are missing
                if event_uuid not in sz_begin_end_frames.keys():
                    sz_begin_end_frames[event_uuid] = [i_current_frame, i_current_frame]
                    segment_type_break_points.append(i_current_frame-1)  # set begin of sz segment
                    segment_type_break_points.append(i_current_frame)  # set begin of am segment 
            else:
                if len(segment_type_break_points) < 3:  # only add am begin frame if not yet in list
                    segment_type_break_points.append(i_current_frame)  # set begin of am segment

        trace = np.concatenate([trace, segment])
        tstamps = np.concatenate([tstamps, segment_timestamps])
        i_current_frame = len(trace)  # set the next frame index
    trace_attributes_dict[event_uuid] = dict()
    trace_attributes_dict[event_uuid]["mouse_id"] = mouse_id
    trace_attributes_dict[event_uuid]["window_type"] = win_type
    trace_attributes_dict[event_uuid]["session_uuids"] = session_uuids
    trace_attributes_dict[event_uuid]["recording_break_points"] = recording_break_points
    trace_attributes_dict[event_uuid]["segment_type_break_points"] = segment_type_break_points
    # assemble total break points
    i_recbreak = 0
    i_segbreak = 0
    break_points = []
    while (i_recbreak < len(recording_break_points)) and (i_segbreak < len(segment_type_break_points)):
        if recording_break_points[i_recbreak] < segment_type_break_points[i_segbreak]:
            if recording_break_points[i_recbreak] not in break_points:
                break_points.append(recording_break_points[i_recbreak])
            i_recbreak += 1
        else:
            if segment_type_break_points[i_segbreak] not in break_points:
                break_points.append(segment_type_break_points[i_segbreak])
            i_segbreak += 1
    # one of the break points list is completely in break_points; add the rest of the other
    if i_recbreak == len(recording_break_points):
        while i_segbreak < len(segment_type_break_points):
            if segment_type_break_points[i_segbreak] not in break_points:
                break_points.append(segment_type_break_points[i_segbreak])
            i_segbreak += 1
    elif i_segbreak == len(segment_type_break_points):
        while i_recbreak < len(recording_break_points):
            if recording_break_points[i_recbreak] not in break_points:
                break_points.append(recording_break_points[i_recbreak])
            i_recbreak += 1
    trace_attributes_dict[event_uuid]["break_points"] = break_points

    traces_dict[event_uuid] = trace
    tstamps_dict[event_uuid] = tstamps

In [None]:
output_fname = f"traces_for_recovery_{get_datetime_for_fname()}.h5"
output_fpath = os.path.join(output_folder, output_fname)
with h5py.File(output_fpath, "w") as hf:
    for event_uuid in traces_dict.keys():
        uuid_grp = hf.create_group(event_uuid)
        uuid_grp.attrs["sz_begin_frame"] = sz_begin_end_frames[event_uuid][0]
        uuid_grp.attrs["sz_end_frame"] = sz_begin_end_frames[event_uuid][1]
        for k in trace_attributes_dict[event_uuid].keys():
            uuid_grp.attrs[k] = trace_attributes_dict[event_uuid][k]
        trace = uuid_grp.create_dataset(data=traces_dict[event_uuid], name="mean_fluo")
        tstamps = uuid_grp.create_dataset(data=tstamps_dict[event_uuid], name="tstamps")
        n_frames = len(trace)
        segment_type_break_points = trace_attributes_dict[event_uuid]["segment_type_break_points"]
        if len(segment_type_break_points) != 3:  # bl_begin, sz_begin, am_begin are the points
            print(event_uuid)
            print(segment_type_break_points)
            raise Exception()
        n_bl_frames = segment_type_break_points[1]-segment_type_break_points[0]
        n_sz_frames = segment_type_break_points[2]-segment_type_break_points[1]
        n_am_frames = n_frames - n_sz_frames - n_bl_frames
        uuid_grp.attrs["n_bl_frames"] = n_bl_frames
        uuid_grp.attrs["n_sz_frames"] = n_sz_frames
        uuid_grp.attrs["n_am_frames"] = n_am_frames
print(f"Saved to {output_fpath}")