In [None]:
#Auto-reload modules (used to develop functions outside this notebook)
%load_ext autoreload
%autoreload 2

In [None]:
import labrotation.file_handling as fh
import h5py
from time import time
import matplotlib.pyplot as plt
import numpy as np
import os
from labrotation import file_handling as fh
from copy import deepcopy
import pandas as pd
import labrotation.two_photon_session as tps
import seaborn as sns
import uuid  # for unique labeling of sessions and coupling arrays (mouse velocity, distance, ...) to sessions in dataframe 
from matplotlib import cm  # colormap
import datadoc_util
from labrotation import two_photon_session as tps
from datetime import datetime
import seaborn as sns
from math import floor

In [None]:
sns.set(font_scale=2)
sns.set_style("whitegrid")

In [None]:
env_dict = dict()
if not os.path.exists("./.env"):
    print(".env does not exist")
else:
    with open("./.env", "r") as f:
        for line in f.readlines():
            l = line.rstrip().split("=")
            env_dict[l[0]] = l[1]
print(env_dict.keys())

In [None]:
if "DATA_DOCU_FOLDER" in env_dict.keys():
    docu_folder = env_dict["DATA_DOCU_FOLDER"]
else:
    docu_folder = fh.open_dir("Choose folder containing folders for each mouse!")
print(f"Selected folder:\n\t{docu_folder}")

In [None]:
if "documentation" in os.listdir(docu_folder):
    mouse_folder = os.path.join(docu_folder, "documentation")
else:
    mouse_folder = docu_folder
mouse_names = os.listdir(mouse_folder)
print(f"Mice detected:")
for mouse in mouse_names:
    print(f"\t{mouse}")

In [None]:
def get_datetime_for_fname():
    now = datetime.now()
    return f"{now.year:04d}{now.month:02d}{now.day:02d}-{now.hour:02d}{now.minute:02d}{now.second:02d}"

In [None]:
output_folder = env_dict["DOWNLOADS_FOLDER"]
print(f"Output files will be saved to {output_folder}")

In [None]:
ddoc = datadoc_util.DataDocumentation(docu_folder)
ddoc.loadDataDoc()

## Load all seizures dataset

In [None]:
df_events = ddoc.getEventsDf()
df_events = df_events[df_events["event_type"] == "sz"] 

In [None]:
event_traces_fpath = fh.open_file("Open .h5 file containing assembled traces for all seizures!")
print(event_traces_fpath)

In [None]:
traces_ca1 = []
traces_nc = []

uuids_ca1 = []
uuids_nc = []

session_uuids_ca1 = []
session_uuids_nc = []

recording_break_points_ca1 = []
recording_break_points_nc = []

n_bl_frames = 5000
n_am_frames = 5000

# first keys are event uuids, inside the following dataset names:
# 'lfp_mov_t', 'lfp_mov_y', 'lfp_t', 'lfp_y', 'lv_dist', 'lv_rounds', 
# 'lv_running', 'lv_speed', 'lv_t_s', 'lv_totdist', 'mean_fluo'
with h5py.File(event_traces_fpath, "r") as hf:
    for uuid in hf.keys():
        win_type = hf[uuid].attrs["window_type"]
        mean_fluo = np.array(hf[uuid]["mean_fluo"])
        assert n_bl_frames == hf[uuid].attrs["n_bl_frames"]
        assert n_am_frames == hf[uuid].attrs["n_am_frames"]
        if win_type == "Cx":
            traces_nc.append(mean_fluo)
            uuids_nc.append(uuid)
            session_uuids_nc.append(hf[uuid].attrs["session_uuids"])
            recording_break_points_nc.append(hf[uuid].attrs["recording_break_points"])
        elif win_type == "CA1":
            traces_ca1.append(mean_fluo)
            uuids_ca1.append(uuid)
            session_uuids_ca1.append(hf[uuid].attrs["session_uuids"])
            recording_break_points_ca1.append(hf[uuid].attrs["recording_break_points"])
        else:
            print(f"{win_type} not recognized window type")

## Get baseline values
Calculated as lowest 5% of data points in baseline segment. 

In [None]:
lowest_percent = 0.05  # 5% of baseline to be used 

In [None]:
# FIXME: these take lowest values of the whole traces!
#baselines_ca1 = [np.sort(traces_ca1[i][:floor(lowest_percent*n_bl_frames)]) for i in range(len(traces_ca1))]
#baselines_nc = [np.sort(traces_nc[i][:floor(lowest_percent*n_bl_frames)]) for i in range(len(traces_nc))]

baselines_ca1 = [np.min(traces_ca1[i][:n_bl_frames]) for i in range(len(traces_ca1))]
baselines_nc = [np.min(traces_nc[i][:n_bl_frames]) for i in range(len(traces_nc))]

## Get aftermath values
in 20 sec windows, get minimum value of fluorescence

### Calculate first normal frames
Use data documentation for corresponding recording

In [None]:
# ca1: need to find first segment after the "sd_extinction" segment, and find the corresponding index in the (5000 + sz + 5000) traces
first_frames_ca1 = []
rec_uuids_ca1 = []
for i_event in range(len(traces_ca1)):
    event_uuid = uuids_ca1[i_event]
    # get all segments belonging to aftermath
    df_event = df_events[(df_events["event_uuid"] == event_uuid) & (df_events["interval_type"] == "am")]
    # for all recordings contributing to aftermath, look which one contains sd_extinction
    i_frame = len(traces_ca1[i_event]) - n_am_frames  # points to first am frame right now
    next_segment_stop = False  # flag to stop on reaching next segment
    found_frame = False  # flag to mark if first frame to take was found
    am_rec_uuid = None
    for i_row, am_row in df_event.iterrows():  # loop over recordings participating in aftermath trace
        # begin and end frames of am in current recording
        am_begin_frame = am_row["begin_frame"]
        am_end_frame = am_row["end_frame"]
        # uuid of current recording
        rec_uuid = am_row["recording_uuid"]
        # get all segments after start of am
        i_first_am = ddoc.getSegmentForFrame(rec_uuid, am_begin_frame).index[0]
        i_last_am = ddoc.getSegmentForFrame(rec_uuid, am_end_frame).index[0]
        am_segments = ddoc.getSegmentsForUUID(rec_uuid).loc[i_first_am:i_last_am+1]
        am_rec_uuid = rec_uuid
        for i_segment_row, segment_row in am_segments.iterrows():
            if next_segment_stop:  # first segment after sd_extinction reached. Take this as start for baseline return observation
                break
            if segment_row["interval_type"] == "sd_extinction":
                next_segment_stop = True
            segment_length = segment_row["frame_end"] - segment_row["frame_begin"] + 1  # both inclusive -> need +1
            i_frame += segment_length
        if found_frame:
            break
    first_frames_ca1.append(i_frame)
    rec_uuids_ca1.append(am_rec_uuid)

# nc: there is no SD, so just take first am frame as it is
first_frames_nc = [len(traces_nc[i]) - n_am_frames for i in range(len(traces_nc))]


In [None]:
interval_length_seconds = 10
interval_length = 15*interval_length_seconds  # 15 Hz * 20 seconds
n_intervals = 11

In [None]:
aftermath_ca1 = [ np.array([np.min( traces_ca1[i][ first_frames_ca1[i] + j*interval_length : first_frames_ca1[i] + (j+1)*interval_length  ] )  for j in range(n_intervals)]) for i in range(len(traces_ca1)) ] 
aftermath_nc = [ np.array([np.min( traces_nc[i][ first_frames_nc[i] + j*interval_length : first_frames_nc[i] + (j+1)*interval_length  ] )  for j in range(n_intervals)]) for i in range(len(traces_nc)) ]

## Create dataframe
Columns should be: uuid, value (numeric), value_type (bl, 20s, 40s, ... 200 s)

In [None]:
#col_names = ["baseline_mean", "baseline_std"] + [f"{20*i}s" for i in range(1, n_intervals+1)]
data_dict = {"uuid": [], "value": [], "value_type": []}

# get baseline values for CA1 and NC

for i_event_ca1 in range(len(baselines_ca1)):
    uuids = [uuids_ca1[i_event_ca1]]  # only one baseline value per event
    value_types = ["bl"]
    data_dict["uuid"] += uuids
    data_dict["value"] += [baselines_ca1[i_event_ca1]]
    data_dict["value_type"] += value_types

for i_event_nc in range(len(baselines_nc)):
    uuids = [uuids_nc[i_event_nc]] # only one baseline value per event
    value_types = ["bl"]
    data_dict["uuid"] += uuids
    data_dict["value"] += [baselines_nc[i_event_nc]]
    data_dict["value_type"] += value_types        

# get 20, 40, ..., 200 s values for CA1 and NC

for i_event_ca1 in range(len(aftermath_ca1)):
    uuids = [uuids_ca1[i_event_ca1]]*len(aftermath_ca1[i_event_ca1])
    value_types = [f"{(i+1)*interval_length_seconds}s" for i in range(n_intervals)]
    assert len(uuids) == len(value_types)
    assert len(uuids) == len(aftermath_ca1[i_event_ca1])
    data_dict["uuid"] += uuids
    data_dict["value"] += list(aftermath_ca1[i_event_ca1])
    data_dict["value_type"] += value_types

for i_event_nc in range(len(aftermath_nc)):
    uuids = [uuids_nc[i_event_nc]]*len(aftermath_nc[i_event_nc])
    value_types = [f"{(i+1)*interval_length_seconds}s" for i in range(n_intervals)]
    assert len(uuids) == len(value_types)
    assert len(uuids) == len(aftermath_nc[i_event_nc])
    data_dict["uuid"] += uuids
    data_dict["value"] += list(aftermath_nc[i_event_nc])
    data_dict["value_type"] += value_types   
     

In [None]:
df = pd.DataFrame(data=data_dict)

In [None]:
fig = plt.figure(figsize=(18,18))
sns.lineplot(data=df, palette="tab10", x="value_type", y="value", hue="uuid", linewidth=2.5, legend=False)
plt.show()

In [None]:
all_bl_traces = []
all_am_traces = []
for i_tr in range(len(traces_ca1)):
    all_am_traces.append(traces_ca1[i_tr][first_frames_ca1[i_tr]:])
    all_bl_traces.append(traces_ca1[i_tr][n_bl_frames - 1000 :n_bl_frames])
for i_tr in range(len(traces_nc)):
    all_am_traces.append( traces_nc[i_tr][first_frames_nc[i_tr]:])
    all_bl_traces.append(traces_nc[i_tr][n_bl_frames - 1000 :n_bl_frames])
    
bl_x = np.array([i-len(all_bl_traces[0])+1 for i in range(len(all_bl_traces[0]))])

In [None]:
fig = plt.figure(figsize=(18,18))
for tr in all_bl_traces:
    plt.plot(bl_x, tr)
for tr in all_am_traces:
    plt.plot(tr)
plt.ylim((0, 60))
#plt.xlim((-10,10))
plt.show()

In [None]:
# TODO: lowess filter? Somehow filter this signal!