# Create waterfall plot of all optostim experiments
Including bilateral stim and ChR2 with window. Plot the locomotion traces along with the stim time point and the baseline/post-stim time windows.

In [None]:
save_figs = True  # set to True to save the figures created
save_as_eps = False
save_as_pdf = True
if save_as_pdf:
    file_format = ".pdf"
elif save_as_eps:
    file_format = ".eps"
else:
    file_format = ".jpg"
if save_figs:
    print(f"Going to save figures as {file_format} files.")

In [None]:
#Auto-reload modules (used to develop functions outside this notebook)
%load_ext autoreload
%autoreload 2

In [None]:
import labrotation.file_handling as fh
import h5py
from time import time
import matplotlib.pyplot as plt
import numpy as np
import os
from labrotation import file_handling as fh
from copy import deepcopy
import pandas as pd
import labrotation.two_photon_session as tps
import seaborn as sns
import uuid  # for unique labeling of sessions and coupling arrays (mouse velocity, distance, ...) to sessions in dataframe 
from matplotlib import cm  # colormap
import datadoc_util
from labrotation import two_photon_session as tps
from datetime import datetime
import seaborn as sns
from math import floor
import matlab.engine  # for saving data to workspace
from scipy.stats import ttest_rel
import json
from loco_functions import apply_threshold, get_episodes, calculate_avg_speed, calculate_max_speed, get_trace_delta
import matplotlib.patches as mpatches

In [None]:
save_data = True  # export results of this script?
save_sanity_check = False  # make sure to set save_figs to True as well
save_waterfall = False
save_figs = True  # set to True to save the figures created
save_as_eps = False
save_as_pdf = True
if save_as_pdf:
    file_format = ".pdf"
elif save_as_eps:
    file_format = ".eps"
else:
    file_format = ".jpg"
if save_figs:
    print(f"Going to save figures as {file_format} files.")


# Set seaborn parameters

In [None]:
sns.set(font_scale=3)
sns.set_style("whitegrid")

# If exists, load environmental variables from .env file

In [None]:
env_dict = dict()
if not os.path.exists("./.env"):
    print(".env does not exist")
else:
    with open("./.env", "r") as f:
        for line in f.readlines():
            l = line.rstrip().split("=")
            env_dict[l[0]] = l[1]
print(env_dict.keys())

# Set up data documentation directory

In [None]:
# assumption: inside the documentation folder, the subfolders carry the id of each mouse (not exact necessarily, but they 
# can be identified by the name of the subfolder). 
# Inside the subfolder xy (for mouse xy), xy_grouping.xlsx and xy_segmentation.xlsx can be found.
# xy_grouping.xlsx serves the purpose of finding the recordings belonging together, and has columns:
# folder, nd2, labview, lfp, face_cam_last, nikon_meta, experiment_type, day
# xy_segmentation.xlsx contains frame-by-frame (given by a set of disjoint intervals forming a cover for the whole recording) 
# classification of the events in the recording ("normal", seizure ("sz"), sd wave ("sd_wave") etc.). The columns:
# folder, interval_type, frame_begin, frame_end.

# TODO: write documentation on contents of xlsx files (what the columns are etc.)
if "DATA_DOCU_FOLDER" in env_dict.keys():
    docu_folder = env_dict["DATA_DOCU_FOLDER"]
else:
    docu_folder = fh.open_dir("Choose folder containing folders for each mouse!")
print(f"Selected folder:\n\t{docu_folder}")

In [None]:
if "documentation" in os.listdir(docu_folder):
    mouse_folder = os.path.join(docu_folder, "documentation")
else:
    mouse_folder = docu_folder
mouse_names = os.listdir(mouse_folder)
print(f"Mice detected:")
for mouse in mouse_names:
    print(f"\t{mouse}")

In [None]:
def get_datetime_for_fname():
    now = datetime.now()
    return f"{now.year:04d}{now.month:02d}{now.day:02d}-{now.hour:02d}{now.minute:02d}{now.second:02d}"

In [None]:
output_folder = env_dict["DOWNLOADS_FOLDER"]
print(f"Output files will be saved to {output_folder}")

## Set a uniform datetime string for output files

In [None]:
output_dtime = get_datetime_for_fname()

### Load data documentation

In [None]:
ddoc = datadoc_util.DataDocumentation(docu_folder)
ddoc.loadDataDoc()

### Set up color coding
for now, only possible to assign a color to each mouse. Later, when event uuids available, need to map event uuid to color code

In [None]:
df_colors = ddoc.getColorings()

In [None]:
dict_colors_mouse = df_colors[["mouse_id", "color"]].to_dict(orient="list")
dict_colors_mouse = dict(zip(dict_colors_mouse["mouse_id"], dict_colors_mouse["color"]))

In [None]:
#dict_colors_mouse["T413"] = "#000000"  # set one to black

### Load events_list dataset

In [None]:
events_list_fpath = os.path.join(docu_folder, "events_list.xlsx")
assert os.path.exists(events_list_fpath)

df_events_list = pd.read_excel(events_list_fpath)

In [None]:
assembled_traces_bilat_fpath = fh.open_file("Open assembled_traces for bilat stim h5 file!")
print(assembled_traces_bilat_fpath)

In [None]:
assembled_traces_chr2_fpath = fh.open_file("Open assembled_traces chr2 with window h5 file!")
print(assembled_traces_chr2_fpath)

In [None]:
used_mouse_ids_chr2win = ["OPI-2239", "WEZ-8917", "WEZ-8924", "WEZ-8922"]  # bilat stim mice
used_mouse_ids_bilat = ["WEZ-8946", "WEZ-8960", "WEZ-8961"]  # chr2 + win mice
used_mouse_ids = used_mouse_ids_bilat.copy()
used_mouse_ids.extend(used_mouse_ids_chr2win)  # add chr2 window mice

In [None]:
traces_dict = dict()  
traces_meta_dict = dict()
# first keys are event uuids, inside the following dataset names:
# 'lfp_mov_t', 'lfp_mov_y', 'lfp_t', 'lfp_y', 'lv_dist', 'lv_rounds', 
# 'lv_running', 'lv_speed', 'lv_t_s', 'lv_totdist', 'mean_fluo'
for assembled_traces_fpath in [assembled_traces_bilat_fpath, assembled_traces_chr2_fpath]:
    with h5py.File(assembled_traces_fpath, "r") as hf:
        for uuid in hf.keys():
            if hf[uuid].attrs["mouse_id"] in used_mouse_ids:
                session_dataset_dict = dict() 
                session_meta_dict = dict()
                for dataset_name in hf[uuid].keys():
                    session_dataset_dict[dataset_name] = np.array(hf[uuid][dataset_name])
                for attr_name in hf[uuid].attrs:
                    session_meta_dict[attr_name] = hf[uuid].attrs[attr_name]
                traces_dict[uuid] = session_dataset_dict.copy()
                traces_meta_dict[uuid] = session_meta_dict.copy()

In [None]:
min_speed = np.inf
max_speed = -np.inf
for event_uuid in traces_dict.keys():
    speed = traces_dict[event_uuid]["lv_speed"]
    min_candidate = np.min(speed)
    max_candidate = np.max(speed)
    if min_candidate < min_speed:
        min_speed = min_candidate
    if max_candidate > max_speed:
        max_speed = max_candidate
print(f"Speed range: {min_speed} to {max_speed}")
LV_SPEED_AMPL = max_speed - min_speed

In [None]:
# for tmev and chr2: 4500 bl/post-sz frames, bilat stim: 4425, should match the value that was used in Loco analysis 3.0
n_segment_frames_chr2win = 4500
n_segment_frames_bilat = 4425

### Unify categories
* Make single control
* Make single sz+sd
* Keep the uni-/bilateral SD
* Make window SD experiments unilateral (for now)

In [None]:
# TODO: properly handle SD experiments with window! Maybe some of them are bilateral SD
for uuid in traces_meta_dict.keys():
    exp_type = traces_meta_dict[uuid]["exp_type"]
    if "chr2_ctl" in exp_type:  # unify control naming
        traces_meta_dict[uuid]["exp_type"] = "chr2_ctl"
    elif "chr2_szsd" in exp_type:  # unify szsd naming
        traces_meta_dict[uuid]["exp_type"] = "chr2_szsd"
    elif exp_type == "chr2_sd":  # window experiments do not specify stim is unilateral
        traces_meta_dict[uuid]["exp_type"] = "chr2_sd_unilat"

In [None]:
#for uuid in traces_meta_dict.keys():
#    print(traces_meta_dict[uuid]["exp_type"])

In [None]:
exptype_wintype_id_dict = {}   # keys: experiment_type, window_type, mouse_id, value: [uuid1, uuid2, ...]
for uuid in traces_meta_dict.keys():
    exp_type = traces_meta_dict[uuid]["exp_type"]
    win_type = traces_meta_dict[uuid]["window_type"]
    mouse_id = traces_meta_dict[uuid]["mouse_id"]
    if exp_type not in exptype_wintype_id_dict.keys():
        exptype_wintype_id_dict[exp_type] = dict()
    if win_type not in exptype_wintype_id_dict[exp_type].keys():
        exptype_wintype_id_dict[exp_type][win_type] = dict()
    if mouse_id not in exptype_wintype_id_dict[exp_type][win_type].keys():
        exptype_wintype_id_dict[exp_type][win_type][mouse_id] = []  # list of uuids
    exptype_wintype_id_dict[exp_type][win_type][mouse_id].append(uuid)

In [None]:
for uuid in traces_meta_dict:
    assert "n_bl_frames" in traces_meta_dict[uuid]
    assert "n_am_frames" in traces_meta_dict[uuid]
    assert "n_frames" in traces_meta_dict[uuid]
    assert "i_stim_begin_frame" in traces_meta_dict[uuid]



In [None]:
ddoc.getSegmentsForUUID("0708b5892bf4459ca1aeed2d317efe19")

In [None]:
def waterfallLoco(exp_type, show_segments=False, bl_equal_post=True, show_stim_duration=False, show_legend=False, lims=None):
    AMPLITUDE = LV_SPEED_AMPL
    offset = 0
    n_recordings_with_type =  0
    for win_type in exptype_wintype_id_dict[exp_type].keys():
        for mouse_id in exptype_wintype_id_dict[exp_type][win_type].keys():
            for event_uuid in exptype_wintype_id_dict[exp_type][win_type][mouse_id]:
                n_recordings_with_type += 1
    fig = plt.figure(figsize=(18,n_recordings_with_type*3))
    mouse_ids = traces_dict.keys()
    prev_range = 0.0
    appearing_mice = []
    for win_type in exptype_wintype_id_dict[exp_type].keys():
        for mouse_id in exptype_wintype_id_dict[exp_type][win_type].keys(): 
            if mouse_id in used_mouse_ids_bilat:
                mouse_type = "bilat"
            elif mouse_id in used_mouse_ids_chr2win:
                mouse_type = "chr2win"
            else:
                raise Exception("Mouse neither bilat nor chr2win")
            if mouse_id not in appearing_mice:
                appearing_mice.append(mouse_id)
            for event_uuid in exptype_wintype_id_dict[exp_type][win_type][mouse_id]:
                metadata_dict = traces_meta_dict[event_uuid]
                if exp_type == "chr2win":
                    i_frame_stim_begin = metadata_dict["i_stim_begin_frame"]
                else:  # there is an issue with bilat recordings; use break points: [0, stim_begin, ...]
                    i_frame_stim_begin = metadata_dict["break_points"][1]
                # FIXME: this is not stim end, but the beginning of the post segment
                stim_duration = ddoc.getStimDurationForUuid(event_uuid)
                t = traces_dict[event_uuid]["lv_t_s"]
                if mouse_type == "bilat":
                    i_frame_stim_end = np.searchsorted(t, t[i_frame_stim_begin] + stim_duration)
                else:
                    i_frame_stim_end = metadata_dict["i_stim_end_frame"]
                # some recordings have a slight difference in calculated stim (t[] - t[]) vs stim duration written in session description.
                #assert abs(t[i_frame_stim_end] - t[i_frame_stim_begin] - stim_duration) < 0.1
                #if not abs(t[i_frame_stim_end] - t[i_frame_stim_begin] - stim_duration) < 0.1:
                #    print(event_uuid)
                #    print(f"{t[i_frame_stim_end] - t[i_frame_stim_begin]}, stim_duration: {stim_duration}")
                # TODO: make sure that begin of post segment is correct! (not always post-stim! )
                n_bl_frames = n_segment_frames_bilat if mouse_type=="bilat" else n_segment_frames_chr2win

                
                t = t - t[i_frame_stim_begin]
                
                if show_segments:
                    # get begin and end time points of baseline and post-stim segments 
                    i_frame_bl_end = i_frame_stim_begin
                    if mouse_type == "chr2win":
                        i_frame_post_begin = metadata_dict["break_points"][-1]
                    else:
                        n_frames = traces_meta_dict[event_uuid]["n_frames"]
                        n_am_frames = traces_meta_dict[event_uuid]["n_am_frames"]
                        i_frame_post_begin = n_frames - n_am_frames  #i_frame_stim_end+1


                        

                    assert n_bl_frames < i_frame_stim_begin
                    i_frame_bl_begin = i_frame_bl_end - n_bl_frames 
                    


                    if bl_equal_post:
                        i_frame_post_end = i_frame_post_begin + n_bl_frames
                    if not bl_equal_post or i_frame_post_end >= len(t):
                        i_frame_post_end = len(t) - 1
                    # plot them
                    begin_end_frames = [i_frame_bl_begin, i_frame_bl_end, i_frame_post_begin, i_frame_post_end]
                    plt.vlines(x=t[begin_end_frames], ymin = offset, ymax = offset+AMPLITUDE, color="black", linewidth=2)
                if show_stim_duration:
                    plt.vlines(x=[t[i_frame_stim_begin], t[i_frame_stim_end]], ymin=offset, ymax=offset+0.7*AMPLITUDE, color="red", linewidth=1)
                    plt.text(t[i_frame_stim_end]+15, offset+0.7*AMPLITUDE, f"{stim_duration} s", fontsize=20, color="red")
                    #plt.text(t[i_frame_bl_begin], offset+0.7*AMPLITUDE, f"{event_uuid}", fontsize=20, color="red")

                
                labview_trace = traces_dict[event_uuid]["lv_speed"]
                min_lv = min(labview_trace)
                max_lv = max(labview_trace)
                color = df_colors[df_colors["mouse_id"] == mouse_id].color.iloc[0]
                plt.plot(t, labview_trace - min_lv+offset, color=color)
                
                offset +=1.3*AMPLITUDE
    if show_legend:
        patches=[mpatches.Patch(color=ddoc.getColorForMouseId(mouse_id), label=mouse_id) for mouse_id in sorted(appearing_mice)]
        plt.legend(handles=patches)
    plt.suptitle(exp_type, fontsize=22)
    #plt.axis("off")
    plt.yticks([])
    plt.xlabel("Time (s)", fontsize=14)
    plt.tight_layout()
    if lims is not None:
        plt.xlim(lims)  # 250, 500
    if save_figs:
        out_fpath = f"D:\\Downloads\\loco_waterfall_{exp_type}_{output_dtime}{file_format}"
        plt.savefig(out_fpath,bbox_inches='tight', dpi=300)
        print(f"Saved as {out_fpath}")
    plt.show()

In [None]:
traces_meta_dict["77c076ebfc5543ea93a9c0b2ba9e8b8c"]

In [None]:
cats = []
for event_uuid in traces_meta_dict:
    exp_type = traces_meta_dict[event_uuid]["exp_type"]
    if exp_type not in cats:
        cats.append(exp_type)

In [None]:
print(cats)

In [None]:
waterfallLoco("chr2_sd_bilat", show_segments=True, bl_equal_post=True, show_stim_duration=True, show_legend=True, lims=(-305,350))


In [None]:
# TODO: maybe need to redo analysis: instead of n frames, get n seconds before stim, and calculate back the number of frames needed for this