# Recovery analysis
(TMEV+ChR2 datasets combined)
1. time to recovery
2. Sz-bl, SD-bl amplitudes
3. baseline-trough fluorescence difference
4. peak-trough time

In [None]:
# TODO: use recovery time point to set upper limit of trough index
# TODO: brightest spot right now is found as absolute maximum. Seems good enough, but maybe a more robust method?

## Set params

In [None]:
percent_considered = 5  # x% darkest/brightest of complete trace to consider
extreme_group_size = 15  # this many of the darkest/brightest pixels to consider (earliest darkest percent_considered% pixels)
n_trough_frames = 5000  # simple method to set upper limit of window where to look for darkest point.
peak_window_length = 300  # consider the first n frames when looking for peak
imaging_freq = 15.  # approx, in hertz
n_frames_before_am_start_nc = 200  # number of frames to consider additionally before aftermath begin for NC traces to look for trough. Necessary because optical end of seizure segment is "darkest point", so am category might just miss trough.

In [None]:
a = [i for i in range(5000)]

In [None]:
# window-related parameters
window_width_s = 10
window_step_s = 5
imaging_frequency = 15. # in Hz
n_frames_before_nc = 200  # include 200 frames just before aftermath for NC recordings  
n_frames_before_ca1 = 0
n_windows_post_darkest = 300 #40 # dataset consists of bl, darkest point, and this many windows post darkest point

default_bl_center_ca1 = -75 # 4925 when 5000 bl frames
default_bl_center_nc = -975  # 4025 when 5000 bl frames

window_width_frames = int(window_width_s*imaging_frequency)
window_step_frames = int(window_step_s*imaging_frequency)

half_window_width_frames = window_width_frames//2

recovery_ratio = 0.95  # reach x % of baseline to be considered recovered

In [None]:
win_types_mapping = {"CA1" : "CA1", "Cx" : "NC"}  # replace Cx with NC

In [None]:
save_dsets = False

In [None]:
save_figs = False
save_as_eps = False
save_as_pdf = True
if save_as_eps:
    output_format = ".eps"
elif save_as_pdf:
    output_format=".pdf"
else:
    output_format = ".jpg"
if save_figs:
    print(output_format)

## Load libraries, set data

In [None]:
#Auto-reload modules (used to develop functions outside this notebook)
%load_ext autoreload
%autoreload 2

In [None]:
import labrotation.file_handling as fh
import seaborn as sns
import os
from datetime import datetime
import datadoc_util
import h5py
import numpy as np
from math import floor
import matplotlib.pyplot as plt
import pandas as pd
import warnings

In [None]:
sns.set_theme(font_scale=2)
sns.set_style("whitegrid")

In [None]:
chr2_fpath = fh.open_file("Open ChR2 assembled traces h5 file!")

In [None]:
tmev_fpath = fh.open_file("Open TMEV assembled traces h5 file!")

In [None]:
env_dict = dict()
if not os.path.exists("./.env"):
    print(".env does not exist")
else:
    with open("./.env", "r") as f:
        for line in f.readlines():
            l = line.rstrip().split("=")
            env_dict[l[0]] = l[1]
print(env_dict.keys())

In [None]:
if "DATA_DOCU_FOLDER" in env_dict.keys():
    docu_folder = env_dict["DATA_DOCU_FOLDER"]
else:
    docu_folder = fh.open_dir("Choose folder containing folders for each mouse!")
print(f"Selected folder:\n\t{docu_folder}")

In [None]:
if "documentation" in os.listdir(docu_folder):
    mouse_folder = os.path.join(docu_folder, "documentation")
else:
    mouse_folder = docu_folder
mouse_names = os.listdir(mouse_folder)
print(f"Mice detected:")
for mouse in mouse_names:
    print(f"\t{mouse}")

In [None]:
def get_datetime_for_fname():
    now = datetime.now()
    return f"{now.year:04d}{now.month:02d}{now.day:02d}-{now.hour:02d}{now.minute:02d}{now.second:02d}"

In [None]:
output_folder = env_dict["DOWNLOADS_FOLDER"]
print(f"Output files will be saved to {output_folder}")

In [None]:
ddoc = datadoc_util.DataDocumentation(docu_folder)
ddoc.loadDataDoc()

## Load traces

In [None]:
dict_mean_fluo = {} # uuid: [mean_fluo], cut to aftermath only!
dict_bl_fluo = {}  # baseline (until segment_type_break_points[1]) 
dict_mid_fluo = {}  # rest of trace: sz or stim+sz
dict_meta = {}  # uuid: {"exp_type": exp_type, "mouse_id": mouse_id, "session_uuids": [session_uuids]}

dict_excluded = {}  # uuid: {"exp_type": exp_type, "mouse_id": mouse_id, "win_type": window_type, "session_uuids": [session_uuids]}

dict_segment_break_points = {}  # uuid: (i_begin_mid, i_begin_am). bl: [:i_begin_mid], mid: [i_begin_mid:i_begin_am], am: [i_begin_am:]

# Load traces. Set start time to appearance of first SD wave. TODO: maybe last SD wave must be used?
for fpath in [tmev_fpath, chr2_fpath]:
    with h5py.File(fpath, "r") as hf:
        for event_uuid in hf.keys():
            win_type = win_types_mapping[hf[event_uuid].attrs["window_type"]]
            assert "session_uuids" in hf[event_uuid].attrs
            mouse_id = hf[event_uuid].attrs["mouse_id"]
            # for TMEV, traces were stitched together from multiple recordings, so uuid is not in data documentation. 
            # But the individual session uuids are stored in attributes (both for ChR2 and TMEV data)
            session_uuids = hf[event_uuid].attrs["session_uuids"]
            exp_type = ddoc.getExperimentTypeForUuid(session_uuids[0])
            mean_fluo = np.array(hf[event_uuid]["mean_fluo"])
            if exp_type == "tmev":
                # as TMEV traces are stitched together, it is difficult to use data documentation.
                # But segment_type_break_points attribute contains bl, sz, am begin frames.
                # am (aftermath) is defined as visual appearance of first SD wave. Can take this as beginning
                segment_type_break_points = hf[event_uuid].attrs["segment_type_break_points"]
                assert len(segment_type_break_points) == 3  # make sure only bl, sz, am points are in list
                i_begin_am = segment_type_break_points[2]
                i_begin_mid = segment_type_break_points[1]  # one frame past end of baseline, i.e. begin of middle section (sz)
                if win_type == "NC":  # NC seizures end abruptly, manual segmentation tries to set "reaching darkest point" as end of Sz. This means trough might be missed in original "aftermath" category.
                    i_begin_am -= n_frames_before_am_start_nc
                    assert i_begin_am > 0
            elif exp_type in ["chr2_sd", "chr2_szsd"]:
                assert session_uuids[0] == event_uuid
                df_segments = ddoc.getSegmentsForUUID(event_uuid)
                # set first frame of first SD appearance as beginning
                i_begin_am = df_segments[df_segments["interval_type"] == "sd_wave"].frame_begin.min() - 1  # 1-indexing to 0-indexing conversion
                i_begin_mid = df_segments[df_segments["interval_type"] == "stimulation"].frame_begin.min() - 1
            else:
                continue  # do not add chr2_ctl recordings to dataset 
            if not np.isnan(i_begin_am):
                bl_fluo = mean_fluo[:i_begin_mid].copy()
                mid_fluo = mean_fluo[i_begin_mid:i_begin_am].copy()
                mean_fluo = mean_fluo[i_begin_am:]

                dict_segment_break_points[event_uuid] = (i_begin_mid, i_begin_am)

                dict_bl_fluo[event_uuid] = bl_fluo
                dict_mean_fluo[event_uuid] = mean_fluo
                dict_mid_fluo[event_uuid] = mid_fluo
                dict_meta[event_uuid] = {"exp_type": exp_type, "mouse_id": mouse_id, "win_type": win_type, "session_uuids": session_uuids}
            else:
                dict_excluded[event_uuid] = {"exp_type": exp_type, "mouse_id": mouse_id, "win_type": win_type, "session_uuids": session_uuids}


# Assemble recovery dataset

### Define window-related functions

In [None]:
def get_window(i_center, trace) -> np.array:
    """Given i_center and the global parameter half_window_width_frames, try to return a window centered around i_center, 
    and with inclusive borders at i_center - half_window_width_frames, i_center + half_window_width_frames. Might return a 
    smaller window [0, i_center + half_window_width_frames], or [i_center - half_window_width_frames, len(trace) - 1] if the 
    boundaries are outside the shape of trace.
    Parameters
    ----------
    i_center : _type_
        _description_
    trace : _type_
        _description_

    Returns
    -------
    np.array
        _description_
    """
    if i_center > len(trace):
        warnings.warn(f"Trying to access window with center {i_center}, but only {len(trace)} frames")
        return np.array([])
    if i_center + half_window_width_frames > len(trace):
        warnings.warn(f"Part of window out of bounds: {i_center} + HW {half_window_width_frames} > {len(trace)}")
        right_limit = len(trace)
    else:
        right_limit = i_center + half_window_width_frames + 1  # right limit is exclusive
    if i_center - half_window_width_frames < 0:
        warnings.warn(f"Part of window out of bounds: {i_center} - HW {half_window_width_frames} < 0")
        left_limit = 0
    else:
        left_limit = i_center - half_window_width_frames
    return trace[left_limit : right_limit]

In [None]:
def get_metric_for_window(trace_window):
    lowest_indices = np.argsort(trace_window)[:int(percent_considered/100.*len(trace_window))]
    lowest_values = trace_window[lowest_indices]
    return np.median(lowest_values)

### Find baseline windows, metrics

In [None]:
# TODO: manually correct f0442bebcd1a4291a8d0559eb47df08e
#dict_uuid_manual_bl_center = {"aa66ae0470a14eb08e9bcadedc34ef64": 4250, "c7b29d28248e493eab02288b85e3adee": 4000,  "7b9c17d8a1b0416daf65621680848b6a": 4050, "9e75d7135137444492d104c461ddcaac": 4700, "d158cd12ad77489a827dab1173a933f9": 4500, "a39ed3a880c54f798eff250911f1c92f" : 4500, "4e2310d2dde845b0908519b7196080e8" : 4500, "f0442bebcd1a4291a8d0559eb47df08e": 4500, "2251bba132cf45fa839d3214d1651392": 3700, "cd3c1e0e3c284a89891d2e4d9a7461f4": 3500}
dict_uuid_manual_bl_center = {"aa66ae0470a14eb08e9bcadedc34ef64": -750, "c7b29d28248e493eab02288b85e3adee": -1000,  "7b9c17d8a1b0416daf65621680848b6a": -950, "9e75d7135137444492d104c461ddcaac": -300, "d158cd12ad77489a827dab1173a933f9": -500, "a39ed3a880c54f798eff250911f1c92f" : -500, "4e2310d2dde845b0908519b7196080e8" : -500, "f0442bebcd1a4291a8d0559eb47df08e": -500, "2251bba132cf45fa839d3214d1651392": -1300, "cd3c1e0e3c284a89891d2e4d9a7461f4": -1500}

# uuid: (i_bl, bl_metric), i_bl is the center of the window
dict_bl_values = {}

for uuid in dict_meta.keys():  # uuid: {"exp_type": exp_type, "mouse_id": mouse_id, "session_uuids": [session_uuids]}
    exp_type = dict_meta[uuid]["exp_type"]
    win_type = dict_meta[uuid]["win_type"]
    # check if manually corrected. If not, check if TMEV or not. If TMEV, use default_bl_center_ca1/default_bl_center_nc
    # if ChR2, can use a window right before stim
    bl_trace = dict_bl_fluo[uuid]
    if uuid in dict_uuid_manual_bl_center:
        i_bl = dict_uuid_manual_bl_center[uuid]
    elif exp_type == "tmev":
        if win_type == "CA1":
            i_bl = default_bl_center_ca1
        elif win_type == "NC":
            i_bl = default_bl_center_nc
    elif exp_type in ["chr2_sd", "chr2_szsd"]:
        # take a window just before stim
        i_bl = len(bl_trace) - half_window_width_frames - 1
    if i_bl < 0:
        i_bl = len(bl_trace) + i_bl
    bl_win = get_window(i_bl, bl_trace)
    bl_metric = get_metric_for_window(bl_win)
    dict_bl_values[uuid] = (i_bl, bl_metric)
    

In [None]:
for event_uuid in dict_meta:
    if dict_meta[event_uuid]["exp_type"] == "tmev" and dict_meta[event_uuid]["win_type"] == "CA1":
        print(f"{event_uuid}:\t{dict_bl_values[event_uuid][0]}\t{len(dict_bl_fluo[event_uuid])}\t{dict_bl_values[event_uuid][1]}")


In [None]:
for event_uuid in dict_meta:
    if dict_meta[event_uuid]["exp_type"] == "tmev" and dict_meta[event_uuid]["win_type"] == "NC":
        print(f"{event_uuid}:\t{dict_bl_values[event_uuid][0]}\t{len(dict_bl_fluo[event_uuid])}\t{dict_bl_values[event_uuid][1]}")


### Find peak, trough (darkest point), recovery position
Same method as Baseline recovery 

In [None]:
# aftermath:
# TMEV - appearance of first SD. This could also be taken above
# ChR2 - if SD present, then appearance of first SD. Else: directly after stim (ctl).

dict_peak_trough = {}  # uuid: (i_peak, i_trough, peak_amplitude, trough_amplitude)

for event_uuid in dict_mean_fluo.keys(): 
    exp_type = dict_meta[event_uuid]["exp_type"]
    win_type = dict_meta[event_uuid]["exp_type"]

    # traces already cut to "aftermath" (plus few extra frames)
    complete_trace = dict_mean_fluo[event_uuid]
    
    # get 5% darkest points of aftermath
    sorted_indices = np.argsort(complete_trace)  # this cut should not influence the index
    
    # get brightest frame
    # old method, uses same percentages and median as darkest frame. Did not work well
    #i_brightest_group = np.flip(sorted_indices)[:int(percent_considered/100.*len(sorted_indices))]
    #i_brightest = int(floor(np.median(np.sort(i_brightest_group)[:extreme_group_size])))
    sorted_beginning = np.argsort(complete_trace[:peak_window_length])
    i_brightest = sorted_beginning[-1]

    cut_trace = complete_trace[i_brightest:i_brightest+n_trough_frames]
    # use reduced window to look for trough
    sorted_indices_cut = np.argsort(cut_trace)
    i_darkest_group = sorted_indices_cut[:int(percent_considered/100.*len(complete_trace))]  # still take n percent of aftermath, not cut trace!
    # get single coordinate for darkest part
    # find darkest <percent_considered>%, take earliest <extreme_group_size> of them, get median frame index of these, round down to integer frame
    i_darkest_cut = int(floor(np.median(np.sort(i_darkest_group)[:extreme_group_size])))

    i_darkest = i_darkest_cut + i_brightest  # bring it back to original frame indices
    assert i_darkest > i_brightest

    y_brightest = complete_trace[i_brightest]
    #y_darkest = complete_trace[i_darkest]  # TODO: get window value instead?
    y_darkest = get_window(i_darkest, complete_trace)
    y_darkest = get_metric_for_window(y_darkest)
    print(f"{event_uuid}\t{i_brightest}\t{i_darkest}\t{y_brightest}\t{y_darkest}")
    assert y_brightest > y_darkest

    # find time of half maximum
    y_half = (y_brightest + y_darkest)/2.  # bl + (peak - bl)/2
    i_half = np.argmax(complete_trace[i_brightest:] <= y_half)
    i_half += i_brightest
    assert i_brightest < i_half
    assert i_darkest > i_half
    dict_peak_trough[event_uuid] = (i_brightest, i_darkest, i_half, y_brightest, y_darkest, y_half)

### Starting with trough, find time window where metric shows recovery

In [None]:
dict_recovery = {}  # event_uuid: (i_recovery, y_recovery, did_recover)
dict_windows = {}  # event_uuid: [y_bl_window, y_darkest_window, y_post_darkest1, y_post_darkest2, ..., y_recovery_window]

for event_uuid in dict_mean_fluo:
    trace = dict_mean_fluo[event_uuid]
    i_trough = dict_peak_trough[event_uuid][1]
    y_windows = []
    did_recover = False  # assume recovery will be found
    # add y_bl
    y_bl = dict_bl_values[event_uuid][1]
    y_windows.append(y_bl)  
    
    # add trough window to windows list
    i_current = i_trough
    current_win = get_window(i_current, trace)  # start with metric at trough
    y_current = get_metric_for_window(current_win)
    y_windows.append(y_current)
    
    # move on to next window just after trough to start looking for recovery (FIXME: in some cases, already trough is > 95% of bl! by definition we demand recovery to happen after the trough?)
    i_current += window_step_frames
    current_win = get_window(i_current, trace)
    while len(current_win) >= window_width_frames:  # stop algorithm upon reaching end of recording
        y_current = get_metric_for_window(current_win)
        y_windows.append(y_current)
        if y_current >= recovery_ratio*y_bl:  # recovery reached
            did_recover = True
            break
        else:  # move to next window
            i_current += window_step_frames
            current_win = get_window(i_current, trace)
    dict_windows[event_uuid] = y_windows
    dict_recovery[event_uuid] = (i_current, y_current, did_recover)


## Create DataFrame

In [None]:
for event_uuid in dict_peak_trough:
    if event_uuid not in dict_segment_break_points:
        print(event_uuid)

In [None]:
# (raw) columns: event_uuid, mouse_id, experiment_type, peak_time, trough_time, peak_amplitude, trough_amplitude 
df_recovery = pd.DataFrame.from_dict(dict_peak_trough, "index", columns=["i_peak", "i_trough", "i_half", "y_peak", "y_trough", "y_half"]).reset_index()
# replace column name "index" with "event_uuid"
df_recovery["event_uuid"] = df_recovery["index"] 
df_recovery = df_recovery.drop(columns=["index"])
df_recovery["exp_type"] = df_recovery.apply(lambda row: dict_meta[row.event_uuid]["exp_type"], axis=1)
df_recovery["mouse_id"] = df_recovery.apply(lambda row: dict_meta[row.event_uuid]["mouse_id"], axis=1)

df_recovery["y_bl"] = df_recovery.apply(lambda row: dict_bl_values[row["event_uuid"]][1], axis=1)
df_recovery["i_bl"] = df_recovery.apply(lambda row: dict_bl_values[row["event_uuid"]][0], axis=1)

# peak minus trough difference in amplitude
df_recovery["dy_bl_trough"] = df_recovery["y_bl"] - df_recovery["y_trough"]
# peak-trough time difference, s
df_recovery["dt_peak_trough"] = df_recovery["i_trough"]/imaging_freq - df_recovery["i_peak"]/imaging_freq
# peak to half amplitude time difference, s
df_recovery["dt_peak_trough_FWHM"] = df_recovery["i_trough"]/imaging_freq - df_recovery["i_peak"]/imaging_freq

df_recovery["i_recovery"] = df_recovery.apply(lambda row: dict_recovery[row["event_uuid"]][0], axis=1)
df_recovery["y_recovery"] = df_recovery.apply(lambda row: dict_recovery[row["event_uuid"]][1], axis=1)
df_recovery["did_recover"] = df_recovery.apply(lambda row: dict_recovery[row["event_uuid"]][2], axis=1)



df_recovery["dt_trough_recovery"] = df_recovery["i_recovery"]/imaging_freq - df_recovery["i_trough"]/imaging_freq
df_recovery["dt_peak_recovery"] = df_recovery["i_recovery"]/imaging_freq - df_recovery["i_peak"]/imaging_freq


# move i_xy to whole trace indexing frame of reference
df_recovery["i_recovery_whole"] = df_recovery.apply(lambda row: row["i_recovery"] + dict_segment_break_points[row["event_uuid"]][1], axis=1)
df_recovery["i_peak_whole"] = df_recovery.apply(lambda row: row["i_peak"] + dict_segment_break_points[row["event_uuid"]][1], axis=1)
df_recovery["i_trough_whole"] = df_recovery.apply(lambda row: row["i_trough"] + dict_segment_break_points[row["event_uuid"]][1], axis=1)


# final columns: event_uuid, mouse_id, exp_type, y_bl, y_peak, y_trough, y_recovery, dy_trough_peak, dt_peak_trough, dt_peak_trough_FWHM, dt_trough_recovery, dt_peak_recovery, did_recover
df_recovery = df_recovery[["event_uuid", "mouse_id", "exp_type",  "y_bl", "y_peak", "y_trough", "y_recovery", "dy_bl_trough", "dt_peak_trough", "dt_peak_trough_FWHM", "dt_trough_recovery", "dt_peak_recovery", "did_recover"]]


In [None]:
# TODO: where did_recover is False, need to implement linear extrapolation.
df_recovery[["event_uuid", "dt_trough_recovery", "did_recover"]]

### Plot detected peak/trough values

In [None]:
fig = plt.figure(figsize=(18, 42))

AMPLITUDE = 100.0
offset = 0.0
color_dict = {"tmev": "green", "chr2_sd": "red", "chr2_szsd": "blue"}
def normalize_trace(trace):
    min_trace = np.min(trace)
    max_trace = np.max(trace)
    return AMPLITUDE*(trace - min_trace)/(max_trace - min_trace)

for event_uuid in df_recovery.sort_values(by=["exp_type", "event_uuid"]).event_uuid:
    exp_type = dict_meta[event_uuid]["exp_type"]
    plt.plot(normalize_trace(dict_mean_fluo[event_uuid]) + offset, color=color_dict[exp_type], label=exp_type)
    # plot brightest point
    plt.vlines(x=dict_peak_trough[event_uuid][0], ymin=offset, ymax = offset+AMPLITUDE )  # peak
    plt.vlines(x=dict_peak_trough[event_uuid][1], ymin=offset, ymax = offset+AMPLITUDE, color="black" )  # trough
    plt.vlines(x=dict_peak_trough[event_uuid][2], ymin=offset, ymax = offset+AMPLITUDE, color="orange" )  # half max


    offset += AMPLITUDE


plt.legend(color_dict)
ax = plt.gca()
leg = ax.get_legend()
# manually set colors of legend... reading the dict colors does not work for some reason
leg.legendHandles[0].set_color("green")
leg.legendHandles[1].set_color('red')
leg.legendHandles[2].set_color('blue')

plt.show()

### Plot peak-trough time per experiment type

In [None]:
fig = plt.figure(figsize=(6, 10))
g = sns.boxplot(data=df_recovery, y="delta_t_FWHM", hue="exp_type")
plt.show()

In [None]:
fig = plt.figure(figsize=(16, 10))
g = sns.histplot(data=df_recovery, x="delta_t", hue="exp_type",multiple="stack", bins=30)
plt.show()

## Peak-trough amplitude

In [None]:
fig = plt.figure(figsize=(6, 10))
g = sns.boxplot(data=df_recovery, y="delta_amp", hue="exp_type")
plt.show()

In [None]:
# TODO: smoothing and first derivative for minimum?