# Locomotion analysis v3.0
This script works with the pre-assembled dataset, cutting out most of the acquiring data part, resulting in cleaner and more reliable analysis (the dataset is manually checked separately from this script).

# Set version

### Set up export figure parameters

In [None]:
save_figs = False  # set to True to save the figures created
save_as_eps = False
if save_as_eps:
    file_format = ".eps"
else:
    file_format = ".jpg"
if save_figs:
    print(f"Going to save figures as {file_format} files.")

In [None]:
output_version = "v1.0"

# Import libraries

In [None]:
#Auto-reload modules (used to develop functions outside this notebook)
%load_ext autoreload
%autoreload 2

In [None]:
import labrotation.file_handling as fh
import h5py
from time import time
import matplotlib.pyplot as plt
import numpy as np
import os
from labrotation import file_handling as fh
from copy import deepcopy
import pandas as pd
import labrotation.two_photon_session as tps
import seaborn as sns
import uuid  # for unique labeling of sessions and coupling arrays (mouse velocity, distance, ...) to sessions in dataframe 
from matplotlib import cm  # colormap
import datadoc_util
from labrotation import two_photon_session as tps
from datetime import datetime
import seaborn as sns

# Set seaborn parameters

In [None]:
sns.set(font_scale=2)
sns.set_style("whitegrid")

# If exists, load environmental variables from .env file

In [None]:
env_dict = dict()
if not os.path.exists("./.env"):
    print(".env does not exist")
else:
    with open("./.env", "r") as f:
        for line in f.readlines():
            l = line.rstrip().split("=")
            env_dict[l[0]] = l[1]
print(env_dict.keys())

# Set up data documentation directory

In [None]:
# assumption: inside the documentation folder, the subfolders carry the id of each mouse (not exact necessarily, but they 
# can be identified by the name of the subfolder). 
# Inside the subfolder xy (for mouse xy), xy_grouping.xlsx and xy_segmentation.xlsx can be found.
# xy_grouping.xlsx serves the purpose of finding the recordings belonging together, and has columns:
# folder, nd2, labview, lfp, face_cam_last, nikon_meta, experiment_type, day
# xy_segmentation.xlsx contains frame-by-frame (given by a set of disjoint intervals forming a cover for the whole recording) 
# classification of the events in the recording ("normal", seizure ("sz"), sd wave ("sd_wave") etc.). The columns:
# folder, interval_type, frame_begin, frame_end.

# TODO: write documentation on contents of xlsx files (what the columns are etc.)
if "DATA_DOCU_FOLDER" in env_dict.keys():
    docu_folder = env_dict["DATA_DOCU_FOLDER"]
else:
    docu_folder = fh.open_dir("Choose folder containing folders for each mouse!")
print(f"Selected folder:\n\t{docu_folder}")

In [None]:
if "documentation" in os.listdir(docu_folder):
    mouse_folder = os.path.join(docu_folder, "documentation")
else:
    mouse_folder = docu_folder
mouse_names = os.listdir(mouse_folder)
print(f"Mice detected:")
for mouse in mouse_names:
    print(f"\t{mouse}")

In [None]:
def get_datetime_for_fname():
    now = datetime.now()
    return f"{now.year:04d}{now.month:02d}{now.day:02d}-{now.hour:02d}{now.minute:02d}{now.second:02d}"

In [None]:
output_folder = env_dict["DOWNLOADS_FOLDER"]
print(f"Output files will be saved to {output_folder}")

### Load matlab-2p

In [None]:
if "MATLAB_2P_FOLDER" in env_dict.keys():
    matlab_2p_folder = env_dict["MATLAB_2P_FOLDER"]
else:
    matlab_2p_folder = fh.open_dir("Choose matlab-2p folder")
print(f"matlab-2p folder set to:\n\t{matlab_2p_folder}")

### Load data documentation

In [None]:
ddoc = datadoc_util.DataDocumentation(docu_folder)
ddoc.loadDataDoc()

### Set up color coding
for now, only possible to assign a color to each mouse. Later, when event uuids available, need to map event uuid to color code

In [None]:
df_colors = ddoc.getColorings()

In [None]:
dict_colors_mouse = df_colors[["mouse_id", "color"]].to_dict(orient="list")
dict_colors_mouse = dict(zip(dict_colors_mouse["mouse_id"], dict_colors_mouse["color"]))

### Load events_list dataset

In [None]:
events_list_fpath = os.path.join(docu_folder, "events_list.xlsx")
assert os.path.exists(events_list_fpath)

df_events_list = pd.read_excel(events_list_fpath)

## Load dataset

In [None]:
assembled_traces_fpath = fh.open_file("Open assembled_traces h5 file!")

In [None]:
traces_dict = dict()  
traces_meta_dict = dict()
# first keys are event uuids, inside the following dataset names:
# 'lfp_mov_t', 'lfp_mov_y', 'lfp_t', 'lfp_y', 'lv_dist', 'lv_rounds', 
# 'lv_running', 'lv_speed', 'lv_t_s', 'lv_totdist', 'mean_fluo'
with h5py.File(assembled_traces_fpath, "r") as hf:
    for uuid in hf.keys():
        session_dataset_dict = dict() 
        session_meta_dict = dict()
        for dataset_name in hf[uuid].keys():
            session_dataset_dict[dataset_name] = np.array(hf[uuid][dataset_name])
        for attr_name in hf[uuid].attrs:
            session_meta_dict[attr_name] = hf[uuid].attrs[attr_name]
        traces_dict[uuid] = session_dataset_dict.copy()
        traces_meta_dict[uuid] = session_meta_dict.copy()

# Calculate locomotion statistics

In [None]:
# each entry (row) should have columns: 
# uuid of event, mouse id, window type, segment type (bl/sz/am), segment length in frames, totdist, running, speed
list_statistics = []  
dict_episodes = {}

for event_uuid in traces_dict.keys():
    mouse_id = traces_meta_dict[event_uuid]["mouse_id"]
    win_type = traces_meta_dict[event_uuid]["window_type"]
    # get segment lengths
    n_bl_frames = traces_meta_dict[event_uuid]["n_bl_frames"]
    n_am_frames = traces_meta_dict[event_uuid]["n_am_frames"]
    n_frames = traces_meta_dict[event_uuid]["n_frames"]
    n_sz_frames = n_frames - n_am_frames - n_bl_frames
    # get movement data
    lv_totdist = traces_dict[event_uuid]["lv_totdist"]
    lv_totdist_abs = traces_dict[event_uuid]["lv_totdist_abs"]
    lv_running = traces_dict[event_uuid]["lv_running"]
    lv_speed = traces_dict[event_uuid]["lv_speed"]
    # cut up data into segments
    lv_totdist_bl = lv_totdist[:n_bl_frames]
    lv_totdist_sz = lv_totdist[n_bl_frames:n_bl_frames+n_sz_frames]
    lv_totdist_am = lv_totdist[n_bl_frames+n_sz_frames:]
    lv_totdist_abs_bl = lv_totdist_abs[:n_bl_frames]
    lv_totdist_abs_sz = lv_totdist_abs[n_bl_frames:n_bl_frames+n_sz_frames]
    lv_totdist_abs_am = lv_totdist_abs[n_bl_frames+n_sz_frames:]
    lv_running_bl = lv_running[:n_bl_frames]
    lv_running_sz = lv_running[n_bl_frames:n_bl_frames+n_sz_frames]
    lv_running_am = lv_running[n_bl_frames+n_sz_frames:]
    lv_speed_bl = lv_speed[:n_bl_frames]
    lv_speed_sz = lv_speed[n_bl_frames:n_bl_frames+n_sz_frames]
    lv_speed_am = lv_speed[n_bl_frames+n_sz_frames:]
    # calculate statistics
    totdist_bl = lv_totdist_bl[-1] - lv_totdist_bl[0]
    totdist_sz = lv_totdist_sz[-1] - lv_totdist_sz[0]
    totdist_am = lv_totdist_am[-1] - lv_totdist_am[0]
    totdist_abs_bl = lv_totdist_abs_bl[-1] - lv_totdist_abs_bl[0]
    totdist_abs_sz = lv_totdist_abs_sz[-1] - lv_totdist_abs_sz[0]
    totdist_abs_am = lv_totdist_abs_am[-1] - lv_totdist_abs_am[0]
    running_bl = sum(lv_running_bl)
    running_sz = sum(lv_running_sz)
    running_am = sum(lv_running_am)
    speed_bl = sum(lv_speed_bl)
    speed_sz = sum(lv_speed_sz)
    speed_am = sum(lv_speed_am)
    
    # number of running episodes, length
    list_episodes_bl = []
    list_episodes_sz = []
    list_episodes_am = []
    
    n_episodes_bl = 0
    current_episode_len = 0
    for i_frame in range(len(lv_running_bl)-1):  # check current and next element for end of a episode: ...100...
        if lv_running_bl[i_frame] == 1:  # current frame is part of an episode
            # increase current episode length
            current_episode_len += 1
            if lv_running_bl[i_frame+1] == 0: # episode ends with next frame
                n_episodes_bl += 1
                list_episodes_bl.append(current_episode_len)
                current_episode_len = 0
    if lv_running_bl[-1] == 1: # check if there is one episode that does not end
        n_episodes_bl += 1
        # add last segment to segments list
        current_episode_len += 1
        list_episodes_bl.append(current_episode_len)
        current_episode_len = 0
        
    assert current_episode_len == 0
    n_episodes_sz = 0
    for i_frame in range(len(lv_running_sz)-1):  # check current and next element for end of a episode: ...100...
        if lv_running_sz[i_frame] == 1: 
            current_episode_len += 1
            if lv_running_sz[i_frame+1] == 0:
                n_episodes_sz += 1
                list_episodes_sz.append(current_episode_len)
                current_episode_len = 0
    if lv_running_sz[-1] == 1: # check if there is one episode that does not end
        n_episodes_sz += 1
        # add last segment to segments list
        current_episode_len += 1
        list_episodes_sz.append(current_episode_len)
        current_episode_len = 0
    
    assert current_episode_len == 0
    n_episodes_am = 0
    for i_frame in range(len(lv_running_am)-1):  # check current and next element for end of a episode: ...100...
        if lv_running_am[i_frame] == 1: 
            current_episode_len += 1
            if lv_running_am[i_frame+1] == 0:
                n_episodes_am += 1
    if lv_running_am[-1] == 1: # check if there is one episode that does not end
        n_episodes_am += 1
        # add last segment to segments list
        current_episode_len += 1
        list_episodes_am.append(current_episode_len)
        current_episode_len = 0
    
    # add to episodes dict
    if mouse_id not in dict_episodes.keys():
        dict_episodes[mouse_id] = dict()
    dict_episodes[mouse_id][event_uuid] = dict()
    
    list_episodes_bl = np.array(list_episodes_bl)
    list_episodes_sz = np.array(list_episodes_sz)
    list_episodes_am = np.array(list_episodes_am)
    
    dict_episodes[mouse_id][event_uuid]["bl"] = list_episodes_bl
    dict_episodes[mouse_id][event_uuid]["sz"] = list_episodes_sz
    dict_episodes[mouse_id][event_uuid]["am"] = list_episodes_am
    
    # calculate mean episode length, std
    bl_episode_mean_len = list_episodes_bl.mean()
    sz_episode_mean_len = list_episodes_sz.mean()
    am_episode_mean_len = list_episodes_am.mean()
    
    bl_episode_std = list_episodes_bl.std()
    sz_episode_std = list_episodes_sz.std()
    am_episode_std = list_episodes_am.std()
    
    
    # add to data list
    list_statistics.append([event_uuid, mouse_id, win_type, "bl", n_bl_frames, totdist_bl, totdist_abs_bl, running_bl, speed_bl, n_episodes_bl, bl_episode_mean_len, bl_episode_std])
    list_statistics.append([event_uuid, mouse_id, win_type, "sz", n_sz_frames, totdist_sz, totdist_abs_sz, running_sz, speed_sz, n_episodes_sz, sz_episode_mean_len, sz_episode_std])
    list_statistics.append([event_uuid, mouse_id, win_type, "am", n_am_frames, totdist_am, totdist_abs_am, running_am, speed_am, n_episodes_am, am_episode_mean_len, am_episode_std])

In [None]:
df_stats = pd.DataFrame(data=list_statistics, columns=["event_uuid", "mouse_id", "window_type", "segment_type",  "segment_length", "totdist", "totdist_abs", "running", "speed", "running_episodes", "running_episodes_mean_length", "running_episodes_length_std"])

In [None]:
df_stats

In [None]:
# set NaN to 0 (running_episodes_mean_length: if no episodes, then mean segment length is 0)
df_stats["running_episodes_mean_length"] = df_stats["running_episodes_mean_length"].fillna(value=0)

In [None]:
df_stats["totdist_norm"] = df_stats["totdist"]/df_stats["segment_length"]
df_stats["totdist_abs_norm"] = df_stats["totdist_abs"]/df_stats["segment_length"]
df_stats["running_norm"] = df_stats["running"]/df_stats["segment_length"]
df_stats["speed_norm"] = df_stats["speed"]/df_stats["segment_length"]

### Add % of time spent running

In [None]:
# % of time spent running
df_stats["running%"] = 100*df_stats["running"]/df_stats["segment_length"]  # get value as true % instead of [0, 1] float

### Add  color codes to entries

In [None]:
df_stats["color"] = df_stats.apply(lambda row: dict_colors_mouse[row["mouse_id"]], axis=1)

In [None]:
dict_colors_event = df_stats[["event_uuid", "color"]].to_dict(orient="list")
dict_colors_event = dict(zip(dict_colors_event["event_uuid"], dict_colors_event["color"]))

In [None]:
dict_colors_event

# Plot results

In [None]:
fig = plt.figure(figsize=(10,10))
sns.violinplot(x="segment_type", y="running_norm", data=df_stats)
#sns.stripplot(data=df_stats[df_stats["window_type"]=="CA1"], x="speed_norm", y="segment_type", hue="mouse_id", dodge=True, zorder=1, legend=False)
plt.show()

In [None]:
loco_statistic = "totdist_abs"

fig, axs = plt.subplots(2, 3, figsize=(22,12), sharey=False)
sns.pointplot(data=df_stats, x="segment_type", y=loco_statistic, ax=axs[0][0])
sns.pointplot(data=df_stats, x="segment_type", y="running", ax=axs[0][1])
sns.pointplot(data=df_stats, x="segment_type", y="speed", ax=axs[0][2])

sns.pointplot(data=df_stats[df_stats["segment_type"] == "bl"], x="window_type", y=loco_statistic, join=False, capsize=.1, ax=axs[1][0])
sns.pointplot(data=df_stats[df_stats["segment_type"] == "sz"], x="window_type", y=loco_statistic, join=False, capsize=.1, ax=axs[1][1])
sns.pointplot(data=df_stats[df_stats["segment_type"] == "am"], x="window_type", y=loco_statistic, join=False, capsize=.1, ax=axs[1][2])
if save_figs:
    fig_fpath = os.path.join(output_folder, f'loco_mean_per_segment_{get_datetime_for_fname()}_{output_version}{file_format}')
    plt.savefig(fig_fpath, format=file_format.split(".")[-1])
    print(f"Saved to {fig_fpath}")
plt.show()

In [None]:
loco_statistic = "totdist_abs"

fig, axs = plt.subplots(2, 3, figsize=(22,12))
sns.lineplot(data=df_stats, x="segment_type", y=loco_statistic, hue="event_uuid", palette=dict_colors_event, ax=axs[0][0], legend=False)
sns.lineplot(data=df_stats, x="segment_type", y="running", hue="event_uuid", palette=dict_colors_event, ax=axs[0][1], legend=False)
sns.lineplot(data=df_stats, x="segment_type", y="speed", hue="event_uuid", palette=dict_colors_event, ax=axs[0][2], legend=False)


sns.stripplot(data=df_stats[df_stats["segment_type"] == "bl"], x="window_type", hue="event_uuid", palette=dict_colors_event, y=loco_statistic, size=8, ax=axs[1][0], legend=False)
sns.stripplot(data=df_stats[df_stats["segment_type"] == "sz"], x="window_type", hue="event_uuid", palette=dict_colors_event, y=loco_statistic, size=8, ax=axs[1][1], legend=False)
sns.stripplot(data=df_stats[df_stats["segment_type"] == "am"], x="window_type", hue="event_uuid", palette=dict_colors_event, y=loco_statistic, size=8, ax=axs[1][2], legend=False)

    
sns.violinplot(
    data=df_stats[df_stats["segment_type"] == "bl"],
    x="window_type", y=loco_statistic, 
    hue_order=[True, False], split=True,
    hue=True,
    palette=["lightgrey"],
    ax=axs[1][0]
)
axs[1][0].legend_=None

sns.violinplot(
    data=df_stats[df_stats["segment_type"] == "sz"],
    x="window_type", y=loco_statistic, 
    hue_order=[True, False], split=True,
    hue=True,
    palette=["lightgrey"],
    ax=axs[1][1]
)
axs[1][1].legend_=None

sns.violinplot(
    data=df_stats[df_stats["segment_type"] == "am"],
    x="window_type", y=loco_statistic, 
    hue_order=[True, False], split=True,
    hue=True,
    palette=["lightgrey"],
    ax=axs[1][2]
)
axs[1][2].legend_=None


axs[1][0].set_title("baseline")
axs[1][1].set_title("seizure")
axs[1][2].set_title("aftermath")


#axs[1][0].set_alpha(0.5)
# found violin plot surfaces by trial and error below, for making them transparent:
plt.setp(axs[1][0].collections[-2], alpha=.3)
plt.setp(axs[1][0].collections[-4], alpha=.3)
plt.setp(axs[1][1].collections[-2], alpha=.3)
plt.setp(axs[1][1].collections[-4], alpha=.3)
plt.setp(axs[1][2].collections[-2], alpha=.3)
plt.setp(axs[1][2].collections[-4], alpha=.3)

plt.tight_layout()

if save_figs:
    fig_fpath = os.path.join(output_folder, f'loco_per_segment_{get_datetime_for_fname()}_{output_version}{file_format}')
    plt.savefig(fig_fpath, format=file_format.split(".")[-1])
    print(f"Saved to {fig_fpath}")
    


plt.show()

## Plot individual recordings, color-coded by mouse ID

### Plot all possible metrics

In [None]:
fig, axs = plt.subplots(2, 3, figsize=(22,16))
sns.lineplot(data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y="totdist_abs", hue="event_uuid", palette=dict_colors_event, estimator=None, ax=axs[0][0], legend=False)
sns.lineplot(data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y="running", hue="event_uuid", palette=dict_colors_event, estimator=None, ax=axs[0][1], legend=False)
sns.lineplot(data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y="speed", hue="event_uuid", palette=dict_colors_event, estimator=None, ax=axs[0][2], legend=False)
sns.lineplot(data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y="running%", hue="event_uuid", palette=dict_colors_event, estimator=None, ax=axs[1][0], legend=False)
sns.lineplot(data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y="running_episodes", hue="event_uuid", palette=dict_colors_event, estimator=None, ax=axs[1][1], legend=False)
sns.lineplot(data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y="running_episodes_mean_length", hue="event_uuid", palette=dict_colors_event, estimator=None, ax=axs[1][2], legend=False)



plt.tight_layout()

if save_figs:
    fig_fpath = os.path.join(output_folder, f'loco_per_segment_all_sources_sz_excluded_{get_datetime_for_fname()}_{output_version}{file_format}')
    plt.savefig(fig_fpath, format=file_format.split(".")[-1])
    print(f"Saved to {fig_fpath}")
plt.show()

### Plot 3 metrics along with individual points, violin plot

In [None]:
loco_statistic = "totdist_abs"

fig, axs = plt.subplots(2, 3, figsize=(22,12))
sns.lineplot(data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y=loco_statistic, hue="event_uuid", palette=dict_colors_event, estimator=None, ax=axs[0][0], legend=False)
sns.lineplot(data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y="running", hue="event_uuid", palette=dict_colors_event, estimator=None, ax=axs[0][1], legend=False)
sns.lineplot(data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y="speed", hue="event_uuid", palette=dict_colors_event, estimator=None, ax=axs[0][2], legend=False)


sns.stripplot(data=df_stats[df_stats["segment_type"] == "bl"], x="window_type", y=loco_statistic, hue="event_uuid", palette=dict_colors_event, size=8, ax=axs[1][0], legend=False)
sns.stripplot(data=df_stats[df_stats["segment_type"] == "sz"], x="window_type", y=loco_statistic, hue="event_uuid", palette=dict_colors_event, size=8, ax=axs[1][1], legend=False)
sns.stripplot(data=df_stats[df_stats["segment_type"] == "am"], x="window_type", y=loco_statistic, hue="event_uuid", palette=dict_colors_event, size=8, ax=axs[1][2], legend=False)


sns.violinplot(
    data=df_stats[df_stats["segment_type"] == "bl"],
    x="window_type", y=loco_statistic, 
    hue_order=[True, False], split=True,
    hue=True,
    palette=["lightgrey"],
    ax=axs[1][0]
)
axs[1][0].legend_=None

sns.violinplot(
    data=df_stats[df_stats["segment_type"] == "sz"],
    x="window_type", y=loco_statistic, 
    hue_order=[True, False], split=True,
    hue=True,
    palette=["lightgrey"],
    ax=axs[1][1]
)
axs[1][1].legend_=None

sns.violinplot(
    data=df_stats[df_stats["segment_type"] == "am"],
    x="window_type", y=loco_statistic, 
    hue_order=[True, False], split=True,
    hue=True,
    palette=["lightgrey"],
    ax=axs[1][2]
)
axs[1][2].legend_=None

#axs[1][0].set_alpha(0.5)
# found violin plot surfaces by trial and error below, for making them transparent:
plt.setp(axs[1][0].collections[-2], alpha=.3)
plt.setp(axs[1][0].collections[-4], alpha=.3)
plt.setp(axs[1][1].collections[-2], alpha=.3)
plt.setp(axs[1][1].collections[-4], alpha=.3)
plt.setp(axs[1][2].collections[-2], alpha=.3)
plt.setp(axs[1][2].collections[-4], alpha=.3)


axs[1][0].set_title("baseline")
axs[1][1].set_title("seizure")
axs[1][2].set_title("aftermath")

plt.tight_layout()

if save_figs:
    fig_fpath = os.path.join(output_folder, f'loco_per_segment_sz-excluded_{get_datetime_for_fname()}_{output_version}{file_format}')
    plt.savefig(fig_fpath, format=file_format.split(".")[-1])
    print(f"Saved to {fig_fpath}")
plt.show()

## Aggregate by mouse
estimator='mean', errorbar=('ci', 95) are the default statistics

In [None]:
loco_statistic = "totdist_abs"

fig, axs = plt.subplots(2, 3, figsize=(22,12))
sns.lineplot(data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y=loco_statistic, hue="mouse_id", palette=dict_colors_mouse,  ax=axs[0][0], legend=False)
sns.lineplot(data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y="running", hue="mouse_id", palette=dict_colors_mouse, ax=axs[0][1], legend=False)
sns.lineplot(data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y="speed", hue="mouse_id", palette=dict_colors_mouse,  ax=axs[0][2], legend=False)


sns.stripplot(data=df_stats[df_stats["segment_type"] == "bl"], x="window_type", y=loco_statistic, hue="event_uuid", palette=dict_colors_event, size=8, ax=axs[1][0], legend=False)
sns.stripplot(data=df_stats[df_stats["segment_type"] == "sz"], x="window_type", y=loco_statistic, hue="event_uuid", palette=dict_colors_event, size=8, ax=axs[1][1], legend=False)
sns.stripplot(data=df_stats[df_stats["segment_type"] == "am"], x="window_type", y=loco_statistic, hue="event_uuid", palette=dict_colors_event, size=8, ax=axs[1][2], legend=False)
    
    
sns.violinplot(
    data=df_stats[df_stats["segment_type"] == "bl"],
    x="window_type", y=loco_statistic, 
    hue_order=[True, False], split=True,
    hue=True,
    palette=["lightgrey"],
    ax=axs[1][0]
)
axs[1][0].legend_=None

sns.violinplot(
    data=df_stats[df_stats["segment_type"] == "sz"],
    x="window_type", y=loco_statistic, 
    hue_order=[True, False], split=True,
    hue=True,
    palette=["lightgrey"],
    ax=axs[1][1]
)
axs[1][1].legend_=None

sns.violinplot(
    data=df_stats[df_stats["segment_type"] == "am"],
    x="window_type", y=loco_statistic, 
    hue_order=[True, False], split=True,
    hue=True,
    palette=["lightgrey"],
    ax=axs[1][2]
)
axs[1][2].legend_=None

#axs[1][0].set_alpha(0.5)
# found violin plot surfaces by trial and error below, for making them transparent:
plt.setp(axs[1][0].collections[-2], alpha=.3)
plt.setp(axs[1][0].collections[-4], alpha=.3)
plt.setp(axs[1][1].collections[-2], alpha=.3)
plt.setp(axs[1][1].collections[-4], alpha=.3)
plt.setp(axs[1][2].collections[-2], alpha=.3)
plt.setp(axs[1][2].collections[-4], alpha=.3)

axs[1][0].set_title("baseline")
axs[1][1].set_title("seizure")
axs[1][2].set_title("aftermath")

plt.tight_layout()

    
if save_figs:
    fig_fpath = os.path.join(output_folder, f'loco_per_segment_sz-excluded_mean_95ci_{get_datetime_for_fname()}_{output_version}{file_format}')
    plt.savefig(fig_fpath, format=file_format.split(".")[-1])
    print(f"Saved to {fig_fpath}")
plt.show()

## Plot results with illustrative example of data in top row

In [None]:
fig, axs = plt.subplots(2, 3, figsize=(22,14))
dset_example = traces_dict["d158cd12ad77489a827dab1173a933f9"]

# first row for examples

axs[0][0].plot(dset_example["lv_t_s"], dset_example["lv_totdist"])
axs[0][1].plot(dset_example["lv_t_s"], dset_example["lv_running"])
axs[0][2].plot(dset_example["lv_t_s"], dset_example["lv_speed"])

axs[0][2].set_ylim((-0.2, 0.7))
axs[0][0].set_ylim((0, 1400))

axs[0][0].set_xlim((350, 450))
axs[0][1].set_xlim((350, 450))
axs[0][2].set_xlim((350, 450))

axs[0][0].set_ylabel("Total distance (a.u.)", fontsize=22)
axs[0][1].set_ylabel("Running? (binary)", fontsize=22)
axs[0][2].set_ylabel("Velocity (a.u.)", fontsize=22)


axs[0][1].set_xlabel("Time (s)", fontsize=22)

# second row for statistics
sns.lineplot(data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y=loco_statistic, hue="mouse_id", palette=dict_colors_mouse, ax=axs[1][0], legend=False)
sns.lineplot(data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y="running", hue="mouse_id", palette=dict_colors_mouse, ax=axs[1][1], legend=False)
sns.lineplot(data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y="speed", hue="mouse_id", palette=dict_colors_mouse, ax=axs[1][2], legend=False)


plt.tight_layout()
#plt.savefig("D:\\Downloads\\locomotion_figure.jpg")
plt.show()

In [None]:
loco_statistic = "totdist_abs"

fig, axs = plt.subplots(2, 3, figsize=(22,12))
sns.lineplot(data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y=loco_statistic, hue="mouse_id", palette=dict_colors_mouse, ax=axs[0][0], legend=False)
sns.lineplot(data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y="running", hue="mouse_id", palette=dict_colors_mouse, ax=axs[0][1], legend=False)
sns.lineplot(data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y="speed", hue="mouse_id", palette=dict_colors_mouse, ax=axs[0][2], legend=False)
   
if save_figs:
    fig_fpath = os.path.join(output_folder, f'loco_per_segment_sz-excluded_{get_datetime_for_fname()}_{output_version}{file_format}')
    plt.savefig(fig_fpath, format=file_format.split(".")[-1])
    print(f"Saved to {fig_fpath}")
plt.show()

## Instead of plotting window type,

In [None]:
loco_statistic = "totdist_abs"

fig, axs = plt.subplots(2, 3, figsize=(22,12))
sns.lineplot(data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y=loco_statistic, hue="event_uuid", palette=dict_colors_event, estimator=None, ax=axs[0][0], legend=False)
sns.lineplot(data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y="running", hue="event_uuid", palette=dict_colors_event, estimator=None, ax=axs[0][1], legend=False)
sns.lineplot(data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y="speed", hue="event_uuid", palette=dict_colors_event, estimator=None, ax=axs[0][2], legend=False)


sns.stripplot(data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y=loco_statistic, hue="event_uuid", palette=dict_colors_event, size=8, ax=axs[1][0], legend=False)
sns.stripplot(data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y="running", hue="event_uuid", palette=dict_colors_event, size=8, ax=axs[1][1], legend=False)
sns.stripplot(data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y="speed",  hue="event_uuid", palette=dict_colors_event, size=8, ax=axs[1][2], legend=False)

sns.violinplot(
    data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y=loco_statistic,
    hue_order=[True, False], split=True,
    hue=True,
    palette=["lightgrey"],
    ax=axs[1][0]
)
axs[1][0].legend_=None

sns.violinplot(
    data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y="running",
    hue_order=[True, False], split=True,
    hue=True,
    palette=["lightgrey"],
    ax=axs[1][1]
)
axs[1][1].legend_=None

sns.violinplot(
    data=df_stats[df_stats["segment_type"].isin(["bl", "am"])], x="segment_type", y="speed", 
    hue_order=[True, False], split=True,
    hue=True,
    palette=["lightgrey"],
    ax=axs[1][2]
)
axs[1][2].legend_=None

#axs[1][0].set_alpha(0.5)
# found violin plot surfaces by trial and error below, for making them transparent:
plt.setp(axs[1][0].collections[-2], alpha=.3)
plt.setp(axs[1][0].collections[-4], alpha=.3)
plt.setp(axs[1][1].collections[-2], alpha=.3)
plt.setp(axs[1][1].collections[-4], alpha=.3)
plt.setp(axs[1][2].collections[-2], alpha=.3)
plt.setp(axs[1][2].collections[-4], alpha=.3)



plt.tight_layout()

if save_figs:
    fig_fpath = os.path.join(output_folder, f'loco_per_segment_sz-excluded_line_and_violin_{get_datetime_for_fname()}_{output_version}{file_format}')
    plt.savefig(fig_fpath, format=file_format.split(".")[-1])
    print(f"Saved to {fig_fpath}")
plt.show()

# Statistical testing

## Given two pairwise matched populations (bl and am), test significance of difference between means.

### Paired t-test (Gopal K. Kanji - 100 statistical tests: Test 10, page 35, 44/257)
As the distributions are not necessarily normally distributed, in first round, this is an approximation. Alternative is Wilcoxon signed-rank test

In [None]:
from scipy.stats import ttest_rel

In [None]:
stat_data = df_stats[df_stats["segment_type"].isin(["bl", "am"])]

In [None]:
def paired_t_test(column_name="totdist_abs", one_sided=False, greater_expected="am"):
    am_vals = []
    bl_vals = []
    for i_g, g in stat_data.groupby("event_uuid"):
        assert (len(g[g["segment_type"] == "bl"]) == 1) and (len(g[g["segment_type"] == "am"] ) == 1)
        bl_val = g[g["segment_type"] == "bl"][column_name].values
        am_val = g[g["segment_type"] == "am"][column_name].values
        am_vals.append(am_val[0])
        bl_vals.append(bl_val[0])
    am_vals = np.array(am_vals)
    bl_vals = np.array(bl_vals)
    
    if one_sided:
        ttest_result = ttest_rel(am_vals, bl_vals)
    else:
        if greater_expected=="am":
            ttest_result = ttest_rel(am_vals, bl_vals, alternative="greater")  # first dataset (am) expected to be greater
        elif greater_expected=="bl":
            ttest_result = ttest_rel(am_vals, bl_vals, alternative="less")  # second dataset (bl) expected to be greater
        else:
            raise Error(f"paired_t_test(): invalid greater_expected value {greater_expected}")
    print(ttest_result)
    return ttest_result

In [None]:
paired_t_test("running_episodes", False)  
# a negative statistic would tell us that mean(totdist_am - totdist_bl) < 0, i.e. locomotion is less after sz event.

# Look at episodes

In [None]:
dict_episodes