# Recovery analysis
(TMEV+ChR2 datasets combined)
1. time to recovery
2. Sz-bl, SD-bl amplitudes
3. baseline-trough fluorescence difference
4. peak-trough time

## Set params

In [None]:
percent_considered = 5  # x% darkest/brightest of complete trace to consider
extreme_group_size = 15  # this many of the darkest/brightest pixels to consider (earliest darkest percent_considered% pixels)
n_trough_frames = 5000  # simple method to set upper limit of window where to look for darkest point.
peak_window_length = 300  # consider the first n frames when looking for peak
imaging_freq = 15.  # approx, in hertz
# The following was replaced with a manually set dict:
#n_frames_before_am_start_nc = 0#200  # number of frames to consider additionally before aftermath begin for NC traces to look for trough. Necessary because optical end of seizure segment is "darkest point", so am category might just miss trough.
# WARNING: reducing n_frames_before_am_start_nc might result in the trough falling before the peak (in NC traces). Weird, but it happens... 

In [None]:
# window-related parameters
window_width_s = 10
window_step_s = 5
imaging_frequency = 15. # in Hz
n_frames_before_nc = 200  # include 200 frames just before aftermath for NC recordings  
n_frames_before_ca1 = 0
n_windows_post_darkest = 300 #40 # dataset consists of bl, darkest point, and this many windows post darkest point

default_bl_center_ca1 = -75 # 4925 when 5000 bl frames
default_bl_center_nc = -975  # 4025 when 5000 bl frames

window_width_frames = int(window_width_s*imaging_frequency)
window_step_frames = int(window_step_s*imaging_frequency)

half_window_width_frames = window_width_frames//2

recovery_ratio = 0.95  # reach x % of baseline to be considered recovered

In [None]:
# the aftermath categories for NC are arbitrary: they begin around the trough, but the starting frame was visually set. 
# Manual corrections to the dataset: ()
n_frames_before_am_start_nc = {
    "2251bba132cf45fa839d3214d1651392": 125,
    "4dea78a01bf5408092f498032d67d84e": 205,
    "54c31c3151944cfd86043932d3a19b9a": 60,
    "5cfb012d47f14303a40680d2b333336a": 125,
    "7753b03a2a554cccaab42f1c0458d742": 70,
    "cd3c1e0e3c284a89891d2e4d9a7461f4": 192,
    "f481149fa8694621be6116cb84ae2d3c": 115,
    "f5ccb81a34bb434482e2498bfdf88784": 58,
}

In [None]:
win_types_mapping = {"CA1" : "CA1", "Cx" : "NC"}  # replace Cx with NC

In [None]:
save_dsets = False

In [None]:
save_figs = False
save_as_eps = False
save_as_pdf = True
if save_as_eps:
    output_format = ".eps"
elif save_as_pdf:
    output_format=".pdf"
else:
    output_format = ".jpg"
if save_figs:
    print(output_format)

## Load libraries, set data

In [None]:
#Auto-reload modules (used to develop functions outside this notebook)
%load_ext autoreload
%autoreload 2

In [None]:
import labrotation.file_handling as fh
import seaborn as sns
import os
from datetime import datetime
import datadoc_util
import h5py
import numpy as np
from math import floor, ceil
import matplotlib.pyplot as plt
import pandas as pd
import warnings
from numpy.polynomial.polynomial import Polynomial

In [None]:
sns.set_theme(font_scale=2)
sns.set_style("whitegrid")

In [None]:
chr2_fpath = fh.open_file("Open ChR2 assembled traces h5 file!")

In [None]:
tmev_fpath = fh.open_file("Open TMEV assembled traces h5 file!")

In [None]:
env_dict = dict()
if not os.path.exists("./.env"):
    print(".env does not exist")
else:
    with open("./.env", "r") as f:
        for line in f.readlines():
            l = line.rstrip().split("=")
            env_dict[l[0]] = l[1]
print(env_dict.keys())

In [None]:
if "DATA_DOCU_FOLDER" in env_dict.keys():
    docu_folder = env_dict["DATA_DOCU_FOLDER"]
else:
    docu_folder = fh.open_dir("Choose folder containing folders for each mouse!")
print(f"Selected folder:\n\t{docu_folder}")

In [None]:
if "documentation" in os.listdir(docu_folder):
    mouse_folder = os.path.join(docu_folder, "documentation")
else:
    mouse_folder = docu_folder
mouse_names = os.listdir(mouse_folder)
print(f"Mice detected:")
for mouse in mouse_names:
    print(f"\t{mouse}")

In [None]:
def get_datetime_for_fname():
    now = datetime.now()
    return f"{now.year:04d}{now.month:02d}{now.day:02d}-{now.hour:02d}{now.minute:02d}{now.second:02d}"

In [None]:
output_folder = env_dict["DOWNLOADS_FOLDER"]
print(f"Output files will be saved to {output_folder}")

In [None]:
ddoc = datadoc_util.DataDocumentation(docu_folder)
ddoc.loadDataDoc()

## Load traces

In [None]:
dict_mean_fluo = {} # uuid: [mean_fluo], cut to aftermath (+ extra frames) only!
dict_bl_fluo = {}  # baseline (until segment_type_break_points[1]) 
dict_mid_fluo = {}  # rest of trace: sz or stim+sz
# to get complete trace for event_uuid, np.concatenate([dict_bl_fluo[event_uuid], dict_mid_fluo[event_uuid], dict_mean_fluo[event_uuid]])
dict_meta = {}  # uuid: {"exp_type": exp_type, "mouse_id": mouse_id, "session_uuids": [session_uuids], "segment_type_break_points": [segment_type_break_points]}

dict_excluded = {}  # uuid: {"exp_type": exp_type, "mouse_id": mouse_id, "win_type": window_type, "session_uuids": [session_uuids]}

dict_segment_break_points = {}  # uuid: (i_begin_mid, i_begin_am). bl: [:i_begin_mid], mid: [i_begin_mid:i_begin_am], am: [i_begin_am:]

# Load traces. Set start time to appearance of first SD wave. TODO: maybe last SD wave must be used?
for fpath in [tmev_fpath, chr2_fpath]:
    with h5py.File(fpath, "r") as hf:
        for event_uuid in hf.keys():
            win_type = win_types_mapping[hf[event_uuid].attrs["window_type"]]
            assert "session_uuids" in hf[event_uuid].attrs
            mouse_id = hf[event_uuid].attrs["mouse_id"]
            # for TMEV, traces were stitched together from multiple recordings, so uuid is not in data documentation. 
            # But the individual session uuids are stored in attributes (both for ChR2 and TMEV data)
            session_uuids = hf[event_uuid].attrs["session_uuids"]
            exp_type = ddoc.getExperimentTypeForUuid(session_uuids[0])
            mean_fluo = np.array(hf[event_uuid]["mean_fluo"])
            segment_type_break_points = hf[event_uuid].attrs["segment_type_break_points"]
            if exp_type == "tmev":
                # as TMEV traces are stitched together, it is difficult to use data documentation.
                # But segment_type_break_points attribute contains bl, sz, am begin frames.
                # am (aftermath) is defined as visual appearance of first SD wave. Can take this as beginning
                assert len(segment_type_break_points) == 3  # make sure only bl, sz, am points are in list
                i_begin_am = segment_type_break_points[2]
                i_begin_mid = segment_type_break_points[1]  # one frame past end of baseline, i.e. begin of middle section (sz)
                if win_type == "NC":  # NC seizures end abruptly, manual segmentation tries to set "reaching darkest point" as end of Sz. This means trough might be missed in original "aftermath" category.
                    i_begin_am -= n_frames_before_am_start_nc[event_uuid]
                    assert i_begin_am > 0
            elif exp_type in ["chr2_sd", "chr2_szsd"]:
                assert session_uuids[0] == event_uuid
                df_segments = ddoc.getSegmentsForUUID(event_uuid)
                # set first frame of first SD appearance as beginning
                i_begin_am = df_segments[df_segments["interval_type"] == "sd_wave"].frame_begin.min() - 1  # 1-indexing to 0-indexing conversion
                i_begin_mid = df_segments[df_segments["interval_type"] == "stimulation"].frame_begin.min() - 1
            else:
                continue  # do not add chr2_ctl recordings to dataset 
            if not np.isnan(i_begin_am):
                bl_fluo = mean_fluo[:i_begin_mid].copy()
                mid_fluo = mean_fluo[i_begin_mid:i_begin_am].copy()
                if not len(mid_fluo) > 0:
                    print(f"{i_begin_mid} - {i_begin_am}")
                mean_fluo = mean_fluo[i_begin_am:]

                dict_segment_break_points[event_uuid] = (i_begin_mid, i_begin_am)

                dict_bl_fluo[event_uuid] = bl_fluo
                dict_mean_fluo[event_uuid] = mean_fluo
                dict_mid_fluo[event_uuid] = mid_fluo
                dict_meta[event_uuid] = {"exp_type": exp_type, "mouse_id": mouse_id, "win_type": win_type, "session_uuids": session_uuids, "segment_type_break_points": segment_type_break_points}
            else:
                dict_excluded[event_uuid] = {"exp_type": exp_type, "mouse_id": mouse_id, "win_type": win_type, "session_uuids": session_uuids, "segment_type_break_points": segment_type_break_points}


In [None]:
check_am_begin_nc = False
if check_am_begin_nc:
    fig = plt.figure(figsize=(20, 42))
    AMPLITUDE = 100
    offset = 0.0
    for event_uuid in dict_meta:
        exp_type = dict_meta[event_uuid]["exp_type"]
        win_type = dict_meta[event_uuid]["win_type"]
        if exp_type == "tmev" and win_type == "NC":
            full_trace = np.concatenate([dict_bl_fluo[event_uuid][-200:], dict_mid_fluo[event_uuid], dict_mean_fluo[event_uuid]])
            norm_trace = AMPLITUDE*(full_trace - np.min(full_trace))/(np.max(full_trace) - np.min(full_trace))
            plt.plot(norm_trace + offset)
            plt.vlines(x=len(full_trace) - len(dict_mean_fluo[event_uuid]), ymin=offset, ymax = offset+1.1*AMPLITUDE, color="black")
            plt.text(800, offset+0.3*AMPLITUDE, event_uuid, fontdict={"fontsize":40})
            offset += 1.1*AMPLITUDE
            print(event_uuid)
            

    plt.xlim((0, 1500))
    plt.show()

# Assemble recovery dataset

### Define window-related functions

In [None]:
def get_window(i_center, trace) -> np.array:
    """Given i_center and the global parameter half_window_width_frames, try to return a window centered around i_center, 
    and with inclusive borders at i_center - half_window_width_frames, i_center + half_window_width_frames. Might return a 
    smaller window [0, i_center + half_window_width_frames], or [i_center - half_window_width_frames, len(trace) - 1] if the 
    boundaries are outside the shape of trace.
    Parameters
    ----------
    i_center : int
        The index of center of the window in trace
    trace : np.array
        The trace to extract the window from

    Returns
    -------
    np.array
        The window, a subarray of trace
    """
    if i_center > len(trace):
        warnings.warn(f"Trying to access window with center {i_center}, but only {len(trace)} frames")
        return np.array([])
    if i_center + half_window_width_frames > len(trace):
        warnings.warn(f"Part of window out of bounds: {i_center} + HW {half_window_width_frames} > {len(trace)}")
        right_limit = len(trace)
    else:
        right_limit = i_center + half_window_width_frames + 1  # right limit is exclusive
    if i_center - half_window_width_frames < 0:
        warnings.warn(f"Part of window out of bounds: {i_center} - HW {half_window_width_frames} < 0")
        left_limit = 0
    else:
        left_limit = i_center - half_window_width_frames
    return trace[left_limit : right_limit]

In [None]:
SD_WINDOW_WIDTH_FRAMES = 450  # 30 s window beginning with "am" segment to look for amplitude of SD
def get_window_for_event_type(event_uuid, event_type="sz"):
    """Given the event uuid and the event type (sz, sd) to look for, return np.array() of the corresponding window in
    the whole trace of the original hdf5 data

    Parameters
    ----------
    event_uuid : str
        the event_uuid of the trace (the name of the hdf5 group)
    event_type : str
        "sd" or "sz". The event for which the window to be returned: end of bl/stim until beginning of first SD if "sz", else a fixed
        30s window starting with the appearance of the first SD wave.

    Returns
    -------
    np.array
        The window (empty array if event_type does not exist for the recording type)
    """
    exp_type = dict_meta[event_uuid]["exp_type"]
    win_type = dict_meta[event_uuid]["win_type"]
    complete_trace = np.concatenate([dict_bl_fluo[event_uuid], dict_mid_fluo[event_uuid], dict_mean_fluo[event_uuid]])

    if exp_type == "tmev":
        break_points = dict_meta[event_uuid]["segment_type_break_points"]  # [bl_begin, sz_begin, SD_begin]
        if event_type == "sz":  # am begins with appearance of first SD wave -> if sz, get time second and third indices
            return complete_trace[break_points[1]:break_points[2]]
        elif event_type == "sd" and win_type != "NC":  # prove me wrong, but no SD in NC. :)  
            return complete_trace[break_points[2]: break_points[2] + SD_WINDOW_WIDTH_FRAMES]
    elif "chr2" in exp_type:
        df_segments = ddoc.getSegmentsForUUID(event_uuid)  # sessions consist of one recording, so event_uuid = recording_uuid
        if event_type == "sz" and exp_type == "chr2_szsd":
            i_begin_sz = df_segments[df_segments["interval_type"] == "sz"].frame_begin.iloc[0] - 1  # switch to 0-based indexing
            i_end_sz = df_segments[df_segments["interval_type"] == "sz"].frame_end.iloc[0]  # upper limit exclusive
            return complete_trace[i_begin_sz: i_end_sz]
        elif event_type == "sd" and "sd" in exp_type:
            i_begin_sd = df_segments[df_segments["interval_type"] == "sd_wave"].frame_begin.iloc[0] - 1  # switch to 0-based indexing
            i_end_sd = i_begin_sd + SD_WINDOW_WIDTH_FRAMES
            return complete_trace[i_begin_sd:i_end_sd]
    warnings.warn("No window found!")
    return np.array([])

def get_metric_for_window(trace_window):
    """Given a window, calculate the following metric: 
    1. Take percent_considered % of the lowest values within the window
    2. Get the median value of the values found in step 1.

    Parameters
    ----------
    trace_window : np.array
        The window to calculate the metric for.

    Returns
    -------
    float
        The calculated metric 
    """
    lowest_indices = np.argsort(trace_window)[:int(percent_considered/100.*len(trace_window))]
    lowest_values = trace_window[lowest_indices]
    return np.median(lowest_values)

def get_peak_metric(trace_window):
    """Given a trace window, calculate the mean of top 5% values. Intended use: SD and Sz amplitudes.
    Parameters
    ----------
    trace_window : np.array
        The window to calculate the metric for.

    Returns
    -------
    float
        The calculated metric
    """

    mean_top_5p = np.flip(np.sort(trace_window))[:int(0.05*len(trace_window))].mean()  # take mean of highest 5% of sz values
    return mean_top_5p

### Find baseline windows, metrics

In [None]:
#dict_uuid_manual_bl_center = {"aa66ae0470a14eb08e9bcadedc34ef64": 4250, "c7b29d28248e493eab02288b85e3adee": 4000,  "7b9c17d8a1b0416daf65621680848b6a": 4050, "9e75d7135137444492d104c461ddcaac": 4700, "d158cd12ad77489a827dab1173a933f9": 4500, "a39ed3a880c54f798eff250911f1c92f" : 4500, "4e2310d2dde845b0908519b7196080e8" : 4500, "f0442bebcd1a4291a8d0559eb47df08e": 4500, "2251bba132cf45fa839d3214d1651392": 3700, "cd3c1e0e3c284a89891d2e4d9a7461f4": 3500}
# fix the dict to work with arbitrary length baseline instead of length 5000 only.
dict_uuid_manual_bl_center = {"aa66ae0470a14eb08e9bcadedc34ef64": -750, "c7b29d28248e493eab02288b85e3adee": -1000,  "7b9c17d8a1b0416daf65621680848b6a": -950, "9e75d7135137444492d104c461ddcaac": -300, "d158cd12ad77489a827dab1173a933f9": -500, "a39ed3a880c54f798eff250911f1c92f" : -500, "4e2310d2dde845b0908519b7196080e8" : -500, "f0442bebcd1a4291a8d0559eb47df08e": -500, "2251bba132cf45fa839d3214d1651392": -1300, "cd3c1e0e3c284a89891d2e4d9a7461f4": -1500}

# uuid: (i_bl, bl_metric), i_bl is the center of the window
dict_bl_values = {}

for uuid in dict_meta.keys():  # uuid: {"exp_type": exp_type, "mouse_id": mouse_id, "session_uuids": [session_uuids]}
    exp_type = dict_meta[uuid]["exp_type"]
    win_type = dict_meta[uuid]["win_type"]
    # check if manually corrected. If not, check if TMEV or not. If TMEV, use default_bl_center_ca1/default_bl_center_nc
    # if ChR2, can use a window right before stim
    bl_trace = dict_bl_fluo[uuid]
    if uuid in dict_uuid_manual_bl_center:
        i_bl = dict_uuid_manual_bl_center[uuid]
    elif exp_type == "tmev":
        if win_type == "CA1":
            i_bl = default_bl_center_ca1
        elif win_type == "NC":
            i_bl = default_bl_center_nc
    elif exp_type in ["chr2_sd", "chr2_szsd"]:
        # take a window just before stim
        i_bl = len(bl_trace) - half_window_width_frames - 1
    if i_bl < 0:
        i_bl = len(bl_trace) + i_bl
    bl_win = get_window(i_bl, bl_trace)
    bl_metric = get_metric_for_window(bl_win)
    dict_bl_values[uuid] = (i_bl, bl_metric)
    

### Find significant time points and corresponding values
Sz amplitude (if exists), peak, trough (darkest point), recovery position

In [None]:
# aftermath:
# TMEV - appearance of first SD. This could also be taken above
# ChR2 - if SD present, then appearance of first SD. Else: directly after stim (ctl).

dict_significant_tpoints = {}  # uuid: (i_sd_peak, i_trough, i_fwhm, peak_amplitude, trough_amplitude, fwhm_amplitude=peak_amplitude/2)

for event_uuid in dict_mean_fluo.keys(): 
    exp_type = dict_meta[event_uuid]["exp_type"]
    win_type = dict_meta[event_uuid]["win_type"]

    # traces already cut to "aftermath" (plus few extra frames)
    complete_trace = dict_mean_fluo[event_uuid]
    
    # get 5% darkest points of aftermath
    sorted_indices = np.argsort(complete_trace)  # this cut should not influence the index
    
    # get brightest frame
    # old method, uses same percentages and median as darkest frame. Did not work well
    #i_brightest_group = np.flip(sorted_indices)[:int(percent_considered/100.*len(sorted_indices))]
    #i_brightest = int(floor(np.median(np.sort(i_brightest_group)[:extreme_group_size])))
    sorted_beginning = np.argsort(complete_trace[:peak_window_length])
    i_brightest = sorted_beginning[-1]  # this is supposed to be SD amplitude. Later, set it to np.nan if there is no SD.
    i_sd_peak = i_brightest 

    cut_trace = complete_trace[i_brightest:i_brightest+n_trough_frames]
    # use reduced window to look for trough
    sorted_indices_cut = np.argsort(cut_trace)
    i_darkest_group = sorted_indices_cut[:int(percent_considered/100.*len(complete_trace))]  # still take n percent of aftermath, not cut trace!
    # get single coordinate for darkest part
    # find darkest <percent_considered>%, take earliest <extreme_group_size> of them, get median frame index of these, round down to integer frame
    i_darkest_cut = int(floor(np.median(np.sort(i_darkest_group)[:extreme_group_size])))

    i_darkest = i_darkest_cut + i_brightest  # bring it back to original frame indices


    # get Sz and SD amplitude metrics
    #y_brightest = complete_trace[i_brightest]
    # TODO: originally, i_brightest was the index of maximum brightness. In Baseline recovery, the SD amplitude uses different approach
    #       as implemented below. Need to remove i_brightest and old y_brightest = complete_trace[i_brightest]! (brightest is SD peak)
    sd_window = get_window_for_event_type(event_uuid, "sd")
    y_sd_peak = get_peak_metric(sd_window)
    if len(sd_window) == 0:
        i_sd_peak = np.nan 

    sz_window = get_window_for_event_type(event_uuid, "sz")
    y_sz_peak = get_peak_metric(sz_window)


    #y_darkest = complete_trace[i_darkest]  # TODO: get window value instead?
    y_darkest = get_window(i_darkest, complete_trace)
    y_darkest = get_metric_for_window(y_darkest)

    # find time of half maximum
    if not np.isnan(i_sd_peak):
        y_half = (y_sd_peak + y_darkest)/2.  # bl + (peak - bl)/2
    if not np.isnan(i_sd_peak):
        i_half = np.argmax(complete_trace[i_brightest:] <= y_half)
        i_half += i_sd_peak
    else:
        i_half = np.nan
        y_half = np.nan

    #print()
    #print(i_darkest)
    #print(i_half)
    #assert i_brightest < i_half
    #assert i_darkest > i_half

    if win_type == "NC":  # no SD in NC windows... Correct this logic if I'm wrong :)
        print(f"{event_uuid} window type is NC. No SD = no peak.")
        i_brightest = -1
        y_brightest = np.nan
    dict_significant_tpoints[event_uuid] = (i_sd_peak, i_darkest, i_half, y_sd_peak, y_darkest, y_half)

### Find seizure amplitude (if exists)

In [None]:
# TODO: this should be outdated, as sz and sd amplpitudes are added to dict_significant_tpoints above
# uuid: (i_mid_max, y_mid_max) where i_mid_max is the frame index of the mid segment (dict_mid_fluo). 
#  Stim is ignored when finding the max.
dict_sz_amps = {}  # contains absolute amplitude, not compared to baseline!

for event_uuid in dict_mean_fluo.keys(): 
    exp_type = dict_meta[event_uuid]["exp_type"]
    win_type = dict_meta[event_uuid]["exp_type"]
    if exp_type in ["chr2_szsd", "tmev"]:  # only consider recordings where seizure occurs
        mid_trace = dict_mid_fluo[event_uuid]
        if exp_type == "chr2_szsd":  # ignore stim frames
            df_segments = ddoc.getSegmentsForUUID(event_uuid)
            assert "sz" in df_segments.interval_type.unique()  # make sure sz actually occurred
            # get number of stim frames to ignore. mid section begins with stim frames.
            i_begin_stim = df_segments[df_segments["interval_type"] == "stimulation"].frame_begin.iloc[0]  # inclusive
            i_end_stim = df_segments[df_segments["interval_type"] == "stimulation"].frame_end.iloc[0]  # inclusive
            n_stim_frames = i_end_stim - i_begin_stim + 1
            mid_trace = mid_trace[n_stim_frames:]
        else:  # if tmev, make sure sz segment exists
            session_uuids = dict_meta[event_uuid]["session_uuids"]
            sz_present = False
            for session_uuid in session_uuids:
                df_segments = ddoc.getSegmentsForUUID(session_uuid)
                if "sz" in df_segments.interval_type.unique():
                    sz_present = True
                    break
            if not sz_present:
                dict_sz_amps[event_uuid] = np.nan
                print(event_uuid)
                continue
        dict_sz_amps[event_uuid] = np.max(mid_trace)
    else:  # no seizure in experiment
        dict_sz_amps[event_uuid] = np.nan

### Starting with trough, find time window where metric shows recovery

In [None]:
def get_windows_from(i_begin_center, trace):
    """Given a trace and the 0-based index of the center of a first window, return the indices and the corresponding window metrics.

    Parameters
    ----------
    i_begin_center : int
        The 0-based index of the first window center to include
    trace : np.array
        The trace

    Returns
    -------
    tuple(np.array, np.array)
        A tuple with two arrays: at location 0, the 0-based window centers and at location 1, the corresponding window metrics.
    """
    i_center_current = i_begin_center
    x_vals = []
    y_vals = []
    while i_center_current < len(trace) - half_window_width_frames:  # stop algorithm upon reaching end of recording
        current_win = get_window(i_center_current, trace)
        y_current = get_metric_for_window(current_win)
        x_vals.append(i_center_current)
        y_vals.append(y_current)

        i_center_current += window_step_frames
    return (np.array(x_vals), np.array(y_vals))

def try_extrapolate_recovery(x_vals, y_vals, y_expol):
    """Given the x values x_vals and the corresponding y-values y_vals, try to find the x value corresponding to y_expol, 
    based on linear extrapolation. This algorithm is specialized on finding recovery time point, so a line with positive slope is sought.
    If this cannot be found, a large time point is returned.
    Parameters
    ----------
    x_vals : np.array
        The x values (in the notebook, intended use is frames inb 0-based indexing, the center of windows)
    y_vals : np.array
        The y values (intended use is the metrics of the windows specified by x_vals)
    y_expol : int (or scalar, same as y_vals.dtype)
        The y value to extrapolate to. (intended use case: recovery_ratio*y_baseline)

    Returns
    -------
    int
        The found extrapolated time point, in same units as x_vals. (intended use case: the frame index where 95% of baseline is reached)
    """
    try:
        line_fit_coeffs = Polynomial.fit(x_vals, y_vals, deg=1).convert().coef  # linear fit starting with point after darkest time point
        # the coefficients [a, b] from y= a + b*x.
        if line_fit_coeffs[1] <= 0:  # Check if b is non-positive -> No recovery possible
            #x_recovery_single_ca1.append(np.inf)
            return 25000  # set a very late recovery. TODO: come up with better value! np.inf messes up statistics...
    except np.linalg.LinAlgError as e:
        print(f"Could not extrapolate. Returning last window center as extrapolation time point...")
        return x_vals[-1]
    else:  # b>0 -> line is ascending, i.e. there will be a recovery time
        # find inverse function. We know y = a + b*x, need to have x = c + d*y, where y = <threshold>*baseline (threshold=0.95)
        # inverse is x = -a/b + (1/b)*y = a_inv + b_inv*y
        a_inv = -line_fit_coeffs[0]/line_fit_coeffs[1]
        b_inv = 1/line_fit_coeffs[1]
        x_recovery = a_inv + b_inv*recovery_ratio*y_expol
        x_recovery = ceil(x_recovery)
        return x_recovery
        

In [None]:
dict_recovery = {}  # event_uuid: (i_recovery, y_recovery, did_recover)
dict_windows = {}  # event_uuid: [y_bl_window, y_darkest_window, y_post_darkest1, y_post_darkest2, ..., y_recovery_window]

for event_uuid in dict_mean_fluo:
    trace = dict_mean_fluo[event_uuid]
    i_trough = dict_significant_tpoints[event_uuid][1]
    
    x_windows = []  # the 0-indexed window center coordinates [x_bl, x_trough, x_window1, ...]
    y_windows = []  # the corresponding window metrics [y_bl, y_trough, y_window1, ...]
    
    did_recover = False  # assume recovery will be found
    # add y_bl
    i_bl = dict_bl_values[event_uuid][0]
    y_bl = dict_bl_values[event_uuid][1]
    x_windows.append(i_bl)
    y_windows.append(y_bl)  
    
    # add trough window to windows list
    i_current = i_trough
    current_win = get_window(i_current, trace)  # start with metric at trough
    y_current = get_metric_for_window(current_win)
    y_windows.append(y_current)
    x_windows.append(i_current)

    # move on to next window just after trough to start looking for recovery (FIXME: in some cases, already trough is > 95% of bl! by definition we demand recovery to happen after the trough?)
    i_current += window_step_frames
    current_win = get_window(i_current, trace)
    while len(current_win) >= window_width_frames:  # stop algorithm upon reaching end of recording
        y_current = get_metric_for_window(current_win)
        y_windows.append(y_current)
        x_windows.append(i_current)
        if y_current >= recovery_ratio*y_bl:  # recovery reached
            did_recover = True
            break
        else:  # move to next window
            i_current += window_step_frames
            current_win = get_window(i_current, trace)
    # if no recovery found within trace, try to extrapolate. Start with window after trough.
    if not did_recover:
        x_recovery = try_extrapolate_recovery(x_windows[2:], y_windows[2:], recovery_ratio*y_bl)
        y_recovery = recovery_ratio*y_bl
    else:
        x_recovery = i_current
        y_recovery = y_current

    dict_windows[event_uuid] = y_windows

    dict_recovery[event_uuid] = (x_recovery, y_recovery, did_recover)


## Create DataFrame

In [None]:
# (raw) columns: event_uuid, mouse_id, experiment_type, peak_time, trough_time, peak_amplitude, trough_amplitude 
df_recovery = pd.DataFrame.from_dict(dict_significant_tpoints, "index", columns=["i_peak", "i_trough", "i_half", "y_peak", "y_trough", "y_half"]).reset_index()
# replace column name "index" with "event_uuid"
df_recovery["event_uuid"] = df_recovery["index"] 
df_recovery = df_recovery.drop(columns=["index"])
df_recovery["exp_type"] = df_recovery.apply(lambda row: dict_meta[row.event_uuid]["exp_type"], axis=1)
df_recovery["mouse_id"] = df_recovery.apply(lambda row: dict_meta[row.event_uuid]["mouse_id"], axis=1)
df_recovery["win_type"] = df_recovery.apply(lambda row: dict_meta[row.event_uuid]["win_type"], axis=1)

df_recovery["y_bl"] = df_recovery.apply(lambda row: dict_bl_values[row["event_uuid"]][1], axis=1)
df_recovery["i_bl"] = df_recovery.apply(lambda row: dict_bl_values[row["event_uuid"]][0], axis=1)

# peak minus trough difference in amplitude
df_recovery["dy_bl_trough"] = df_recovery["y_bl"] - df_recovery["y_trough"]
# peak-trough time difference, s
df_recovery["dt_peak_trough"] = df_recovery["i_trough"]/imaging_freq - df_recovery["i_peak"]/imaging_freq
# peak to half amplitude time difference, s
df_recovery["dt_peak_trough_FWHM"] = df_recovery["i_trough"]/imaging_freq - df_recovery["i_peak"]/imaging_freq

df_recovery["i_recovery"] = df_recovery.apply(lambda row: dict_recovery[row["event_uuid"]][0], axis=1)
df_recovery["y_recovery"] = df_recovery.apply(lambda row: dict_recovery[row["event_uuid"]][1], axis=1)
df_recovery["did_recover"] = df_recovery.apply(lambda row: dict_recovery[row["event_uuid"]][2], axis=1)
df_recovery["extrapolated"] = ~df_recovery["did_recover"]


df_recovery["dt_trough_recovery"] = df_recovery["i_recovery"]/imaging_freq - df_recovery["i_trough"]/imaging_freq
df_recovery["dt_peak_recovery"] = df_recovery["i_recovery"]/imaging_freq - df_recovery["i_peak"]/imaging_freq


# move i_xy to whole trace indexing frame of reference
df_recovery["i_recovery_whole"] = df_recovery.apply(lambda row: row["i_recovery"] + dict_segment_break_points[row["event_uuid"]][1], axis=1)
df_recovery["i_peak_whole"] = df_recovery.apply(lambda row: row["i_peak"] + dict_segment_break_points[row["event_uuid"]][1], axis=1)
df_recovery["i_trough_whole"] = df_recovery.apply(lambda row: row["i_trough"] + dict_segment_break_points[row["event_uuid"]][1], axis=1)

df_recovery["y_sz_max"] = df_recovery.apply(lambda row: dict_sz_amps[row["event_uuid"]], axis=1)

df_recovery["dy_bl_sz"] = df_recovery["y_sz_max"] - df_recovery["y_bl"]
df_recovery["dy_bl_sd"] = df_recovery["y_peak"] - df_recovery["y_bl"]  # peak of aftermath is the largest SD amplitude
# final columns: event_uuid, mouse_id, exp_type, y_bl, y_peak, y_trough, y_recovery, dy_trough_peak, dt_peak_trough, dt_peak_trough_FWHM, dt_trough_recovery, dt_peak_recovery, did_recover
df_recovery = df_recovery[["event_uuid", "mouse_id", "exp_type", "win_type",  "y_bl", "y_sz_max", "y_peak", "y_trough", "y_recovery", "dy_bl_sz", "dy_bl_sd", "dy_bl_trough", "dt_peak_trough", "dt_peak_trough_FWHM", "dt_trough_recovery", "dt_peak_recovery", "extrapolated"]].sort_values(by=["exp_type", "win_type", "event_uuid"])


# 1. Recovery time

In [None]:
# the following dataframe contains recovery time and experiment metadata:
# if did_recover is false, extrapolation was used
df_recovery_time = df_recovery[[ "mouse_id", "exp_type", "win_type", "event_uuid", "dt_trough_recovery", "extrapolated"]].sort_values(by=["exp_type", "win_type", "mouse_id"])
df_recovery_time = df_recovery_time.rename(columns={"win_type": "window_type", "dt_trough_recovery": "t_recovery_s"})

In [None]:
if save_dsets:
    fpath_recovery = os.path.join(output_folder, f"recovery_times_{get_datetime_for_fname()}.xlsx")
    df_recovery_time.to_excel(fpath_recovery, index=False)
    print(f"Saved file to {fpath_recovery}")

# 2. Bl-Sz, Bl-SD amplitudes

In [None]:
df_amplitudes = df_recovery[["mouse_id", "event_uuid", "exp_type", "win_type", "dy_bl_sz", "dy_bl_sd"]].sort_values(by=["exp_type", "win_type", "mouse_id"])
df_amplitudes = df_amplitudes.rename(columns={"dy_bl_sz": "Sz-bl", "dy_bl_sd": "SD-bl"})

In [None]:
if save_dsets:
    fpath_amplitudes = os.path.join(output_folder, f"sz_sd_amplitudes_{get_datetime_for_fname()}.xlsx")
    df_amplitudes.to_excel(fpath_amplitudes, index=False)
    print(f"Saved file to {fpath_amplitudes}")

# 3. Baseline-trough difference amplitude (amount of depression)

In [None]:
df_bl_darkest = df_recovery[[ "mouse_id", "event_uuid", "exp_type", "win_type", "y_bl", "y_trough", "dy_bl_trough", "extrapolated"]].sort_values(by=["exp_type", "win_type", "mouse_id"])
df_bl_darkest = df_bl_darkest.rename(columns={"y_bl": "baseline", "y_trough": "darkest_postictal", "dy_bl_trough": "bl-darkest"})

In [None]:
if save_dsets:
    fpath_bl_darkest = os.path.join(output_folder, f"bl-to-darkest-point_{get_datetime_for_fname()}.xlsx")
    df_bl_darkest.to_excel(fpath_bl_darkest, index=False)
    print(f"Saved file to {fpath_bl_darkest}")

# 4. Peak FWHM - trough time

In [None]:
df_peak_trough_fwhm = df_recovery[["event_uuid", "mouse_id", "exp_type", "win_type", "dt_peak_trough_FWHM", "extrapolated"]].sort_values(by=["exp_type", "win_type", "mouse_id"])

In [None]:
if save_dsets:
    fpath_peak_trough = os.path.join(output_folder, f"peak_trough_fwhm_{get_datetime_for_fname()}.xlsx")
    df_peak_trough_fwhm.to_excel(fpath_peak_trough, index=False)
    print(f"Saved file to {fpath_peak_trough}")

# Plotting

### Plot detected peak/trough values

In [None]:
fig = plt.figure(figsize=(18, 42))

AMPLITUDE = 100.0
offset = 0.0
dict_colors = {"tmev": "green", "chr2_sd": "red", "chr2_szsd": "blue"}
dict_significant_frames_colors = {"peak": "limegreen", "half_max": "darkgreen", "trough": "black", "recovery":"green" }

def normalize_trace(trace):
    min_trace = np.min(trace)
    max_trace = np.max(trace)
    return AMPLITUDE*(trace - min_trace)/(max_trace - min_trace)

for event_uuid in df_recovery.sort_values(by=["exp_type", "mouse_id"]).event_uuid:
    exp_type = dict_meta[event_uuid]["exp_type"]
    mouse_id = dict_meta[event_uuid]["mouse_id"]
    trace_color = ddoc.getColorForMouseId(mouse_id)
    # plot small part of bl, whole mid section, and whole aftermath
    bl_trace = dict_bl_fluo[event_uuid][-200:]
    n_cut_frames = len(dict_bl_fluo[event_uuid]) - len(bl_trace)

    mid_trace = dict_mid_fluo[event_uuid]
    am_trace = dict_mean_fluo[event_uuid][:5000]

    i_shift = len(bl_trace) + len(mid_trace)  # shift indices for peak, trough, etc. to account for extra bl, mid frames 

    full_trace = np.concatenate([bl_trace, mid_trace, am_trace])
    if exp_type in ["chr2_sd", "chr2_szsd"]:  # need to reduce stim amplitude to make sz, sd more visible
        df_segments = ddoc.getSegmentsForUUID(event_uuid)
        i_begin_stim = df_segments[df_segments["interval_type"] == "stimulation"].frame_begin.iloc[0] - n_cut_frames - 1  # switch to 0-based indexing
        i_end_stim = df_segments[df_segments["interval_type"] == "stimulation"].frame_end.iloc[0] - n_cut_frames  # exclusive limit to numpy [a:b] indexing
        full_trace[i_begin_stim:i_end_stim] = np.max(full_trace[i_end_stim:])  # set stim amplitude to maximum of signal to not lose details when scaling trace
    plt.plot(normalize_trace(full_trace) + offset, color=trace_color, label=exp_type)
    # plot significant time points
    plt.vlines(x=dict_significant_tpoints[event_uuid][0] + i_shift, ymin=offset, ymax = offset+1.1*AMPLITUDE, color=dict_significant_frames_colors["peak"])  # peak
    plt.vlines(x=dict_significant_tpoints[event_uuid][2] + i_shift, ymin=offset, ymax = offset+1.1*AMPLITUDE, color=dict_significant_frames_colors["half_max"] )  # half max
    plt.vlines(x=dict_significant_tpoints[event_uuid][1] + i_shift, ymin=offset, ymax = offset+1.1*AMPLITUDE, color=dict_significant_frames_colors["trough"] )  # trough
    plt.vlines(x=dict_recovery[event_uuid][0] + i_shift, ymin=offset, ymax=offset+1.1*AMPLITUDE, color=dict_significant_frames_colors["recovery"])


    offset += AMPLITUDE


plt.legend(dict_significant_frames_colors)
ax = plt.gca()
leg = ax.get_legend()
# manually set colors of legend... reading the dict colors does not work for some reason
leg.legendHandles[0].set_color(dict_significant_frames_colors["peak"])
leg.legendHandles[2].set_color(dict_significant_frames_colors["half_max"])
leg.legendHandles[1].set_color(dict_significant_frames_colors["trough"])
leg.legendHandles[3].set_color(dict_significant_frames_colors["recovery"])


plt.xlim((0, 7500))
plt.show()

In [None]:
for event_uuid in dict_mean_fluo.keys():
    y_peak = df_recovery[df_recovery["event_uuid"] == event_uuid]["y_peak"].iloc[0] 
    full_trace = np.concatenate([dict_mean_fluo[event_uuid] ])
    if not y_peak == np.max(full_trace):
        print(f"{dict_meta[event_uuid]['mouse_id']}: {y_peak}, {np.max(full_trace)}")

In [None]:
df_recovery

### Plot peak-trough time per experiment type

In [None]:
fig = plt.figure(figsize=(6, 10))
g = sns.boxplot(data=df_recovery, y="delta_t_FWHM", hue="exp_type")
plt.show()

In [None]:
fig = plt.figure(figsize=(16, 10))
g = sns.histplot(data=df_recovery, x="delta_t", hue="exp_type",multiple="stack", bins=30)
plt.show()

## Peak-trough amplitude

In [None]:
fig = plt.figure(figsize=(6, 10))
g = sns.boxplot(data=df_recovery, y="delta_amp", hue="exp_type")
plt.show()

In [None]:
# TODO: smoothing and first derivative for minimum?
# TODO: implement all 3 analyses (recovery time, Sz/SD amplitudes, bl-trough amplitude difference = amount of depression)