# Load libraries and set defaults

In [None]:
# imports libraries needed to run the rest of this code

import warnings
from glob import glob
from os.path import basename, splitext, join, dirname
from ast import literal_eval
import pickle
from collections import defaultdict
import json

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import PercentFormatter
import seaborn as sns
import pandas as pd
import scipy as sp
import scipy.stats as stats
import sklearn as sk
import sklearn.decomposition 
import xgboost as xgb
import statsmodels.stats.multitest
import statsmodels.distributions
from statsmodels.formula.api import ols

from scipy.io import loadmat
from scipy.ndimage import uniform_filter1d
from scipy.signal import savgol_filter
from scipy.signal import find_peaks
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components

from tqdm.notebook import tqdm

import arviz
import shap 
shap.initjs()
import graphviz as graphviz

from statannot import add_stat_annotation
import data_analysis
from data_analysis import get_pauses, DataLoader
from data_analysis.plotting import clean_axes, distplot, cell_heatmap, cell_mean_over_laps, cell_mean_over_laps_opto_compare
from data_analysis.settings import defaults
from data_analysis import logger
import logging
from data_analysis.place_cell import PF_analysis
from data_analysis.place_cell import mean_bin_over_laps

In [None]:
# this sets global variables, warning, and log levels. Use set_default(key_name, value) to change global variables

FRAMES_PER_SESSION = defaults.frames_per_session
FRAME_RATE = 15.49 #defaults.frame_rate
PICO_TO_CM = defaults.pico_to_cm
FIGURE_WIDTH = defaults.figure_width
warnings.filterwarnings("ignore", message="invalid value encountered in double_scalars")
logging.root.setLevel(logging.WARNING)
logger.setLevel(logging.INFO)

In [None]:
#this adjusts font sizes for figures 
SMALL_SIZE = 18
MEDIUM_SIZE = 20
BIGGER_SIZE = 20

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

# Load in behavior and NR-axon traces panda's dataframes 

In [None]:
# original code to load data from host computer, shown here for posterity
defaults.reset('beh_fmt')
with open(join(defaults.configs_dir, "mice.json"), "r") as f:
    mouse_config = json.load(f)
    
all_behavior = multi_behavior({m: [i for i, t in enumerate(v["dates"]) if t] for m, v in mouse_config.items()})
all_behavior["condition"] = all_behavior.name.apply(lambda t: mouse_config[t]["condition"][0])

with open(join(defaults.configs_dir, "mice_shocks.json"), "r") as f:
    shocks_config = json.load(f)
    
defaults.set('beh_fmt','{mouse_name}_{date}_shocks_final.mat')
shock_behavior = multi_behavior(
    {m: [i for i, t in enumerate(v["dates"]) if t] for m, v in shocks_config.items()},
    config_name = 'mice_shocks.json'
)
shock_behavior = shock_behavior[~shock_behavior.seconds.isna()].drop(
    columns=[col for col in shock_behavior.columns if col.startswith("pupil")]
).reset_index(drop=True)
defaults.reset('beh_fmt')

axon_data = pd.read_pickle(join(defaults.base_path, mouse_name, date, mouse_name + '_axon.pickle'))
beh_data = pd.read_pickle(join(defaults.base_path, mouse_name, date, mouse_name + '_behavior.pickle'))
data = pd.concat([beh_data, pd.DataFrame(axon_data.T[beh_data.index], index=beh_data.index)], axis=1)

#functions that are needed to format data into appropriate dataframes for analysis and figures
# from host computer once loaded, shown here for posterity

def get_p2r(data,
            pause_min=15, pause_max=np.inf, 
            unpaused_min=15, unpaused_max=np.inf,
            half_window_length=45, **kwargs):
    
    pauses = pd.Series(data.pause.values, index=data.frame.values)
    
    def is_pause_to_running(index):
        frame = pauses.index[index]
        if frame + unpaused_min > pauses.index[-1]:
            return False
        if frame - pause_min < pauses.index[0]:
            return False
        
        pause_min_idx, pause_max_idx = pauses.index.get_indexer([frame - pause_min, frame - pause_max], method='ffill')
        if pauses.iloc[pause_min_idx] != pauses.iloc[index - 1]:
            return False
        if pause_max_idx >= 0 and pauses.iloc[pause_max_idx] == pauses.iloc[index - 1]:
            return False
        
        unpaused_min_idx, unpaused_max_idx = pauses.index.get_indexer([frame + unpaused_min, frame + unpaused_max], method='bfill')
        if pauses.iloc[unpaused_min_idx] != pauses.iloc[index]:
            return False
        if unpaused_max_idx >= 0 and pauses.iloc[unpaused_max_idx] == pauses.iloc[index]:
            return False
        
        return True

    # finds the frame where pausing transitions to running
    global pause_to_running
    indices = np.where(np.diff(pauses) < 0)[0] + 1
    print(f'shape of p2r before filter: {indices.shape}')
    pause_to_running = pauses.index[list(filter(is_pause_to_running, indices))]

    p2r = np.empty([len(pause_to_running), 2 * half_window_length], dtype=np.int32)

    for i, frame in enumerate(pause_to_running):
        idx = np.searchsorted(pauses.index, frame + np.arange(-half_window_length, half_window_length),
                              'right') - 1
        p2r[i] = idx
    return p2r

def get_beh(mouse_name, date, use_cache=False, **kwargs):
    dl = DataLoader(mouse_name=mouse_name,
                    date=date,
                    use_cache=use_cache,
                    **kwargs
                   )
    
    return dl.get_behavior()

def multi_behavior(config, **dl_kwargs):
    behavior = []
    for mouse_name, dates in config.items():
        for day in dates:
            df = get_beh(mouse_name, day, **dl_kwargs)
            df.insert(0, "day", day)
            df.insert(0, "name", mouse_name)
            behavior.append(df)
    return pd.concat(behavior, axis=0).reset_index().rename({"index": "frame"}, axis=1)

def context_trace(config, mouse_name, day, is_axon, default_vals={}):
    dl = DataLoader(mouse_name=mouse_name,
                    date=day,
                    use_cache=False
                   )
    kwargs = default_vals.get("kwargs", {}).copy()
    kwargs.update(config.get("kwargs", {}))
    df = dl.merge_behavior(is_axon, **kwargs)
    
    if is_axon:
        axon = (config if "axon" in config else default_vals)["axon"]
        df["axon"] = df[axon]
    else:
        soma = (config if "soma" in config else default_vals)["soma"]
        df["soma"] = df[soma]
    
    df = df.drop([col for col in df.columns if isinstance(col, int)], axis=1)
    
    df.insert(0, 'day', day)
    df.insert(0, 'name', mouse_name)
    
    return df

def multi_trace(config, is_axon):
    with open(join(defaults.configs_dir, "traces.json"), "r") as f:
        traces_config = json.load(f)
    
    conf_defaults = traces_config.get("defaults", {})
    traces = []
    for mouse_name, dates in config.items():
        for day in dates:
            if traces_config.get(mouse_name, defaultdict(lambda: None))[day] is None: continue
            trace_conf = traces_config[mouse_name][day]
            traces.append(context_trace(trace_conf, mouse_name, day, is_axon, conf_defaults))
    return pd.concat(traces, axis=0).reset_index().rename({"index": "frame"}, axis=1)

with open(join(defaults.configs_dir, "traces.json"), "r") as f:
    traces_config = json.load(f)
    
movement_thresh = 0.1

traces["is_paused"] = traces.pause > 0
traces["is_running"] = ~traces.is_paused & (traces.recorded_velocity > movement_thresh)
traces["is_backtracking"] = ~traces.is_paused & (traces.recorded_velocity < 0)

window_after_pause = 8

cols = ["name", "day", "context", "pause"]
interval_starts = traces.groupby(cols, as_index=False).frame.first().rename({"frame": "interval_start"}, axis=1)
interval_starts = pd.merge(traces, interval_starts, on=cols).interval_start

traces["is_postpause"] = ~traces.is_paused & ((traces.frame - interval_starts) <= window_after_pause)

movement_thresh = 0.1

all_behavior["is_paused"] = all_behavior.pause > 0
all_behavior["is_running"] = ~all_behavior.is_paused & (all_behavior.recorded_velocity > movement_thresh)
all_behavior["is_backtracking"] = ~all_behavior.is_paused & (all_behavior.recorded_velocity < 0)

def first_laps_col(df, gb_cols, lap_col="lap", nlaps=3):
    df = pd.merge(
        df[gb_cols + [lap_col]], 
        df.groupby(gb_cols)[lap_col].first(),
        left_on=gb_cols,
        right_index=True
    )
    return df[f"{lap_col}_x"] - df[f"{lap_col}_y"] < nlaps

all_behavior["is_first_lap"] = first_laps_col(
    all_behavior,
    ["name", "day", "context"],
    nlaps=1
)

In [None]:
# run this instead if you're trying to reproduce figures/analyze data from the paper and provided pandas dataframes
traces = pd.read_pickle("traces.pkl")  
all_behavior = pd.read_pickle("all_behavior.pkl")
shock_behavior = pd.reac_pickle("shock_behavior.pkl")

# Figure 1

## Figures 1a and 1b are schematics that were generated in biorender

## Figure 1c

In [None]:
df = pd.concat(vals, axis=1)
df.columns = range(len(df.columns))
df = df.sort_index().interpolate().stack().reset_index()
df.columns = ["frame", "shock_number", "velocity"]

fig, ax = plt.subplots(1, figsize=(6.4, 4.8))

colors = 'brown'
sns.lineplot(data=df, x="frame", y="velocity", color = colors, 
             errorbar=("ci", 95), err_style="band")

ax.lines[0].set_color(colors)

ax.set(
    ylabel='velocity (cm/s)', 
    xlim = [-30,45],
    ylim = [5,70],
    yticks = [10,30,50,70],
    yticklabels = [20,40,60,80],
    xticks = [-18, 0, 18, 36],
    xticklabels = ['-1s','0s','+1s', '+2s'],
    xlabel = '',
    )

shock_color = 'xkcd:cool grey'
plt.axvline(0, 0, 1, linewidth=4, color = shock_color)
plt.axvspan(0,18, alpha=0.3, color=shock_color)
plt.axvline(18, 0, 1, linewidth=4, color =shock_color)

ax.tick_params('both', length=15, width=2.5, which='major')
ax.tick_params('both', length=7.5, width=2.5, which='minor')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
plt.legend([],[], frameon=False)

sns.despine()
plt.show()

## Figure 1d

In [None]:
#top

single_beh = all_behavior[(all_behavior.name=='MR8a')&(all_behavior.day==0)&(all_behavior.context.isin(['fear','safe']))]
df_name = single_beh

fig, axes = plt.subplots(1, figsize=(20, 2))
clean_axes(axes)

df_name.plot(x='seconds', y='ybinned', color='xkcd:very dark brown', linewidth = 1.5, ax=axes, legend=False)

axes.set(
        ylim=[0,0.61],
        xlim=[16,651],
        yticks=(0, .61),
        yticklabels=('0m', '2m'),
        xticks=[16,651],
        xticklabels=['',''],
        xlabel = ''
            )

gb = df_name.seconds.groupby(df_name.pause)
prev = df_name.seconds.iloc[0]
ranges = pd.DataFrame({'first': gb.first(), 'last': gb.last()})

for i, (first, last) in ranges[ranges.index > 0].sort_index().iterrows():
    for ax in np.asarray(axes).flat:
        ax.axvspan(first, last, color='xkcd:red orange', alpha=0.3)
        ax.axvspan(prev, first, color='xkcd:white', alpha=0.12)
    prev = last
for ax in np.asarray(axes).flat:
    ax.axvspan(prev, df_name.seconds.iloc[-1], color='xkcd:white', alpha=0.1)

for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2.5)
    
ax.tick_params('both', length=10, width=2.5, which='major')
sns.despine()

plt.show()


In [None]:
#bottom

single_beh = all_behavior[(all_behavior.name=='MR8a')&(all_behavior.day==1)&(all_behavior.context.isin(['fear','safe']))]
df_name = single_beh

fig, axes = plt.subplots(1, figsize=(20, 2))
clean_axes(axes)

df_name.plot(x='seconds', y='ybinned', color='xkcd:very dark brown', linewidth = 1.5, ax=axes, legend=False)

axes.set(
        ylim=[0,0.61],
        xlim=[16,651],
        yticks=(0, .61),
        yticklabels=('0m', '2m'),
        xticks=[16,651],
        xticklabels=['',''],
        xlabel = ''
            )

gb = df_name.seconds.groupby(df_name.pause)
prev = df_name.seconds.iloc[0]
ranges = pd.DataFrame({'first': gb.first(), 'last': gb.last()})

for i, (first, last) in ranges[ranges.index > 0].sort_index().iterrows():
    for ax in np.asarray(axes).flat:
        ax.axvspan(first, last, color='xkcd:red orange', alpha=0.3)
        ax.axvspan(prev, first, color='xkcd:white', alpha=0.12)
    prev = last
for ax in np.asarray(axes).flat:
    ax.axvspan(prev, df_name.seconds.iloc[-1], color='xkcd:white', alpha=0.1)

for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2.5)
    
ax.tick_params('both', length=10, width=2.5, which='major')
sns.despine()

plt.show()

## Figure 1e

In [None]:
pause_fraction = (all_behavior.pause > 0).groupby(
    [all_behavior[col] for col in ["context", "condition", "name", "day"]]).mean().reset_index()

pause_fraction_normalized = pause_fraction[pause_fraction.context.isin(["fear", "safe"])].copy()

for (context, condition, name),df in pause_fraction_normalized[pause_fraction_normalized.day > 0].groupby(["context","condition", "name"]):    
    pause_fraction_normalized.loc[df.index, "pause"] = df.pause - pause_fraction[
        (pause_fraction.context.isin([context]))
        & (pause_fraction.day == 1)
        & (pause_fraction.condition == condition)
        & (pause_fraction.name == name)
    ].pause.mean()
    
df = pause_fraction_normalized
colors = ["teal","orchid"]

color_it = iter(colors)

fig, ax = plt.subplots(1, figsize=(6.4, 4.8))

conditions = {k for mouse in mouse_config.values() for k in mouse["condition"]}
is_cond = pd.DataFrame({
    cond: df.name.apply(
        lambda t: cond in mouse_config[t]["condition"])
    for cond in conditions
})

for condition in ["nr_standard"]:
    for context in ["fear"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
                x="day", y="pause", marker="o", errorbar=("ci", 95), err_style="band", color=colors[0]
                             )
            
for condition in ["nr_standard"]:
    for context in ["safe"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
                x="day", y="pause", marker="o", errorbar=("ci", 95), err_style="band", color=  colors[1]
                             )
ax.set(
    ylabel='time spent freezing', 
    yticks = [0,0.2,0.4,0.6,0.8],
    yticklabels = ['0%','20%','40%','60%','80%'],
    xlabel = '',
    xlim = [.97,4],
    xticks = [1,2,3,4],
    xticklabels=['pre-shocks','day 1','day 2','day 3']
    )
              
    
ax.tick_params('both', length=15, width=2.5, which='major')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
sns.despine()

plt.show()

## Figure 1f

In [None]:
df = pause_fraction_normalized

colors = ["teal","orchid"]

color_it = iter(colors)

fig, ax = plt.subplots(1, figsize=(6.4, 4.8))

conditions = {k for mouse in mouse_config.values() for k in mouse["condition"]}
is_cond = pd.DataFrame({
    cond: df.name.apply(
        lambda t: cond in mouse_config[t]["condition"])
    for cond in conditions
})

for condition in ["nr_no_shock_control"]:
    for context in ["fear"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
                x="day", y="pause", marker="o", errorbar=("ci", 95), err_style="band", color=colors[0]
                             )
            
for condition in ["nr_no_shock_control"]:
    for context in ["safe"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
                x="day", y="pause", marker="o", errorbar=("ci", 95), err_style="band", color=  colors[1]
                             )
        
ax.set(
    ylabel='time spent freezing', 
    yticks = [0,0.2,0.4,0.6,0.8],
    yticklabels = ['0%','20%','40%','60%','80%'],
    xlabel = '',
    xlim = [.97,4],
    xticks = [1,2,3,4],
    xticklabels=['pre-shocks','day 1','day 2','day 3']
    )
              
    
ax.tick_params('both', length=15, width=2.5, which='major')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
sns.despine()

plt.show()

## Figure 1g

In [None]:
bin_size = 0.07
day = 2
bin_window = 20

colors = ['teal','orchid']

fig, ax = plt.subplots(1, figsize=(6.4, 4.8))


pause_lengths_nr_standard = get_pause_lengths(all_behavior[all_behavior.context.isin(["fear","safe"])
                                                & (all_behavior.day == day) 
                                                & (all_behavior.condition == "nr_standard")],bin_window)


ax = sns.histplot(data=pause_lengths_nr_standard[pause_lengths_nr_standard.context == "fear"], x='length',
                  multiple="layer", edgecolor = colors[0], color = colors[0], linewidth = 2, fill = False,
                  log_scale=True, stat='probability', binwidth = bin_size, kde=True, zorder = 3,
                  alpha=1, line_kws={"color": colors[0], "linewidth": 3, "alpha":1, "linestyle":'solid'})


ax = sns.histplot(data=pause_lengths_nr_standard[pause_lengths_nr_standard.context == "safe"], x='length',
                  multiple="layer", color = colors[1], edgecolor = colors[1], linewidth = 2, fill = False,
                  log_scale=True, stat='probability', binwidth = bin_size, kde=True, zorder = 0,
                  alpha=1, line_kws={"color": colors[1], "linewidth": 3, "alpha":1, "linestyle":'solid'}, ax=ax)

ax.lines[0].set_color(colors[0])

ax.set(
    ylabel ='freeze length probability', 
    ylim = [0,.18],
    yticks = [0, 0.06, 0.12,0.18],
    xlabel = 'length of freezed epoch (log scale)',
    xlim = [1,100],
    xticks = [1,3,10,30,100],
    xticklabels=['1s','3s','10s','30s','100s']
    )

ax.tick_params('both', length=15, width=2.5, which='major')
ax.tick_params('both', length=7.5, width=2.5, which='minor')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
    
sns.despine()
plt.legend([],[], frameon=False)
plt.show()


## Figure 1h

In [None]:
colors = ["teal"]
     
bin_size = (shock_behavior.ybinned.max()-shock_behavior.ybinned.min())/60    
fig, ax = plt.subplots(1, figsize=(6.4, 4.8))


sns.histplot(data=shock_behavior[shock_behavior.shock], x="ybinned",
             edgecolor = colors[0], linewidth = 2, fill = True,
             stat='probability', binwidth = bin_size,kde=True,
             alpha=0, line_kws={"color": colors, "linewidth": 3, "alpha":1, "linestyle":'solid'})

ax.lines[0].set_color(colors[0])

ax.set(
    ylabel='shock delivery probability', 
    xlim = [0.097, 0.55],
    ylim = [0,0.04],
    yticks = [0,0.02,0.04],
    xticks = [0, 0.305, 0.61],
    xticklabels = ['0m','1m','2m'],
    xlabel = '',
    )
              
ax.tick_params('both', length=15, width=2.5, which='major')
ax.tick_params('both', length=7.5, width=2.5, which='minor')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
plt.legend([],[], frameon=False)

sns.despine()
plt.show()

## Figure 1h

In [None]:
colors = ["teal"]
     
bin_size = (shock_behavior.ybinned.max()-shock_behavior.ybinned.min())/60    
fig, ax = plt.subplots(1, figsize=(6.4, 4.8))


sns.histplot(data=shock_behavior[shock_behavior.shock], x="ybinned",
             edgecolor = colors[0], linewidth = 2, fill = True,
             stat='probability', binwidth = bin_size,kde=True,
             alpha=0, line_kws={"color": colors, "linewidth": 3, "alpha":1, "linestyle":'solid'})

ax.lines[0].set_color(colors[0])

ax.set(
    ylabel='shock delivery probability', 
    xlim = [0, 0.61],
    ylim = [0,0.04],
    yticks = [0,0.02,0.04],
    xticks = [0, 0.305, 0.61],
    xticklabels = ['0m','1m','2m'],
    xlabel = '',
    )
              
ax.tick_params('both', length=15, width=2.5, which='major')
ax.tick_params('both', length=7.5, width=2.5, which='minor')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
plt.legend([],[], frameon=False)

sns.despine()
plt.show()

## Figure 1i

In [None]:
df = all_behavior[all_behavior.is_paused].groupby(["condition", "name", "day", "context", "pause"], as_index=False).ybinned.first()

colors = ['teal','orchid']

fig, ax = plt.subplots(1, figsize=(6.4, 4.8))

bin_size = (all_behavior.ybinned.max()-all_behavior.ybinned.min())/40

ax = sns.histplot(data=df[(df.context == "fear")], x="ybinned",
                  multiple="layer", edgecolor = colors[0], linewidth = 2, fill = True,
                  stat='probability', binwidth = bin_size,kde=True,
                  alpha=0.0, line_kws={"color": colors[0], "linewidth": 3, "alpha":1, "linestyle":'solid'})

ax.lines[0].set_color(colors[0])

ax = sns.histplot(data=df[(df.context == "safe")], x="ybinned",
                  multiple="layer", edgecolor = colors[1], linewidth = 2, fill = True,
                  stat='probability', binwidth = bin_size,kde=True,
                  alpha=0.0, line_kws={"color": colors[1], "linewidth": 3, "alpha":1, "linestyle":'solid'}, ax=ax)
ax.lines[1].set_color(colors[1])

ax.set(
    ylabel ='freeze start \n location probablility', 
    xlim=[0.02,all_behavior.ybinned.max()],
    ylim=[0,0.06],
    yticks = [0,0.0305,0.061],
    xlabel = '',
    xticks = [0.019,0.323,all_behavior.ybinned.max()],
    xticklabels=['0m','1m','2m']
    )

handles, labels = plt.gca().get_legend_handles_labels()
sns.despine()

ax.tick_params('both', length=15, width=2.5, which='major')

for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
     
plt.legend([],[], frameon=False)

plt.show()

## Figure 1j

In [None]:
colors = ["teal","orchid"]

df = all_behavior[all_behavior.condition.isin(["nr_standard"]) & (all_behavior.context.isin(["fear", "safe"]) & (all_behavior.day.isin([1,2,3,4])))]
                               
fig, ax = plt.subplots(1, figsize=(6.4, 4.8))

ax = sns.boxplot(data=df[df.is_running].groupby(["name", "day", "context"], as_index=False).velocity.mean(), 
                 x="day", y="velocity", hue="context", 
                 palette = colors, fliersize = 0,  whis = 1, boxprops=dict(alpha=.5))

ax = sns.stripplot(data=df[df.is_running].groupby(["name", "day", "context"], as_index=False).velocity.mean(), 
                   x="day", y="velocity", hue="context", 
                 palette = colors, dodge = True)

ax.set(
    ylabel='running velocity (cm/s)', 
    ylim = [0,50],
    yticks = [0,25,50],
    xlabel = '',
    xticks = [0,1,2,3],
    xticklabels=['pre-shocks','day 1','day 2','day 3']
    )
              
ax.tick_params('both', length=15, width=2.5, which='major')
ax.tick_params('both', length=7.5, width=2.5, which='minor')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
plt.legend([],[], frameon=False)

sns.despine()
plt.show()

# Figure 2

## Figures 2a and 2b were made with biorender

## Figure 2c

In [None]:
single_beh = all_behavior[(all_behavior.name=='R3')&(all_behavior.day==0)&(all_behavior.context.isin(['fear','safe']))]
df_name = single_beh

fig, axes = plt.subplots(1, figsize=(20, 2))
clean_axes(axes)

df_name.plot(x='seconds', y='ybinned', color='xkcd:very dark brown', linewidth = 1.5, ax=axes, legend=False)

axes.set(
        ylim=[0,0.61],
        xlim=[16,651],
        yticks=(0, .61),
        yticklabels=('0m', '2m'),
        xticks=[16,651],
        xticklabels=['',''],
        xlabel = ''
            )

gb = df_name.seconds.groupby(df_name.pause)
prev = df_name.seconds.iloc[0]
ranges = pd.DataFrame({'first': gb.first(), 'last': gb.last()})

for i, (first, last) in ranges[ranges.index > 0].sort_index().iterrows():
    for ax in np.asarray(axes).flat:
        ax.axvspan(first, last, color='xkcd:red orange', alpha=0.3)
        ax.axvspan(prev, first, color='xkcd:white', alpha=0.12)
    prev = last
for ax in np.asarray(axes).flat:
    ax.axvspan(prev, df_name.seconds.iloc[-1], color='xkcd:white', alpha=0.1)

for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2.5)
    
ax.tick_params('both', length=10, width=2.5, which='major')
sns.despine()

plt.show()

In [None]:
single_beh = all_behavior[(all_behavior.name=='R3')&(all_behavior.day==1)&(all_behavior.context.isin(['fear','safe']))]
df_name = single_beh

fig, axes = plt.subplots(1, figsize=(20, 2))
clean_axes(axes)

df_name.plot(x='seconds', y='ybinned', color='xkcd:very dark brown', linewidth = 1.5, ax=axes, legend=False)

axes.set(
        ylim=[0,0.61],
        xlim=[16,651],
        yticks=(0, .61),
        yticklabels=('0m', '2m'),
        xticks=[16,651],
        xticklabels=['',''],
        xlabel = ''
            )

gb = df_name.seconds.groupby(df_name.pause)
prev = df_name.seconds.iloc[0]
ranges = pd.DataFrame({'first': gb.first(), 'last': gb.last()})

for i, (first, last) in ranges[ranges.index > 0].sort_index().iterrows():
    for ax in np.asarray(axes).flat:
        ax.axvspan(first, last, color='xkcd:red orange', alpha=0.3)
        ax.axvspan(prev, first, color='xkcd:white', alpha=0.12)
    prev = last
for ax in np.asarray(axes).flat:
    ax.axvspan(prev, df_name.seconds.iloc[-1], color='xkcd:white', alpha=0.1)

for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2.5)
    
ax.tick_params('both', length=10, width=2.5, which='major')
sns.despine()

plt.show()

## Figure 2d

In [None]:
pause_fraction = (all_behavior.pause > 0).groupby(
    [all_behavior[col] for col in ["context", "condition", "name", "day"]]).mean().reset_index()

pause_fraction_normalized = pause_fraction[pause_fraction.context.isin(["fear", "safe"])].copy()

for (context, condition, name),df in pause_fraction_normalized[pause_fraction_normalized.day > 0].groupby(["context","condition", "name"]):    
    pause_fraction_normalized.loc[df.index, "pause"] = df.pause - pause_fraction[
        (pause_fraction.context.isin([context]))
        & (pause_fraction.day == 1)
        & (pause_fraction.condition == condition)
        & (pause_fraction.name == name)
    ].pause.mean()
    
df = pause_fraction_normalized

fig, ax = plt.subplots(1, figsize=(6.4, 4.8))


for condition in ["dreadd_dcz"]:
    for context in ["fear"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
               hue = 'name', x="day", y="pause", marker="o") 
                               
ax.set(
    ylabel='baseline normalized \n time spent freezing', 
    yticks = [0,0.2,0.4,0.6,0.8],
    yticklabels = ['0%','20%','40%','60%','80%'],
    xlabel = '',
    ylim = [0,.87],
    #xlim = [1,4],
    xticks = [1,2,3,4],
    xticklabels=['pre-shocks','day 1','day 2','day 3']
    )
              
    
ax.tick_params('both', length=15, width=2.5, which='major')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
sns.despine()

leg = plt.legend(bbox_to_anchor=(1.5,1))
plt.show()

## Figure 2e

In [None]:
df = pause_fraction_normalized
colors = ["teal","orchid"]

color_it = iter(colors)

fig, ax = plt.subplots(1, figsize=(6.4, 4.8))

conditions = {k for mouse in mouse_config.values() for k in mouse["condition"]}
is_cond = pd.DataFrame({
    cond: df.name.apply(
        lambda t: cond in mouse_config[t]["condition"])
    for cond in conditions
})

for condition in ["dreadd_dcz"]:
    for context in ["fear"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
                x="day", y="pause", marker="o", errorbar=("ci", 95), err_style="band", color=colors[0]
                             )
            
for condition in ["dreadd_dcz"]:
    for context in ["safe"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
                x="day", y="pause", marker="o", errorbar=("ci", 95), err_style="band", color=  colors[1]
                             )
ax.set(
    ylabel='time spent freezing', 
    yticks = [0,0.2,0.4,0.6,0.8],
    yticklabels = ['0%','20%','40%','60%','80%'],
    xlabel = '',
    xlim = [.97,4],
    xticks = [1,2,3,4],
    xticklabels=['pre-shocks','day 1','day 2','day 3']
    )
              
    
ax.tick_params('both', length=15, width=2.5, which='major')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
sns.despine()

plt.show()

# Figure 3

## Figure 3a

In [None]:
pause_fraction = (all_behavior.pause > 0).groupby(
    [all_behavior[col] for col in ["context", "condition", "name", "day"]]).mean().reset_index()

pause_fraction_normalized = pause_fraction[pause_fraction.context.isin(["fear", "safe"])].copy()

for (context, condition, name),df in pause_fraction_normalized[pause_fraction_normalized.day > 0].groupby(["context","condition", "name"]):    
    pause_fraction_normalized.loc[df.index, "pause"] = df.pause - pause_fraction[
        (pause_fraction.context.isin([context]))
        & (pause_fraction.day == 1)
        & (pause_fraction.condition == condition)
        & (pause_fraction.name == name)
    ].pause.mean()
    
df = pause_fraction_normalized
colors = ["xckd:turquoise","xkcd:denim"]

color_it = iter(colors)

fig, ax = plt.subplots(1, figsize=(6.4, 4.8))

conditions = {k for mouse in mouse_config.values() for k in mouse["condition"]}
is_cond = pd.DataFrame({
    cond: df.name.apply(
        lambda t: cond in mouse_config[t]["condition"])
    for cond in conditions
})

for condition in ["nr_standard"]:
    for context in ["fear"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
                x="day", y="pause", marker="o", errorbar=("ci", 95), err_style="band", color=colors[0]
                             )
            
for condition in ["dreadd_dcz"]:
    for context in ["fear"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
                x="day", y="pause", marker="o", errorbar=("ci", 95), err_style="band", color=  colors[1]
                             )
ax.set(
    title = 'VR-CFC in the control context',
    ylabel='baseline normalized \n time spent freezing', 
    yticks = [0,0.2,0.4,0.6,0.8],
    yticklabels = ['0%','20%','40%','60%','80%'],
    xlabel = '',
    xlim = [.97,4],
    xticks = [1,2,3,4],
    xticklabels=['pre-shocks','day 1','day 2','day 3']
    )
              
    
ax.tick_params('both', length=15, width=2.5, which='major')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
sns.despine()

plt.show()

## Figure 3b

In [None]:
pause_fraction = (all_behavior.pause > 0).groupby(
    [all_behavior[col] for col in ["context", "condition", "name", "day"]]).mean().reset_index()

pause_fraction_normalized = pause_fraction[pause_fraction.context.isin(["fear", "safe"])].copy()

for (context, condition, name),df in pause_fraction_normalized[pause_fraction_normalized.day > 0].groupby(["context","condition", "name"]):    
    pause_fraction_normalized.loc[df.index, "pause"] = df.pause - pause_fraction[
        (pause_fraction.context.isin([context]))
        & (pause_fraction.day == 1)
        & (pause_fraction.condition == condition)
        & (pause_fraction.name == name)
    ].pause.mean()
    
df = pause_fraction_normalized
colors = ["xkcd:rose pink","xkcd:merlot"]

color_it = iter(colors)

fig, ax = plt.subplots(1, figsize=(6.4, 4.8))

conditions = {k for mouse in mouse_config.values() for k in mouse["condition"]}
is_cond = pd.DataFrame({
    cond: df.name.apply(
        lambda t: cond in mouse_config[t]["condition"])
    for cond in conditions
})

for condition in ["nr_standard"]:
    for context in ["safe"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
                x="day", y="pause", marker="o", errorbar=("ci", 95), err_style="band", color=colors[0]
                             )
            
for condition in ["dreadd_dcz"]:
    for context in ["safe"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
                x="day", y="pause", marker="o", errorbar=("ci", 95), err_style="band", color=  colors[1]
                             )
ax.set(
    title = 'VR-CFC in the control context',
    ylabel='baseline normalized \n time spent freezing', 
    yticks = [0,0.2,0.4,0.6,0.8],
    yticklabels = ['0%','20%','40%','60%','80%'],
    xlabel = '',
    xlim = [.97,4],
    xticks = [1,2,3,4],
    xticklabels=['pre-shocks','day 1','day 2','day 3']
    )
              
    
ax.tick_params('both', length=15, width=2.5, which='major')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
sns.despine()

plt.show()

## Figure 3c

In [None]:
bin_size = 0.07
day = 2
bin_window = 20

colors = ["xkcd:denim","xkcd:turquoise"]

fig, ax = plt.subplots(1, figsize=(6.4, 4.8))


pause_lengths_nr_standard = get_pause_lengths(all_behavior[all_behavior.context.isin(["fear","safe"])
                                                & (all_behavior.day == day) 
                                                & (all_behavior.condition == "nr_standard")],bin_window)

pause_lengths_dreadd_dcz = get_pause_lengths(all_behavior[all_behavior.context.isin(["fear","safe"])
                                                & (all_behavior.day == day) 
                                                & (all_behavior.condition == "dreadd_dcz")],bin_window)

ax = sns.histplot(data=pause_lengths_nr_standard[pause_lengths_nr_standard.context == "fear"], x='length',
                  multiple="layer", edgecolor = colors[0], color = colors[0], linewidth = 2, fill = False,
                  log_scale=True, stat='probability', binwidth = bin_size, kde=True, zorder = 3,
                  alpha=1, line_kws={"color": colors[0], "linewidth": 3, "alpha":1, "linestyle":'solid'})


ax = sns.histplot(data=pause_lengths_dreadd_dcz[pause_lengths_dreadd_dcz.context == "fear"], x='length',
                  multiple="layer", color = colors[1], edgecolor = colors[1], linewidth = 2, fill = False,
                  log_scale=True, stat='probability', binwidth = bin_size, kde=True, zorder = 0,
                  alpha=1, line_kws={"color": colors[1], "linewidth": 3, "alpha":1, "linestyle":'solid'}, ax=ax)

ax.lines[0].set_color(colors[0])

ax.set(
    title ='freeze lengths in the shocked context', 
    ylim = [0,.18],
    yticks = [0, 0.06, 0.12,0.18],
    ylabel = 'freeze length probability',
    xlabel = 'length of freezed epoch (log scale)',
    xlim = [1,100],
    xticks = [1,3,10,30,100],
    xticklabels=['1s','3s','10s','30s','100s']
    )

ax.tick_params('both', length=15, width=2.5, which='major')
ax.tick_params('both', length=7.5, width=2.5, which='minor')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
    
sns.despine()
plt.legend([],[], frameon=False)
plt.show()


## Figure 3d

In [None]:
bin_size = 0.07
day = 2
bin_window = 20

colors = ["xkcd:merlot","xkcd:rose pink"]

fig, ax = plt.subplots(1, figsize=(6.4, 4.8))


pause_lengths_nr_standard = get_pause_lengths(all_behavior[all_behavior.context.isin(["fear","safe"])
                                                & (all_behavior.day == day) 
                                                & (all_behavior.condition == "nr_standard")],bin_window)

pause_lengths_dreadd_dcz = get_pause_lengths(all_behavior[all_behavior.context.isin(["fear","safe"])
                                                & (all_behavior.day == day) 
                                                & (all_behavior.condition == "dreadd_dcz")],bin_window)

ax = sns.histplot(data=pause_lengths_nr_standard[pause_lengths_nr_standard.context == "safe"], x='length',
                  multiple="layer", edgecolor = colors[0], color = colors[0], linewidth = 2, fill = False,
                  log_scale=True, stat='probability', binwidth = bin_size, kde=True, zorder = 3,
                  alpha=1, line_kws={"color": colors[0], "linewidth": 3, "alpha":1, "linestyle":'solid'})


ax = sns.histplot(data=pause_lengths_dreadd_dcz[pause_lengths_dreadd_dcz.context == "safe"], x='length',
                  multiple="layer", color = colors[1], edgecolor = colors[1], linewidth = 2, fill = False,
                  log_scale=True, stat='probability', binwidth = bin_size, kde=True, zorder = 0,
                  alpha=1, line_kws={"color": colors[1], "linewidth": 3, "alpha":1, "linestyle":'solid'}, ax=ax)

ax.lines[0].set_color(colors[0])

ax.set(
    title ='freeze lengths in the control context', 
    ylim = [0,.18],
    yticks = [0, 0.06, 0.12,0.18],
    ylabel = 'freeze length probability',
    xlabel = 'length of freezed epoch (log scale)',
    xlim = [1,100],
    xticks = [1,3,10,30,100],
    xticklabels=['1s','3s','10s','30s','100s']
    )

ax.tick_params('both', length=15, width=2.5, which='major')
ax.tick_params('both', length=7.5, width=2.5, which='minor')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
    
sns.despine()
plt.legend([],[], frameon=False)
plt.show()


## Figure 3e

In [None]:
def seaborn_estimates(
    x, y, data, hue=None,
    estimator="mean",
    errorbar="ci",
    **boot_kws
):
    agg = sns._statistics.EstimateAggregator(
        estimator=estimator,
        errorbar=errorbar,
        **boot_kws
    )
    
    data = data.rename(columns={x: "x", y: "y"})
    
    gb_cols = ["x", hue]
    rows = []
    for key, df in data.groupby(gb_cols):
        row = dict(zip(gb_cols, key))
        row.update(agg(df, "y"))
        rows.append(row)
        
    return pd.DataFrame(rows)

def sns_midpoint(*args, **kwargs):
    df = seaborn_estimates(*args, **kwargs)
    df["ymid"] = (df.ymin + df.ymax) / 2
    return df

In [None]:
freezing_frac = all_behavior[
    all_behavior.context.isin(["fear","safe"]) &
    all_behavior.condition.isin(["nr_standard","dreadd_dcz"])
].groupby(["context", "condition", "name", "day"]).is_paused.mean()

di = (
    (freezing_frac.xs("fear") - freezing_frac.xs("safe")) /
    (freezing_frac.xs("fear") + freezing_frac.xs("safe"))
).reset_index()

colors = ["xkcd:bright lilac","xkcd:bruise"]
                               
fig, ax = plt.subplots(1, figsize=(6.4, 4.8))

sns.lineplot(
    data=di,
    x="day", y="is_paused", hue="condition", errorbar = ('ci',95),
    marker= 'o', palette = colors
        )

ax.set(
    ylabel='% freezing difference \n between contexts',
    ylim = [-0.2,0.4],
    yticks = [-0.2,0,.2,.4],
    yticklabels = ['-20%','0%','20%','40%'],
    xlabel = '',
    xticks = [1,2,3,4],
    xticklabels=['pre-shocks','day 1','day 2','day 3']
    )
              
ax.tick_params('both', length=15, width=2.5, which='major')
ax.tick_params('both', length=7.5, width=2.5, which='minor')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
plt.legend([],[], frameon=False)

sns.despine()
plt.show()

## Figure 3f

In [None]:
pause_fraction = (all_behavior.pause > 0).groupby(
    [all_behavior[col] for col in ["context", "condition", "name", "day"]]).mean().reset_index()

pause_fraction_normalized = pause_fraction[pause_fraction.context.isin(["dark"])].copy()

for (context, condition, name),df in pause_fraction_normalized[pause_fraction_normalized.day > 0].groupby(["context","condition", "name"]):    
    pause_fraction_normalized.loc[df.index, "pause"] = df.pause - pause_fraction[
        (pause_fraction.context.isin([context]))
        & (pause_fraction.condition == condition)
        & (pause_fraction.name == name)
    ].pause.mean()
    
df = pause_fraction_normalized
colors = ["xckd:medium brown","xkcd:warm grey"]

color_it = iter(colors)

fig, ax = plt.subplots(1, figsize=(6.4, 4.8))

conditions = {k for mouse in mouse_config.values() for k in mouse["condition"]}
is_cond = pd.DataFrame({
    cond: df.name.apply(
        lambda t: cond in mouse_config[t]["condition"])
    for cond in conditions
})

for condition in ["dreadd_saline_control"]:
    for context in ["dark"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
                x="day", y="pause", marker="o", errorbar=("ci", 95), err_style="band", color=colors[0]
                             )
            
for condition in ["dreadd_dcz"]:
    for context in ["dark"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
                x="day", y="pause", marker="o", errorbar=("ci",95), err_style="band", color=  colors[1]
                             )
ax.set(
    title = 'VR-CFC in the dark',
    ylabel='baseline normalized \n time spent freezing', 
    yticks = [0,0.2,0.4,0.6,0.8],
    yticklabels = ['0%','20%','40%','60%','80%'],
    xlabel = '',
    xlim = [.97,4],
    xticks = [2,3,4],
    xticklabels=['pre-shocks','day 1','day 2','day 3']
    )
              
    
ax.tick_params('both', length=15, width=2.5, which='major')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
sns.despine()

plt.show()

# Figure 4

## Figure 4a was made in biorender

## Figure 4b

In [None]:
df_name = traces[(traces.name=='MR6e')&(traces.day==1)].set_index("frame")

peaks = {roi: find_peaks(df_name[roi], height=0.1, distance=1, prominence=0.35) 
         for roi in df_name.columns 
         if isinstance(roi, int) or 'axon' in roi}

axon_val = 'rescaled_axon'
frame_start = 0
frame_end = 10000

context_color = ['saddlebrown', 'forestgreen']
peak_scatter = pd.Series(peaks[axon_val][1]['peak_heights'], 
                         index=df_name.index[peaks[axon_val][0]]).loc[frame_start:frame_end]
df_name = df_name.loc[frame_start:frame_end].copy()
df_name.seconds -= df_name.seconds.iloc[0]

fig, axes = plt.subplots(3, figsize=(18, 7),sharex=True)
fig.subplots_adjust(hspace=0)
clean_axes(axes)


#axes 0: axon pre-shocks
df_name.plot(x='seconds', y=axon_val, color=context_color[1], ax=axes[0], legend=False)
axes[0].scatter(df_name.loc[peak_scatter.index, 'seconds'], peak_scatter.values, color='orchid')

(mid_line, ), = np.where(np.diff(df_name.context.astype("category").cat.codes))
for ax in axes:
    ax.axvline(df_name.iloc[mid_line:mid_line + 2, df_name.columns.get_loc("seconds")].mean(),
               color='black',linewidth=3.5)

#get axon val on post-shock day 1    
df_name = traces[(traces.name=='MR6e')&(traces.day==2)].set_index("frame")

peaks = {roi: find_peaks(df_name[roi], height=0.1, distance=1, prominence=0.35) 
         for roi in df_name.columns 
         if isinstance(roi, int) or 'axon' in roi}

axon_val = 'rescaled_axon'
frame_start = 0
frame_end = 10000

peak_scatter = pd.Series(peaks[axon_val][1]['peak_heights'], 
                         index=df_name.index[peaks[axon_val][0]]).loc[frame_start:frame_end]

df_name = df_name.loc[frame_start:frame_end].copy()
df_name.seconds -= df_name.seconds.iloc[0]

#axes 1: behavior post-shocks
df_name.plot(x='seconds', y='ybinned', color='xkcd:very dark brown', linewidth = 1.5, ax=axes[1], legend=False)
gb = df_name.seconds.groupby(df_name.pause)
prev = df_name.seconds.iloc[0]
ranges = pd.DataFrame({'first': gb.first(), 'last': gb.last()})

#axes 2: axon post-shocks
df_name.plot(x='seconds', y=axon_val, color=context_color[1], ax=axes[2], legend=False)
axes[2].scatter(df_name.loc[peak_scatter.index, 'seconds'], peak_scatter.values, color='orchid')

(mid_line, ), = np.where(np.diff(df_name.context.astype("category").cat.codes))
for ax in axes:
    ax.axvline(df_name.iloc[mid_line:mid_line + 2, df_name.columns.get_loc("seconds")].mean(),
               color='black',linewidth=3.5)
    
    
for i, (first, last) in ranges[ranges.index > 0].sort_index().iterrows():
    for ax in axes:
        ax.axvspan(first, last, color='xkcd:red orange', alpha=0.15)
        ax.axvspan(prev, first, color='white', alpha=0.08)
    prev = last
for ax in axes:
    ax.axvspan(prev, df_name.seconds.iloc[-1], color='xkcd:white', alpha=0.08)
    
axes[0].set(
    ylim = [0,0.6],
    yticks=([0, 0.6]),
    yticklabels=(['0.0','0.6']),
    ylabel='$Δf/f$',
    xticks=[0,600],
    xticklabels =['',''],
    xlabel = ''
    )

axes[1].set(
    yticks=(0,.59),
    xlim = [0,600],
    yticklabels=('0', '2m'),
    ylabel = 'cm'
    #title='mouse position on virtual track (cm)',
    )

axes[2].set(
    ylim = [0,0.6],
    yticks=([0, 0.6]),
    yticklabels=(['0.0','0.6']),
    ylabel='$Δf/f$',
    xticks=[0,600],
    xticklabels =['',''],
    xlabel = ''
    )

for ax in axes[0:2]:
    ax.tick_params('both', length=0, width=0, which='major')
    fig.subplots_adjust(hspace=0.35)
    for axis in ['left','bottom']:
        ax.spines[axis].set_linewidth(3)
        

axes[2].tick_params('y', length=0, width=0, which='major')
axes[2].tick_params('x', length=15, width=2.5, which='major')

for axis in ['left','bottom']:
    axes[2].spines[axis].set_linewidth(3)
    
plt.show()

## Figure 4c

In [None]:
df = traces[traces.is_peak]
samples = df[
    df.is_peak &
    (df_pause_length > min_frames) & 
    (df_pause_length < max_frames)
].groupby(["name", "day", "combo", "pause", "context", "status"]).zeroed_axon.mean().reset_index()

In [None]:
df = traces[traces.is_peak]
samples = df[
    df.is_peak &
    (df_pause_length > min_frames) & 
    (df_pause_length < max_frames)
].groupby(["name", "day", "combo", "pause", "context", "status"]).zeroed_axon.mean().reset_index()

contexts = ["fear", "safe"]
statuses = ["running", "paused"]

keys = [f"{ctxt}, {stat}" for ctxt in contexts for stat in statuses]

fig, ax = plt.subplots(1, figsize=(7.1 , 5))

axon = "zeroed_axon"

sns.boxplot(samples, x="day",hue="combo", y=axon, showfliers=False, whis = 1,
            hue_order = ['safe, running','safe, paused','fear, running', 'fear, paused'],
            palette = ['tab:green','tab:orange','tab:blue','tab:red'],
            width=0.8, boxprops={'zorder': 3, 'alpha':0.1, 'edgecolor': 'white'}, 
            capprops=dict(color="xkcd:gunmetal",linewidth=1.5),
            whiskerprops=dict(color="xkcd:gunmetal",linewidth=1.5),
            medianprops = dict(color="xkcd:gunmetal",linewidth=2),
            ax=ax)

ax = sns.stripplot(samples, x="day",hue="combo", y=axon,dodge = True, size = 5,jitter = 0.35,
            hue_order = ['safe, running','safe, paused','fear, running', 'fear, paused'],
            palette = ['tab:green','tab:orange','tab:blue','tab:red'],  
                   alpha=0.2)

ax.set(
    title = "NR-axon activity during running versus freezing epochs"
    ylim  = [0, 1],
    yticks = [0.00,0.25,0.50,0.75,1.00],
    yticklabels = [0.00,0.25,0.50,0.75,1.00],
    ylabel = 'mean normalized mean $Δf/f$',
    xticks = [0,1,2,3],
    xticklabels = ['pre-shocks', 'day 1', 'day 2', 'day 3'],
    xlabel = ''
    )

ax.tick_params('both', length=15, width=2.5, which='major')

for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)

handles, labels = plt.gca().get_legend_handles_labels()
  
order = [0,1,2,3]
  
leg = plt.legend([handles[i] for i in order], 
                 ['running \n(control)','freezing \n(control)',
                  'running \n(shocked)','freezing \n(shocked)'],
                  bbox_to_anchor=(.96,1)
                )

for line in leg.get_lines():
    line.set_linewidth(6.0)
    
sns.despine()
plt.show()

## Figure 4d

In [None]:
p2rs = {}
frame_length = 45
df = traces 
for key, df in df.groupby(["name", "day", "context"]):
    p2rs[key] = np.asarray(df.index)[get_p2r(df,
                                             pause_min=frame_length,
                                             unpaused_min=frame_length,
                                             half_window_length=frame_length)]

In [None]:
df = traces

axon_val = 'axon'
frame_end = 10000
x_tick_interpolation = FRAME_RATE

mid_line = next(iter(p2rs.values())).shape[1] / 2
index = np.arange(-mid_line, mid_line) / FRAME_RATE

axon = axon_val
columns = ['fear', 'safe']

palette = ['orchid','teal','black']

colors = {"axon": [palette[0],palette[1]],
          'mid_line': [palette[2], palette[2]]}

shading = {"axon": [palette[0],palette[1]],
          'mid_line': [palette[2],palette[2]]}

fig, axes = plt.subplots(ncols=2, sharey=False, figsize=(4.95*3,1.74*3))
axes = np.asarray(axes)

window_range = [-frame_length/15.49,frame_length/15.49]

for i in range(2):
    for j, context in enumerate(columns):
        p2r = np.concatenate([val for (name, day, ctxt), val in p2rs.items()
                              if ctxt == context and (day > 1 if i else day == 1)])
        trace = df[axon].values[p2r]
        mean = savgol_filter(trace.mean(0),10,5)
        std = trace.std(0)
        se = sp.stats.sem(trace)
        color = colors[axon][j]
        light_color = shading[axon][j]
        
        axes.flat[i].plot(index, mean, color=color, linewidth = 3,zorder=2.55)
        axes.flat[i].fill_between(index, mean - se, mean + se,
                                  color=light_color, alpha=0.2,zorder=2.56)
        axes.flat[i].axvline(x=0, linewidth=3, color=colors['mid_line'][j])
        
for i in range (2):
        axes.flat[i].axvspan(0, window_range[1], color='white', alpha=0.08, lw=0,zorder=0.4)
        axes.flat[i].axvspan(window_range[0], 0, color='xkcd:red orange', alpha=0.15, lw=0,zorder=0.4)
        ax.set_axisbelow(True)
        
clean_axes(axes)

xlabels = ['pre-shocks','post shocks']
ylabels = ['mean $Δf/f$ within epoch range','']

for i, ax in enumerate(axes.flat):
    ax.set(
        xlim = window_range,
        xticks=[-3,-1.5,0,1.5,3],
        xticklabels=['-3s','-1.5s','0s','1.5s','3s'],
        ylim=[-0.1,.2],
        yticks=[-0.1,0,0.1,0.2],
        yticklabels=['-0.1','0.0','0.1','0.2'],
        ylabel = ylabels[i],
        xlabel = xlabels[i],
    )
        
    ax.legend([], [], frameon=False)
    ax.tick_params('both', length=15, width=3, which='major')
    for axis in ['left','bottom']:
        ax.spines[axis].set_linewidth(3)


fig.subplots_adjust(wspace=0.2)
plt.show() 

## Figure 4e

In [None]:
nbins = 5
df = traces[traces.pause > 0].copy()

locs = []
bins = []
for key, val in df.groupby(["name", "day", "context", "pause"]):
    frames = val.frame.values
    bins.append(pd.cut(frames, np.linspace(frames[0], frames[-1] + 1, nbins + 1), right=False, labels=False))
    locs.append(val.index)
    
locs = np.concatenate(locs)
df["bins"] = np.concatenate(bins)[np.argsort(locs)]
df.axon -= 1

pause_has_peak = df.groupby(["name", "day", "context", "pause"]).is_peak.sum() > 0
df = pd.merge(df, 
              pause_has_peak.reset_index().rename({"is_peak": "pause_has_peak"}, axis=1),
              on=["name", "day", "context", "pause"]
             )
df = pd.merge(df, 
              pause_has_peak.reset_index().rename({"is_peak": "pause_has_peak"}, axis=1),
              on=["name", "day", "context", "pause"]
             )

In [None]:
day_label = {
    1: "pre-shocks",
    2: "post-shocks",
    3: "post-shocks",
    4: "post-shocks"
}

colors = ['xkcd:darkish purple','xkcd:olive']

id_cols = ["name", "day", "context", "pause"]
gb = traces[traces.pause > 0].groupby(id_cols).frame
pause_lengths = (gb.last() - gb.first() + 1).reset_index()
df_pause_length = pd.merge(df[id_cols], pause_lengths, on=id_cols).frame

min_frames = 15
max_frames = np.Inf

samples = df[
    df.is_peak &
    (df.context == "fear") &
    (df_pause_length > min_frames) & 
    (df_pause_length < max_frames)
].groupby(["name", "day", "pause", "bins"]).rescaled_axon.mean().reset_index()

samples["day_label"] = samples.day.apply(day_label.__getitem__)
samples["bins"] = samples.bins.apply(lambda t: f"{20 * t}-{20 * (t + 1)}%")

fig, ax = plt.subplots(1, figsize=(7.1 , 4.8))

ax = sns.lineplot(data=samples, 
                  x="bins", y="rescaled_axon", hue="day_label", palette = colors ,
                  hue_order = sorted(set(day_label.values())),
                  errorbar = ('se',1), style = 'day_label', dashes = False, markers = ['o']*2)

ax.set(
    title = 'NR-axon activity on day 1 within a freezing epoch',
    ylabel='mean normalized $Δf/f$', 
    yticks = [0.10,0.15,0.2,0.25,0.3],
    ylim = [0.2,0.5],
    xlabel = 'progress through freeze epoch',
    xticks = [0,1,2,3,4],
    )
ax.tick_params('both', length=15, width=2.5, which='major')

for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
plt.legend(bbox_to_anchor=(0.99,1.09))
sns.despine()
plt.show()

# Figure 5

## Figure 5a was made in biorender

## Figure 5b

In [None]:
mutli_day_traces = traces[traces.name.isin(['RE6e','MR5b','MR6b','MR6e'])]

p2rs = {}
frame_length = 45
df = mutli_day_traces 
for key, df in df.groupby(["name", "day", "context"]):
    p2rs[key] = np.asarray(df.index)[get_p2r(df,
                                             pause_min=frame_length,
                                             #pause_max=frame_length+15,
                                             unpaused_min=frame_length,
                                             half_window_length=frame_length)]

In [None]:
df = multi_day_traces

axon_val = 'axon'
frame_end = 10000
x_tick_interpolation = FRAME_RATE

mid_line = next(iter(p2rs.values())).shape[1] / 2
index = np.arange(-mid_line, mid_line) / FRAME_RATE

axon = axon_val
columns = ['fear', 'safe']

palette = ['orchid','teal','black']

colors = {"axon": [palette[0],palette[1]],
          'mid_line': [palette[2], palette[2]]}


shading = {"axon": [palette[0],palette[1]],
          'mid_line': [palette[2],palette[2]]}

fig, axes = plt.subplots(ncols=1, sharey=False, figsize=(4.95*1.5,1.74*3))
axes = np.asarray(axes)

window_range = [-frame_length/15.49,frame_length/15.49]

for i in range(1):
    for j, context in enumerate(columns):
        p2r = np.concatenate([val for (name, day, ctxt), val in p2rs.items()
                              if ctxt == context and (day == 1)])
        trace = df[axon].values[p2r]
        mean = savgol_filter(trace.mean(0),10,5)
        std = trace.std(0)
        se = sp.stats.sem(trace)
        color = colors[axon][j]
        light_color = shading[axon][j]
        
        axes.flat[i].plot(index, mean, color=color, linewidth = 3,zorder=2.55)
        axes.flat[i].fill_between(index, mean - se, mean + se,
                                  color=light_color, alpha=0.2,zorder=2.56)
        axes.flat[i].axvline(x=0, linewidth=3, color=colors['mid_line'][j])
        
for i in range (1):
        axes.flat[i].axvspan(0, window_range[1], color='white', alpha=0.08, lw=0,zorder=0.4)
        axes.flat[i].axvspan(window_range[0], 0, color='xkcd:red orange', alpha=0.15, lw=0,zorder=0.4)
        ax.set_axisbelow(True)
        
clean_axes(axes)

xlabels = ['pre-shocks','post shocks']
ylabels = ['mean $Δf/f$ within epoch range','']

for i, ax in enumerate(axes.flat):
    ax.set(
        xlim = window_range,
        xticks=[-3.7,-1.85,0,1.85,3.7],
        xticklabels=['-3s','-1.5s','0s','1.5s','3s'],
        ylim=[.85,1.25],
        yticks=[.85+.133333,.85+.133333*2,.85+.13333*3,.85+.13333*4],
        #yticklabels = [''],
        yticklabels=['-0.1','0.0','0.1','0.2'],
        ylabel = ylabels[i],
        xlabel = xlabels[i],
    )
        
    ax.legend([], [], frameon=False)
    ax.tick_params('both', length=15, width=3, which='major')
    for axis in ['left','bottom']:
        ax.spines[axis].set_linewidth(3)


fig.subplots_adjust(wspace=0.2)
plt.show() 

In [None]:
df = mutli_day_traces

axon_val = 'axon'
frame_end = 10000
x_tick_interpolation = FRAME_RATE

mid_line = next(iter(p2rs.values())).shape[1] / 2
index = np.arange(-mid_line, mid_line) / FRAME_RATE

axon = axon_val
columns = ['fear', 'safe']

palette = ['orchid','teal','black']

colors = {"axon": [palette[0],palette[1]],
          'mid_line': [palette[2], palette[2]]}


shading = {"axon": [palette[0],palette[1]],
          'mid_line': [palette[2],palette[2]]}

fig, axes = plt.subplots(ncols=1, sharey=False, figsize=(4.95*1.5,1.74*3))
axes = np.asarray(axes)

window_range = [-frame_length/15.49,frame_length/15.49]

for i in range(1):
    for j, context in enumerate(columns):
        p2r = np.concatenate([val for (name, day, ctxt), val in p2rs.items()
                              if ctxt == context and (day == 2)])
        trace = df[axon].values[p2r]
        mean = savgol_filter(trace.mean(0),10,5)
        std = trace.std(0)
        se = sp.stats.sem(trace)
        color = colors[axon][j]
        light_color = shading[axon][j]
        
        axes.flat[i].plot(index, mean, color=color, linewidth = 3,zorder=2.55)
        axes.flat[i].fill_between(index, mean - se, mean + se,
                                  color=light_color, alpha=0.2,zorder=2.56)
        axes.flat[i].axvline(x=0, linewidth=3, color=colors['mid_line'][j])
        
for i in range (1):
        axes.flat[i].axvspan(0, window_range[1], color='white', alpha=0.08, lw=0,zorder=0.4)
        axes.flat[i].axvspan(window_range[0], 0, color='xkcd:red orange', alpha=0.15, lw=0,zorder=0.4)
        ax.set_axisbelow(True)
        
clean_axes(axes)

xlabels = ['pre-shocks','post shocks']
ylabels = ['mean $Δf/f$ within epoch range','']

for i, ax in enumerate(axes.flat):
    ax.set(
        xlim = window_range,
        xticks=[-3.7,-1.85,0,1.85,3.7],
        xticklabels=['-3s','-1.5s','0s','1.5s','3s'],
        ylim=[.85,1.25],
        yticks=[.85+.133333,.85+.133333*2,.85+.13333*3,.85+.13333*4],
        #yticklabels = [''],
        yticklabels=['-0.1','0.0','0.1','0.2'],
        ylabel = ylabels[i],
        xlabel = xlabels[i],
    )
        
    ax.legend([], [], frameon=False)
    ax.tick_params('both', length=15, width=3, which='major')
    for axis in ['left','bottom']:
        ax.spines[axis].set_linewidth(3)

fig.subplots_adjust(wspace=0.2)
plt.show() 

In [None]:
df = mutli_day_traces

axon_val = 'axon'
frame_end = 10000
x_tick_interpolation = FRAME_RATE

mid_line = next(iter(p2rs.values())).shape[1] / 2
index = np.arange(-mid_line, mid_line) / FRAME_RATE

axon = axon_val
columns = ['fear', 'safe']

palette = ['orchid','teal','black']

colors = {"axon": [palette[0],palette[1]],
          'mid_line': [palette[2], palette[2]]}


shading = {"axon": [palette[0],palette[1]],
          'mid_line': [palette[2],palette[2]]}

fig, axes = plt.subplots(ncols=1, sharey=False, figsize=(4.95*1.5,1.74*3))
axes = np.asarray(axes)

window_range = [-frame_length/15.49,frame_length/15.49]

for i in range(1):
    for j, context in enumerate(columns):
        p2r = np.concatenate([val for (name, day, ctxt), val in p2rs.items()
                              if ctxt == context and (day == 3)])
        trace = df[axon].values[p2r]
        mean = savgol_filter(trace.mean(0),10,5)
        std = trace.std(0)
        se = sp.stats.sem(trace)
        color = colors[axon][j]
        light_color = shading[axon][j]
        
        axes.flat[i].plot(index, mean, color=color, linewidth = 3,zorder=2.55)
        axes.flat[i].fill_between(index, mean - se, mean + se,
                                  color=light_color, alpha=0.2,zorder=2.56)
        axes.flat[i].axvline(x=0, linewidth=3, color=colors['mid_line'][j])
        
for i in range (1):
        axes.flat[i].axvspan(0, window_range[1], color='white', alpha=0.08, lw=0,zorder=0.4)
        axes.flat[i].axvspan(window_range[0], 0, color='xkcd:red orange', alpha=0.15, lw=0,zorder=0.4)
        ax.set_axisbelow(True)
        
clean_axes(axes)

xlabels = ['pre-shocks','post shocks']
ylabels = ['mean $Δf/f$ within epoch range','']

for i, ax in enumerate(axes.flat):
    ax.set(
        xlim = window_range,
        xticks=[-3.7,-1.85,0,1.85,3.7],
        xticklabels=['-3s','-1.5s','0s','1.5s','3s'],
        ylim=[.85,1.25],
        yticks=[.85+.133333,.85+.133333*2,.85+.13333*3,.85+.13333*4],
        #yticklabels = [''],
        yticklabels=['-0.1','0.0','0.1','0.2'],
        ylabel = ylabels[i],
        xlabel = xlabels[i],
    )
        
    ax.legend([], [], frameon=False)
    ax.tick_params('both', length=15, width=3, which='major')
    for axis in ['left','bottom']:
        ax.spines[axis].set_linewidth(3)

fig.subplots_adjust(wspace=0.2)
plt.show() 

In [None]:
df = mutli_day_traces

axon_val = 'axon'
frame_end = 10000
x_tick_interpolation = FRAME_RATE

mid_line = next(iter(p2rs.values())).shape[1] / 2
index = np.arange(-mid_line, mid_line) / FRAME_RATE

axon = axon_val
columns = ['fear', 'safe']

palette = ['orchid','teal','black']

colors = {"axon": [palette[0],palette[1]],
          'mid_line': [palette[2], palette[2]]}


shading = {"axon": [palette[0],palette[1]],
          'mid_line': [palette[2],palette[2]]}

fig, axes = plt.subplots(ncols=1, sharey=False, figsize=(4.95*1.5,1.74*3))
axes = np.asarray(axes)

window_range = [-frame_length/15.49,frame_length/15.49]

for i in range(1):
    for j, context in enumerate(columns):
        p2r = np.concatenate([val for (name, day, ctxt), val in p2rs.items()
                              if ctxt == context and (day == 4)])
        trace = df[axon].values[p2r]
        mean = savgol_filter(trace.mean(0),10,5)
        std = trace.std(0)
        se = sp.stats.sem(trace)
        color = colors[axon][j]
        light_color = shading[axon][j]
        
        axes.flat[i].plot(index, mean, color=color, linewidth = 3,zorder=2.55)
        axes.flat[i].fill_between(index, mean - se, mean + se,
                                  color=light_color, alpha=0.2,zorder=2.56)
        axes.flat[i].axvline(x=0, linewidth=3, color=colors['mid_line'][j])
        
for i in range (1):
        axes.flat[i].axvspan(0, window_range[1], color='white', alpha=0.08, lw=0,zorder=0.4)
        axes.flat[i].axvspan(window_range[0], 0, color='xkcd:red orange', alpha=0.15, lw=0,zorder=0.4)
        ax.set_axisbelow(True)
        
clean_axes(axes)

xlabels = ['pre-shocks','post shocks']
ylabels = ['mean $Δf/f$ within epoch range','']

for i, ax in enumerate(axes.flat):
    ax.set(
        xlim = window_range,
        xticks=[-3.7,-1.85,0,1.85,3.7],
        xticklabels=['-3s','-1.5s','0s','1.5s','3s'],
        ylim=[.85,1.25],
        yticks=[.85+.133333,.85+.133333*2,.85+.13333*3,.85+.13333*4],
        #yticklabels = [''],
        yticklabels=['-0.1','0.0','0.1','0.2'],
        ylabel = ylabels[i],
        xlabel = xlabels[i],
    )
        
    ax.legend([], [], frameon=False)
    ax.tick_params('both', length=15, width=3, which='major')
    for axis in ['left','bottom']:
        ax.spines[axis].set_linewidth(3)

fig.subplots_adjust(wspace=0.2)
plt.show() 

## Figure 5c

In [None]:
df = multi_day_traces[multi_day_traces.is_peak]
samples = df[
    df.is_peak &
    (df_pause_length > min_frames) & 
    (df_pause_length < max_frames)
].groupby(["name", "day", "combo", "pause", "context", "status"]).zeroed_axon.mean().reset_index()

contexts = ["fear", "safe"]
statuses = ["running", "paused"]

keys = [f"{ctxt}, {stat}" for ctxt in contexts for stat in statuses]

In [None]:
fig, ax = plt.subplots(1, figsize=(7.1 , 5))

axon = "zeroed_axon"

sns.boxplot(samples, x="day",hue="combo", y=axon, showfliers=False, whis = 1,
            hue_order = ['safe, running','safe, paused','fear, running', 'fear, paused'],
            palette = ['tab:green','tab:orange','tab:blue','tab:red'],
            width=0.8, boxprops={'zorder': 3, 'alpha':0.1, 'edgecolor': 'white'}, 
            capprops=dict(color="xkcd:gunmetal",linewidth=1.5),
            whiskerprops=dict(color="xkcd:gunmetal",linewidth=1.5),
            medianprops = dict(color="xkcd:gunmetal",linewidth=2),
            ax=ax)

ax = sns.stripplot(samples, x="day",hue="combo", y=axon,dodge = True, size = 5,jitter = 0.35,
            hue_order = ['safe, running','safe, paused','fear, running', 'fear, paused'],
            palette = ['tab:green','tab:orange','tab:blue','tab:red'],  
                   alpha=0.2)

ax.set(
    title = "NR-axon activity during running versus freezing epochs"
    ylim  = [0, 1],
    yticks = [0.00,0.25,0.50,0.75,1.00],
    yticklabels = [0.00,0.25,0.50,0.75,1.00],
    ylabel = 'mean normalized mean $Δf/f$',
    xticks = [0,1,2,3],
    xticklabels = ['pre-shocks', 'day 1', 'day 2', 'day 3'],
    xlabel = ''
    )

ax.tick_params('both', length=15, width=2.5, which='major')

for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)

handles, labels = plt.gca().get_legend_handles_labels()
  
order = [0,1,2,3]
  
leg = plt.legend([handles[i] for i in order], 
                 ['running \n(control)','freezing \n(control)',
                  'running \n(shocked)','freezing \n(shocked)'],
                  bbox_to_anchor=(.96,1)
                )

for line in leg.get_lines():
    line.set_linewidth(6.0)
    
sns.despine()
plt.show()

## Figure 5d

In [None]:
nbins = 5
df = multi_day_traces[multi_day_traces.pause > 0].copy()

locs = []
bins = []
for key, val in df.groupby(["name", "day", "context", "pause"]):
    frames = val.frame.values
    bins.append(pd.cut(frames, np.linspace(frames[0], frames[-1] + 1, nbins + 1), right=False, labels=False))
    locs.append(val.index)
    
locs = np.concatenate(locs)
df["bins"] = np.concatenate(bins)[np.argsort(locs)]
df.axon -= 1

pause_has_peak = df.groupby(["name", "day", "context", "pause"]).is_peak.sum() > 0
df = pd.merge(df, 
              pause_has_peak.reset_index().rename({"is_peak": "pause_has_peak"}, axis=1),
              on=["name", "day", "context", "pause"]
             )
df = pd.merge(df, 
              pause_has_peak.reset_index().rename({"is_peak": "pause_has_peak"}, axis=1),
              on=["name", "day", "context", "pause"]
             )

In [None]:
day_label = {
    1: "pre-shocks",
    2: "post-shocks",
    3: "post-shocks",
    4: "post-shocks"
}

colors = ['xkcd:darkish purple','xkcd:olive']

id_cols = ["name", "day", "context", "pause"]
gb = traces[traces.pause > 0].groupby(id_cols).frame
pause_lengths = (gb.last() - gb.first() + 1).reset_index()
df_pause_length = pd.merge(df[id_cols], pause_lengths, on=id_cols).frame

min_frames = 15*3
max_frames = np.Inf

samples = df[
    df.is_peak &
    (df.context == "fear") &
    (df_pause_length > min_frames) & 
    (df_pause_length < max_frames)
].groupby(["name", "day", "pause", "bins"]).rescaled_axon.mean().reset_index()

samples["day_label"] = samples.day.apply(day_label.__getitem__)
samples["bins"] = samples.bins.apply(lambda t: f"{20 * t}-{20 * (t + 1)}%")

fig, ax = plt.subplots(1, figsize=(7.1 , 4.8))

ax = sns.lineplot(data=samples, 
                  x="bins", y="rescaled_axon", hue="day_label", palette = colors ,
                  hue_order = sorted(set(day_label.values())),
                  errorbar = ('se',1), style = 'day_label', dashes = False, markers = ['o']*2)

ax.set(
    title = 'NR-axon activity on day 1 within a freezing epoch',
    ylabel='mean normalized $Δf/f$', 
    yticks = [0.10,0.15,0.2,0.25,0.3],
    ylim = [0.2,0.5],
    xlabel = 'progress through freeze epoch',
    xticks = [0,1,2,3,4],
    )
ax.tick_params('both', length=15, width=2.5, which='major')

for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
plt.legend(bbox_to_anchor=(0.99,1.09))
sns.despine()
plt.show()

# Figure 6

## Figure 6a was made in biorender

## Figure 6b

In [None]:
#build feature space for model
cols = ["name", "day", "pause"]

interval_start = traces.groupby(cols, as_index=False).frame.first()
interval_end = traces.groupby(cols, as_index=False).frame.last()

interval_boundaries = pd.merge(interval_start, interval_end, on=cols, suffixes=('_first', '_last'))

intervals = pd.merge(traces[cols + ["frame"]], interval_boundaries, on=cols)

intervals["interval_elapsed"] = intervals.frame - intervals.frame_first
intervals["interval_remaining"] = intervals.frame_last - intervals.frame
intervals["interval_progress"] = (intervals.interval_elapsed /
                                  (intervals.frame_last - intervals.frame_first))

for col in intervals.columns:
    if not col.startswith("interval_"): continue
    intervals[col.replace("interval", "pause")] = np.where(intervals.pause > 0, intervals[col], -1)
    intervals[col.replace("interval", "running")] = np.where(intervals.pause > 0, -1, intervals[col])
    
intervals = intervals.drop([col for col in intervals.columns if col.startswith("frame")], axis=1)

def feature_fwd(df, col, offset, grp_cols=["name", "day", "context"], idx_col="frame"):
    new_arr = np.empty(len(df), dtype=df.dtypes[col])
    for key, gdf in df.reset_index().groupby(grp_cols)[[idx_col, col]]:
        indices = np.searchsorted(gdf[idx_col], gdf[idx_col] + offset)
        new_arr[gdf.index] = gdf[col].values.take(indices, mode="clip")
        
    return new_arr

velocity_fwd_8 = feature_fwd(traces, "velocity", 8)
velocity_back_8 = feature_fwd(traces, "velocity", -8)

feature_df = traces[['name', 'day', 'lick', 'ybinned',
       'recorded_velocity', 'velocity', 'acceleration',
       'lap', 'pause', 'context', 'rescaled_axon', 'pupil_area_smoothed', 'pupil_xpos',
       'pupil_ypos', 'is_paused', 'is_running','is_backtracking', 'is_postpause']].copy()

vel_offsets = [-12, -8, 8, 12]
feature_df = pd.concat([feature_df,
                        intervals.drop(cols, axis=1),
                        pd.DataFrame({f"velocity_fwd_{offset}": 
                                      feature_fwd(traces, "velocity", offset)
                                      for offset in vel_offsets},
                                     index=traces.index
                                    )
                       ], axis=1)

feature_df = feature_df[feature_df.context.isin(["fear", "safe"])]
feature_df.pause = feature_df.pause > 0

In [None]:
#split data
train_laps, test_laps = sklearn.model_selection.train_test_split(np.unique(data.lap), train_size=0.8)
def split_by_lap(df, test_fraction=0.2, gb_cols=["name", "day", "context"], lap_col="lap"):
    cols = gb_cols + [lap_col]
    keys = list(df.groupby(cols, as_index=False).groups.keys())

    laps = pd.DataFrame(keys, columns=cols)

    test_lap = np.zeros(len(laps), dtype=bool)
    
    if gb_cols:
        for key, gdf in laps.groupby(gb_cols):
            test_lap[np.random.choice(gdf.index, 
                                      size=round(test_fraction * len(gdf)), 
                                      replace=False)] = True
    else:
        test_lap[np.random.choice(laps.index, 
                                  size=round(test_fraction * len(laps)), 
                                  replace=False)] = True

    laps["test"] = test_lap

    return pd.merge(df[cols], laps, on=cols).test.values

In [None]:
#train model 
import warnings
warnings.filterwarnings("ignore", 
                        message="pandas.Int64Index is deprecated and "
                        "will be removed from pandas in a future version. "
                        "Use pandas.Index with the appropriate dtype instead.")
from IPython.core.display import HTML

# See https://github.com/bstriner/keras-tqdm/issues/21#issuecomment-443019223
HTML("""
<style>
.p-Widget.jp-OutputPrompt.jp-OutputArea-prompt:empty {
  padding: 0;
  border: 0;
}
</style>
""")
df = feature_df
gb_cols = ["day", "context"]
label = "rescaled_axon"
drop_cols = ["name", "day", "context", "lap", label]
n_draws = 100

#random seed is set to 42 for reproducability
np.random.seed(42)
results = {}
for key, gdf in tqdm(df.groupby(gb_cols)):
    for _ in tqdm(range(n_draws), leave=False):
        split = split_by_lap(gdf, gb_cols=[])
        X = gdf.drop(drop_cols, axis=1)
        y = gdf[label]

        model = xgb.XGBRegressor(
            gamma=1,
            learning_rate=0.01,
            n_estimators=1000,
            base_score=1,
            verbosity=0
        )

        model.fit(X[~split], y[~split], 
                  eval_set=[(X[split], y[split])],
                  early_stopping_rounds=5,
                  verbose=False
                 )

        results.setdefault(key, []).append({
            "split": split,
            "model": model,
            "r2": sklearn.metrics.r2_score(y[split], model.predict(X[split]))
        })

In [None]:
#create dataframes to plot
df = pd.DataFrame([[name, day, context, val["r2"]]
                   for (name, day, context), vs in results.items()
                   for val in vs], 
                  columns=["name", "day", "context", "r2"])

r2s = pd.DataFrame([[name, day, context, val["r2"]]
                   for (name, day, context), vs in results.items()
                   for val in vs], 
                  columns=["name", "day", "context", "r2"])

peak_v_r2 = pd.merge(r2s, peak_medians, on=["name", "day", "context"])

In [None]:
#plot figure
df = peak_v_r2[peak_v_r2.day != 1]

fig, ax = plt.subplots(1, figsize=(6.4, 4.8))

ax = sns.regplot(x='peak',y = 'r2', data = df, x_ci = 'sd', scatter=False,
            line_kws={"color": "dimgrey"})

sns.scatterplot(data=df, x="peak", y="r2", hue="name", s=20, alpha=0.15, label="_ignore", ax=ax, palette = 'husl')

ax.set(
    ylabel='$r^2$ of models', 
    yticks = [0,0.25,0.5,0.75,1],
    xticks = [0.15,0.3,0.45,0.6],
    xlabel = 'Median axonal peak  $Δf/f$',
    ylim = [-0.2,1]
    )

ax.tick_params('both', length=15, width=2.5, which='major')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
sns.despine()
plt.legend([],[],frameon=False)
plt.show()

## Figure 6c

In [None]:
filt = {"name": 'MR7e',
        "day": 1,
        "context": 'fear'
       }

vals = results[tuple(filt[k] for k in ["name", "day", "context"])]

result = max(vals, key=lambda t: t["r2"])
df = feature_df

bools = np.logical_and.reduce([df[col] == val for col, val in filt.items()])

test_y = df.loc[bools, 'axon'].values[result["split"]]
test_yhat = result["model"].predict(df.drop(['axon', "lap", *filt.keys()], axis=1).values[bools][result["split"]])

plt.plot((test_y -1)[20:794], color=("xkcd:blue grey"), linewidth = 1.7)
plt.plot((test_yhat -1)[20:794], color=("xkcd:rose"),linestyle = '-')
ax = plt.gca()


ax.set(
    ylabel = 'Actual or predicted $Δf/f$',
    ylim = [-0.4,0.6],
    xlim = [0,620],
    yticks = [
             -0.2,
              0.0,
              0.2,
              0.4,
              0.6,
              0.8,
             ],
    xticks = [0,155,310,465,620],
    xticklabels = ['0s','10s','20s','30s','40s']
    )

ax.tick_params(
        axis='y',         
        which='both',     
        length=15, 
        width=2.5,
        )

for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
sns.despine()
plt.show()

## Figure 6d

In [None]:
filt = {"name": 'MR7e',
        "day": 2,
        "context": 'fear'
       }

vals = results[tuple(filt[k] for k in ["name", "day", "context"])]

result = max(vals, key=lambda t: t["r2"])
df = feature_df

bools = np.logical_and.reduce([df[col] == val for col, val in filt.items()])

test_y = df.loc[bools, 'axon'].values[result["split"]]
test_yhat = result["model"].predict(df.drop(['axon', "lap", *filt.keys()], axis=1).values[bools][result["split"]])

plt.plot((test_y -1)[20:794], color=("xkcd:blue grey"), linewidth = 1.7)
plt.plot((test_yhat -1)[20:794], color=("xkcd:rose"),linestyle = '-')
ax = plt.gca()


ax.set(
    ylabel = 'Actual or predicted $Δf/f$',
    ylim = [-0.4,0.6],
    xlim = [0,620],
    yticks = [
             -0.2,
              0.0,
              0.2,
              0.4,
              0.6,
              0.8,
             ],
    xticks = [0,155,310,465,620],
    xticklabels = ['0s','10s','20s','30s','40s']
    )

ax.tick_params(
        axis='y',         
        which='both',     
        length=15, 
        width=2.5,
        )

for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
sns.despine()
plt.show()

## Figure 6e

In [None]:
filt = {"name": 'MR7e',
        "day": 3,
        "context": 'fear'
       }

vals = results[tuple(filt[k] for k in ["name", "day", "context"])]

result = max(vals, key=lambda t: t["r2"])
df = feature_df

bools = np.logical_and.reduce([df[col] == val for col, val in filt.items()])

test_y = df.loc[bools, 'axon'].values[result["split"]]
test_yhat = result["model"].predict(df.drop(['axon', "lap", *filt.keys()], axis=1).values[bools][result["split"]])

plt.plot((test_y -1)[20:794], color=("xkcd:blue grey"), linewidth = 1.7)
plt.plot((test_yhat -1)[20:794], color=("xkcd:rose"),linestyle = '-')
ax = plt.gca()


ax.set(
    ylabel = 'Actual or predicted $Δf/f$',
    ylim = [-0.4,0.6],
    xlim = [0,620],
    yticks = [
             -0.2,
              0.0,
              0.2,
              0.4,
              0.6,
              0.8,
             ],
    xticks = [0,155,310,465,620],
    xticklabels = ['0s','10s','20s','30s','40s']
    )

ax.tick_params(
        axis='y',         
        which='both',     
        length=15, 
        width=2.5,
        )

for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
sns.despine()
plt.show()

## Figure 6f

In [None]:
filt = {"name": 'MR7e',
        "day": 4,
        "context": 'fear'
       }

vals = results[tuple(filt[k] for k in ["name", "day", "context"])]

result = max(vals, key=lambda t: t["r2"])
df = feature_df

bools = np.logical_and.reduce([df[col] == val for col, val in filt.items()])

test_y = df.loc[bools, 'axon'].values[result["split"]]
test_yhat = result["model"].predict(df.drop(['axon', "lap", *filt.keys()], axis=1).values[bools][result["split"]])

plt.plot((test_y -1)[20:794], color=("xkcd:blue grey"), linewidth = 1.7)
plt.plot((test_yhat -1)[20:794], color=("xkcd:rose"),linestyle = '-')
ax = plt.gca()


ax.set(
    ylabel = 'Actual or predicted $Δf/f$',
    ylim = [-0.4,0.6],
    xlim = [0,620],
    yticks = [
             -0.2,
              0.0,
              0.2,
              0.4,
              0.6,
              0.8,
             ],
    xticks = [0,155,310,465,620],
    xticklabels = ['0s','10s','20s','30s','40s']
    )

ax.tick_params(
        axis='y',         
        which='both',     
        length=15, 
        width=2.5,
        )

for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
sns.despine()
plt.show()

## Figure 6g

In [None]:
thresh_mice = ["A1","A2", "A4","MR7c","B8","MR7e","RE6e", "MR7e"]

df = pd.DataFrame([[ day, context, val["r2"]]
                   for (name, day, context), vs in results.items()
                   for val in vs], 
                  columns=[ "name","day", "context", "r2"])

df = df[df.name.isin(thresh_mice)]

ax = sns.boxplot(data=df, x="day", y="r2", hue="context",whis = 1, showfliers=True, boxprops=dict(alpha=.5),
                 palette = ('teal','grey'))
ax = sns.stripplot(data=df, x="day", y="r2", hue="context", alpha = 0.5, dodge = 'true',
                 palette = ('teal','grey'))
ax.set(
    ylim=[-0.1,.8],
    ylabel ='model fit ($r^2$)', 
    yticks = [-1,-0.75,-0.5,-0.25,0,0.25,0.5,0.75,1],
    xticks = [0,1,2,3],
    xticklabels=['pre-shock baseline','recall day 1','recall day 2','recall day 3']
    )
sns.despine()
plt.legend([],[],frameon=False)
plt.tight_layout()
plt.show()

## Figure 6h

In [None]:
feature_categories = {
    "lick": ["lick"],
    "location": ["ybinned"],
    "pupil_information": [
        "pupil_area_smoothed",
        "pupil_xpos", "pupil_ypos"
    ],
    
    "pausing": [
        "pause", "is_paused", "pause_remaining", "is_postpause", "pause_progress", "pause_elapsed",
    ],
    
    "running": [
        "is_running", "running_progress","running_remaining", "is_backtracking","running_elapsed",
    ],
    
    "combined pausing/running": [
        "interval_elapsed", "interval_remaining","interval_progress"
    ],
    
    "velocity_and_offsets": [
        "recorded_velocity", "velocity", 
        "acceleration", 
        "velocity_fwd_-12", "velocity_fwd_-8",
        "velocity_fwd_8", "velocity_fwd_12"
    ]
}

In [None]:
imps = []
for (name, day, context), vs in results.items():
    if name not in nice_mice: continue
    for v in vs:
        bst = v["model"].get_booster()
        imps.append({"name": name, "day": day, "context": context, 
                     **bst.get_score(importance_type="weight")})
imps = pd.DataFrame(imps).fillna(0)
imps.iloc[:, 3:] /= imps.iloc[:, 3:].sum(1).values[:, None]

In [None]:
cat_imps = imps[["name", "day", "context"]].copy()
for k, v in feature_categories.items():
    cat_imps[k] = sum(imps[vv] for vv in v if vv in imps.columns) * 100

df = cat_imps.melt(id_vars=["name", "day", "context"])
df = df[df.day == 2]
order = df.groupby("variable").value.mean().sort_values(ascending=False).index

top_n = 6
if top_n is not None:
    df = df[df.variable.isin(order[:top_n])]
    order = order[:top_n]

fig, ax = plt.subplots(1, figsize=(6.9, 5))

ax = sns.barplot(
    data=df,
    x="value",
    y="variable",
    hue="context",
    order=order,
    orient = 'h',
    palette = ['teal','orchid']               
           )

ax.set(
    ylabel = '',
    xlabel = 'Mean gain fraction on day 1',
    yticklabels = ['freezing','velocities','interval',
                   'running','location', 'pupil'],
    xticks = [0,15,30,45],
    xticklabels = ['0%','15%','30%','45%']
    )

ax.tick_params(
        axis='y',       
        which='both',      
        length=15, 
        width=2.5,
        )

for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
sns.despine()
plt.legend([],[],frameon=False)

plt.show()

# Supplementary Figure 1

## Supplementary Figure 1a

In [None]:
pause_fraction = (all_behavior.pause > 0).groupby(
    [all_behavior[col] for col in ["context", "condition", "name", "day"]]).mean().reset_index()

pause_fraction_normalized = pause_fraction[pause_fraction.context.isin(["fear", "safe"])].copy()

for (context, condition, name),df in pause_fraction_normalized[pause_fraction_normalized.day > 0].groupby(["context","condition", "name"]):    
    pause_fraction_normalized.loc[df.index, "pause"] = df.pause - pause_fraction[
        (pause_fraction.context.isin([context]))
        & (pause_fraction.day == 1)
        & (pause_fraction.condition == condition)
        & (pause_fraction.name == name)
    ].pause.mean()
    
df = pause_fraction_normalized

fig, ax = plt.subplots(1, figsize=(6.4, 4.8))


for condition in ["nr_standard"]:
    for context in ["fear"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
               hue = 'name', x="day", y="pause", marker="o") 
                               
ax.set(
    ylabel='baseline normalized \n time spent freezing', 
    yticks = [0,0.2,0.4,0.6,0.8],
    yticklabels = ['0%','20%','40%','60%','80%'],
    xlabel = '',
    ylim = [0,.87],
    #xlim = [1,4],
    xticks = [1,2,3,4],
    xticklabels=['pre-shocks','day 1','day 2','day 3']
    )
              
    
ax.tick_params('both', length=15, width=2.5, which='major')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
sns.despine()

leg = plt.legend(bbox_to_anchor=(1.5,1))
plt.show()

## Supplementary Figure 1b

In [None]:
pause_fraction = (all_behavior.pause > 0).groupby(
    [all_behavior[col] for col in ["context", "condition", "name", "day"]]).mean().reset_index()

df = pause_fraction
colors = ["teal","orchid"]

color_it = iter(colors)

fig, ax = plt.subplots(1, figsize=(6.4, 4.8))

conditions = {k for mouse in mouse_config.values() for k in mouse["condition"]}
is_cond = pd.DataFrame({
    cond: df.name.apply(
        lambda t: cond in mouse_config[t]["condition"])
    for cond in conditions
})

for condition in ["dreadd_dcz"]:
    for context in ["fear"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
                x="day", y="pause", marker="o", errorbar=("ci", 95), err_style="band", color=colors[0]
                             )
            
for condition in ["dreadd_dcz"]:
    for context in ["safe"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
                x="day", y="pause", marker="o", errorbar=("ci", 95), err_style="band", color=  colors[1]
                             )
ax.set(
    ylabel='time spent freezing', 
    yticks = [0,0.2,0.4,0.6,0.8],
    yticklabels = ['0%','20%','40%','60%','80%'],
    xlabel = '',
    xlim = [.97,4],
    xticks = [1,2,3,4],
    xticklabels=['pre-shocks','day 1','day 2','day 3']
    )
              
    
ax.tick_params('both', length=15, width=2.5, which='major')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
sns.despine()

plt.show()

## Supplementary Figure 1c

In [None]:
day = 2

colors = ["teal", "xkcd:turquoise", "xkcd:prussian blue"]
colors = ["orchid","xkcd:rose pink","xkcd:rich purple"] 

pause_lengths = get_pause_lengths(all_behavior[all_behavior.context.isin(["fear", "safe"])
                                               & (all_behavior.day == day)])

ax = sns.ecdfplot(data=pause_lengths[pause_lengths.condition == "dreadd_dcz"],
                  x='length', hue='context',
                  palette = ['xkcd:rose pink',"xkcd:turquoise"], stat='proportion', log_scale=True, alpha = 1)

ax = sns.ecdfplot(data=pause_lengths[pause_lengths.condition == "nr_standard"],
                  x='length', hue='context', 
                  palette = ['orchid','teal'], stat='proportion', log_scale=True, alpha = 1)


ax = sns.ecdfplot(data=pause_lengths[pause_lengths.condition == "nr_no_shock_control"],
                  x='length', hue='context',
                  palette = ['xkcd:rich purple','xkcd:prussian blue'], stat='proportion', log_scale=True, alpha = 0.8)

ax.set(
    title = 'Cumulative density',
    ylabel ='cumulative freeze length #', 
    xticks = [1,3,10,30,100],
    xlim = [0,100],
    xticklabels=['1s','3s','10s','30s','100s'],
    xlabel = 'length of freezed epoch (log scale)',
    )

sns.despine()
ax.tick_params('both', length=15, width=2.5, which='major')
ax.tick_params('both', length=7.5, width=2.5, which='minor')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
plt.legend([],[], frameon=False)

plt.show()


## Supplementary Figure 1d

In [None]:
pause_fraction = (all_behavior.pause > 0).groupby(
    [all_behavior[col] for col in ["context", "condition", "name", "day"]]).mean().reset_index()

df = pause_fraction
colors = ["teal","orchid"]

color_it = iter(colors)

fig, ax = plt.subplots(1, figsize=(6.4, 4.8))

conditions = {k for mouse in mouse_config.values() for k in mouse["condition"]}
is_cond = pd.DataFrame({
    cond: df.name.apply(
        lambda t: cond in mouse_config[t]["condition"])
    for cond in conditions
})

for condition in ["nr_standard"]:
    for context in ["fear"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
                x="day", y="pause", marker="o", errorbar=("ci", 95), err_style="band", color=colors[0]
                             )
            
for condition in ["nr_standard"]:
    for context in ["safe"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
                x="day", y="pause", marker="o", errorbar=("ci", 95), err_style="band", color=  colors[1]
                             )
ax.set(
    ylabel='time spent freezing', 
    yticks = [0,0.2,0.4,0.6,0.8],
    yticklabels = ['0%','20%','40%','60%','80%'],
    xlabel = '',
    xlim = [.97,4],
    xticks = [1,2,3,4],
    xticklabels=['pre-shocks','day 1','day 2','day 3']
    )
              
    
ax.tick_params('both', length=15, width=2.5, which='major')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
sns.despine()

plt.show()

## Supplementary Figure 1e

In [None]:

colors = ["xkcd:dark peach","xkcd:slate green",]

df = all_behavior[(all_behavior.context.isin(["fear","safe"]) 
                & (all_behavior.day.isin([1,2,3,4]))
                & (all_behavior.lap == 0)   
                & (all_behavior.condition.isin(['nr_standard','dreadd_dcz','dreadd_saline_control']))
                  )]
                               
fig, ax = plt.subplots(1, figsize=(6.4, 4.8))

ax = sns.boxplot(data=df.groupby(["name", "day","condition", "context","is_first_lap"],as_index=False).is_backtracking.mean(), 
                x="day", y="is_backtracking", hue="context",dodge = True, 
                palette = colors, fliersize = 0,  width = 0.5, boxprops=dict(alpha=.5))

ax = sns.stripplot(data=df.groupby(["name", "day", "condition","context", "is_first_lap"],as_index=False).is_backtracking.mean(), 
                x="day", y="is_backtracking", hue="context", 
                palette = colors, dodge = True, alpha =0.9)


ax.set(
    ylabel='% time moving backwards', 
    ylim = [-0.005,0.10],
    yticks = [0,0.05,0.10,0.15],
    yticklabels = ['0','5','10','15'],
    xlabel = '',
    xticks = [0,1,2,3],
    xticklabels=['pre-shocks','day 1','day 2','day 3']
    )
              
ax.tick_params('both', length=15, width=2.5, which='major')
ax.tick_params('both', length=7.5, width=2.5, which='minor')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
plt.legend([],[], frameon=False)

sns.despine()
plt.show()

## Supplementary Figure 1f

In [None]:
pause_fraction = (all_behavior.pause > 0).groupby(
    [all_behavior[col] for col in ["context", "condition", "name", "day"]]).mean().reset_index()

df = pause_fraction
colors = ["teal","orchid"]

color_it = iter(colors)

fig, ax = plt.subplots(1, figsize=(6.4, 4.8))

conditions = {k for mouse in mouse_config.values() for k in mouse["condition"]}
is_cond = pd.DataFrame({
    cond: df.name.apply(
        lambda t: cond in mouse_config[t]["condition"])
    for cond in conditions
})

for condition in ["nr_no_shock_control"]:
    for context in ["fear"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
                x="day", y="pause", marker="o", errorbar=("ci", 95), err_style="band", color=colors[0]
                             )
            
for condition in ["nr_no_shock_control"]:
    for context in ["safe"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
                x="day", y="pause", marker="o", errorbar=("ci", 95), err_style="band", color=  colors[1]
                             )
ax.set(
    ylabel='time spent freezing', 
    yticks = [0,0.2,0.4,0.6,0.8],
    yticklabels = ['0%','20%','40%','60%','80%'],
    xlabel = '',
    xlim = [.97,4],
    xticks = [1,2,3,4],
    xticklabels=['pre-shocks','day 1','day 2','day 3']
    )
              
    
ax.tick_params('both', length=15, width=2.5, which='major')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
sns.despine()

plt.show()

## Supplementary Figure 1g

In [None]:
pause_fraction = (all_behavior.pause > 0).groupby(
    [all_behavior[col] for col in ["context", "condition", "name", "day"]]).mean().reset_index()

df = pause_fraction
colors = ["teal","orchid","xkcd:denim","xkcd:merlot"]

color_it = iter(colors)

fig, ax = plt.subplots(1, figsize=(6.4, 4.8))

conditions = {k for mouse in mouse_config.values() for k in mouse["condition"]}
is_cond = pd.DataFrame({
    cond: df.name.apply(
        lambda t: cond in mouse_config[t]["condition"])
    for cond in conditions
})

for condition in ["dreadd_mcherry_control"]:
    for context in ["fear"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
                x="day", y="pause", marker="o", errorbar=("ci", 95), err_style="band", color=colors[0]
                             )
            
for condition in ["dreadd_mcherry_control"]:
    for context in ["safe"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
                x="day", y="pause", marker="o", errorbar=("ci", 95), err_style="band", color=colors[1]
                             )

for condition in ["dreadd_saline_control"]:
    for context in ["fear"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
                x="day", y="pause", marker="o", errorbar=("ci", 95), err_style="band", color=colors[2]
                             )
            
for condition in ["dreadd_saline_control"]:
    for context in ["safe"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
                x="day", y="pause", marker="o", errorbar=("ci", 95), err_style="band", color=colors[3]
                             )                             
ax.set(
    ylabel='time spent freezing', 
    yticks = [0,0.2,0.4,0.6,0.8],
    yticklabels = ['0%','20%','40%','60%','80%'],
    xlabel = '',
    xlim = [.97,4],
    xticks = [1,2,3,4],
    xticklabels=['pre-shocks','day 1','day 2','day 3']
    )
              
    
ax.tick_params('both', length=15, width=2.5, which='major')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
sns.despine()

plt.show()

## Supplementary Figure 1h

In [None]:
pause_fraction = (all_behavior.pause > 0).groupby(
    [all_behavior[col] for col in ["context", "condition", "name", "day"]]).mean().reset_index()

df = pause_fraction
colors = ["teal","orchid"]

color_it = iter(colors)

fig, ax = plt.subplots(1, figsize=(6.4, 4.8))

conditions = {k for mouse in mouse_config.values() for k in mouse["condition"]}
is_cond = pd.DataFrame({
    cond: df.name.apply(
        lambda t: cond in mouse_config[t]["condition"])
    for cond in conditions
})

for condition in ["nr_good_imaging"]:
    for context in ["fear"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
                x="day", y="pause", marker="o", errorbar=("ci", 95), err_style="band", color=colors[0]
                             )
            
for condition in ["nr_good_imaging"]:
    for context in ["safe"]:
            ax = sns.lineplot(data=df[(df.context == context) 
                                             & (df.condition == condition)],
                x="day", y="pause", marker="o", errorbar=("ci", 95), err_style="band", color=  colors[1]
                             )
ax.set(
    ylabel='time spent freezing', 
    yticks = [0,0.2,0.4,0.6,0.8],
    yticklabels = ['0%','20%','40%','60%','80%'],
    xlabel = '',
    xlim = [.97,4],
    xticks = [1,2,3,4],
    xticklabels=['pre-shocks','day 1','day 2','day 3']
    )
              
    
ax.tick_params('both', length=15, width=2.5, which='major')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
sns.despine()

plt.show()

# Supplementary Figure 2

## Supplementary Figure 2a

In [None]:
#top

single_beh = all_behavior[(all_behavior.name=='MR8a')&(all_behavior.day==3)&(all_behavior.context.isin(['fear','safe']))]
df_name = single_beh

fig, axes = plt.subplots(1, figsize=(20, 2))
clean_axes(axes)

df_name.plot(x='seconds', y='ybinned', color='xkcd:very dark brown', linewidth = 1.5, ax=axes, legend=False)

axes.set(
        ylim=[0,0.61],
        xlim=[16,651],
        yticks=(0, .61),
        yticklabels=('0m', '2m'),
        xticks=[16,651],
        xticklabels=['',''],
        xlabel = ''
            )

gb = df_name.seconds.groupby(df_name.pause)
prev = df_name.seconds.iloc[0]
ranges = pd.DataFrame({'first': gb.first(), 'last': gb.last()})

for i, (first, last) in ranges[ranges.index > 0].sort_index().iterrows():
    for ax in np.asarray(axes).flat:
        ax.axvspan(first, last, color='xkcd:red orange', alpha=0.3)
        ax.axvspan(prev, first, color='xkcd:white', alpha=0.12)
    prev = last
for ax in np.asarray(axes).flat:
    ax.axvspan(prev, df_name.seconds.iloc[-1], color='xkcd:white', alpha=0.1)

for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2.5)
    
ax.tick_params('both', length=10, width=2.5, which='major')
sns.despine()

plt.show()


In [None]:
#bottom

single_beh = all_behavior[(all_behavior.name=='MR8a')&(all_behavior.day==4)&(all_behavior.context.isin(['fear','safe']))]
df_name = single_beh

fig, axes = plt.subplots(1, figsize=(20, 2))
clean_axes(axes)

df_name.plot(x='seconds', y='ybinned', color='xkcd:very dark brown', linewidth = 1.5, ax=axes, legend=False)

axes.set(
        ylim=[0,0.61],
        xlim=[16,651],
        yticks=(0, .61),
        yticklabels=('0m', '2m'),
        xticks=[16,651],
        xticklabels=['',''],
        xlabel = ''
            )

gb = df_name.seconds.groupby(df_name.pause)
prev = df_name.seconds.iloc[0]
ranges = pd.DataFrame({'first': gb.first(), 'last': gb.last()})

for i, (first, last) in ranges[ranges.index > 0].sort_index().iterrows():
    for ax in np.asarray(axes).flat:
        ax.axvspan(first, last, color='xkcd:red orange', alpha=0.3)
        ax.axvspan(prev, first, color='xkcd:white', alpha=0.12)
    prev = last
for ax in np.asarray(axes).flat:
    ax.axvspan(prev, df_name.seconds.iloc[-1], color='xkcd:white', alpha=0.1)

for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2.5)
    
ax.tick_params('both', length=10, width=2.5, which='major')
sns.despine()

plt.show()


## Supplementary Figure 2b

In [None]:
bin_size = 0.07
day = 0
bin_window = 20

colors = ['teal','orchid']

fig, ax = plt.subplots(1, figsize=(6.4, 4.8))


pause_lengths_nr_standard = get_pause_lengths(all_behavior[all_behavior.context.isin(["fear","safe"])
                                                & (all_behavior.day == day) 
                                                & (all_behavior.condition == "nr_standard")],bin_window)

ax = sns.histplot(data=pause_lengths_nr_standard[pause_lengths_nr_standard.context == "fear"], x='length',
                  multiple="layer", edgecolor = colors[0], color = colors[0], linewidth = 2, fill = False,
                  log_scale=True, stat='probability', binwidth = bin_size, kde=True, zorder = 3,
                  alpha=1, line_kws={"color": colors[0], "linewidth": 3, "alpha":1, "linestyle":'solid'})


ax = sns.histplot(data=pause_lengths_nr_standard[pause_lengths_nr_standard.context == "safe"], x='length',
                  multiple="layer", color = colors[1], edgecolor = colors[1], linewidth = 2, fill = False,
                  log_scale=True, stat='probability', binwidth = bin_size, kde=True, zorder = 0,
                  alpha=1, line_kws={"color": colors[1], "linewidth": 3, "alpha":1, "linestyle":'solid'}, ax=ax)

ax.lines[0].set_color(colors[0])

ax.set(
    ylabel ='freeze length probability', 
    ylim = [0,.18],
    yticks = [0, 0.06, 0.12,0.18],
    xlabel = 'length of freezed epoch (log scale)',
    xlim = [1,100],
    xticks = [1,3,10,30,100],
    xticklabels=['1s','3s','10s','30s','100s']
    )

ax.tick_params('both', length=15, width=2.5, which='major')
ax.tick_params('both', length=7.5, width=2.5, which='minor')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
    
sns.despine()
plt.legend([],[], frameon=False)
plt.show()

## Supplementary Figure 2c

In [None]:
bin_size = 0.07
day = 3
bin_window = 20

colors = ['teal','orchid']

fig, ax = plt.subplots(1, figsize=(6.4, 4.8))


pause_lengths_nr_standard = get_pause_lengths(all_behavior[all_behavior.context.isin(["fear","safe"])
                                                & (all_behavior.day == day) 
                                                & (all_behavior.condition == "nr_standard")],bin_window)


ax = sns.histplot(data=pause_lengths_nr_standard[pause_lengths_nr_standard.context == "fear"], x='length',
                  multiple="layer", edgecolor = colors[0], color = colors[0], linewidth = 2, fill = False,
                  log_scale=True, stat='probability', binwidth = bin_size, kde=True, zorder = 3,
                  alpha=1, line_kws={"color": colors[0], "linewidth": 3, "alpha":1, "linestyle":'solid'})


ax = sns.histplot(data=pause_lengths_nr_standard[pause_lengths_nr_standard.context == "safe"], x='length',
                  multiple="layer", color = colors[1], edgecolor = colors[1], linewidth = 2, fill = False,
                  log_scale=True, stat='probability', binwidth = bin_size, kde=True, zorder = 0,
                  alpha=1, line_kws={"color": colors[1], "linewidth": 3, "alpha":1, "linestyle":'solid'}, ax=ax)

ax.lines[0].set_color(colors[0])

ax.set(
    ylabel ='freeze length probability', 
    ylim = [0,.18],
    yticks = [0, 0.06, 0.12,0.18],
    xlabel = 'length of freezed epoch (log scale)',
    xlim = [1,100],
    xticks = [1,3,10,30,100],
    xticklabels=['1s','3s','10s','30s','100s']
    )

ax.tick_params('both', length=15, width=2.5, which='major')
ax.tick_params('both', length=7.5, width=2.5, which='minor')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
    
sns.despine()
plt.legend([],[], frameon=False)
plt.show()

## Supplementary Figure 2d

In [None]:
bin_size = 0.07
day = 4
bin_window = 20

colors = ['teal','orchid']

fig, ax = plt.subplots(1, figsize=(6.4, 4.8))


pause_lengths_nr_standard = get_pause_lengths(all_behavior[all_behavior.context.isin(["fear","safe"])
                                                & (all_behavior.day == day) 
                                                & (all_behavior.condition == "nr_standard")],bin_window)

ax = sns.histplot(data=pause_lengths_nr_standard[pause_lengths_nr_standard.context == "fear"], x='length',
                  multiple="layer", edgecolor = colors[0], color = colors[0], linewidth = 2, fill = False,
                  log_scale=True, stat='probability', binwidth = bin_size, kde=True, zorder = 3,
                  alpha=1, line_kws={"color": colors[0], "linewidth": 3, "alpha":1, "linestyle":'solid'})


ax = sns.histplot(data=pause_lengths_nr_standard[pause_lengths_nr_standard.context == "safe"], x='length',
                  multiple="layer", color = colors[1], edgecolor = colors[1], linewidth = 2, fill = False,
                  log_scale=True, stat='probability', binwidth = bin_size, kde=True, zorder = 0,
                  alpha=1, line_kws={"color": colors[1], "linewidth": 3, "alpha":1, "linestyle":'solid'}, ax=ax)

ax.lines[0].set_color(colors[0])

ax.set(
    ylabel ='freeze length probability', 
    ylim = [0,.18],
    yticks = [0, 0.06, 0.12,0.18],
    xlabel = 'length of freezed epoch (log scale)',
    xlim = [1,100],
    xticks = [1,3,10,30,100],
    xticklabels=['1s','3s','10s','30s','100s']
    )

ax.tick_params('both', length=15, width=2.5, which='major')
ax.tick_params('both', length=7.5, width=2.5, which='minor')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
    
sns.despine()
plt.legend([],[], frameon=False)
plt.show()

## Supplementary Figure 2e

In [None]:
bin_size = 0.07
day = 0
bin_window = 20

colors = ['teal','orchid']

fig, ax = plt.subplots(1, figsize=(6.4, 4.8))

pause_lengths_dreadd_dcz = get_pause_lengths(all_behavior[all_behavior.context.isin(["fear","safe"])
                                                & (all_behavior.day == day) 
                                                & (all_behavior.condition == "dreadd_dcz")],bin_window)

ax = sns.histplot(data=pause_lengths_dreadd_dcz[pause_lengths_dreadd_dcz.context == "fear"], x='length',
                  multiple="layer", edgecolor = colors[0], color = colors[0], linewidth = 2, fill = False,
                  log_scale=True, stat='probability', binwidth = bin_size, kde=True, zorder = 3,
                  alpha=1, line_kws={"color": colors[0], "linewidth": 3, "alpha":1, "linestyle":'solid'})


ax = sns.histplot(data=pause_lengths_dreadd_dcz[pause_lengths_dreadd_dcz.context == "safe"], x='length',
                  multiple="layer", color = colors[1], edgecolor = colors[1], linewidth = 2, fill = False,
                  log_scale=True, stat='probability', binwidth = bin_size, kde=True, zorder = 0,
                  alpha=1, line_kws={"color": colors[1], "linewidth": 3, "alpha":1, "linestyle":'solid'}, ax=ax)

ax.lines[0].set_color(colors[0])

ax.set(
    ylabel ='freeze length probability', 
    ylim = [0,.18],
    yticks = [0, 0.06, 0.12,0.18],
    xlabel = 'length of freezed epoch (log scale)',
    xlim = [1,100],
    xticks = [1,3,10,30,100],
    xticklabels=['1s','3s','10s','30s','100s']
    )

ax.tick_params('both', length=15, width=2.5, which='major')
ax.tick_params('both', length=7.5, width=2.5, which='minor')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
    
sns.despine()
plt.legend([],[], frameon=False)
plt.show()

## Supplementary Figure 2f

In [None]:
bin_size = 0.07
day = 3
bin_window = 20

colors = ['teal','orchid']

fig, ax = plt.subplots(1, figsize=(6.4, 4.8))

pause_lengths_dreadd_dcz = get_pause_lengths(all_behavior[all_behavior.context.isin(["fear","safe"])
                                                & (all_behavior.day == day) 
                                                & (all_behavior.condition == "dreadd_dcz")],bin_window)

ax = sns.histplot(data=pause_lengths_dreadd_dcz[pause_lengths_dreadd_dcz.context == "fear"], x='length',
                  multiple="layer", edgecolor = colors[0], color = colors[0], linewidth = 2, fill = False,
                  log_scale=True, stat='probability', binwidth = bin_size, kde=True, zorder = 3,
                  alpha=1, line_kws={"color": colors[0], "linewidth": 3, "alpha":1, "linestyle":'solid'})


ax = sns.histplot(data=pause_lengths_dreadd_dcz[pause_lengths_dreadd_dcz.context == "safe"], x='length',
                  multiple="layer", color = colors[1], edgecolor = colors[1], linewidth = 2, fill = False,
                  log_scale=True, stat='probability', binwidth = bin_size, kde=True, zorder = 0,
                  alpha=1, line_kws={"color": colors[1], "linewidth": 3, "alpha":1, "linestyle":'solid'}, ax=ax)

ax.lines[0].set_color(colors[0])

ax.set(
    ylabel ='freeze length probability', 
    ylim = [0,.18],
    yticks = [0, 0.06, 0.12,0.18],
    xlabel = 'length of freezed epoch (log scale)',
    xlim = [1,100],
    xticks = [1,3,10,30,100],
    xticklabels=['1s','3s','10s','30s','100s']
    )

ax.tick_params('both', length=15, width=2.5, which='major')
ax.tick_params('both', length=7.5, width=2.5, which='minor')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
    
sns.despine()
plt.legend([],[], frameon=False)
plt.show()

## Supplementary Figure 2g

In [None]:
p2rs = {}
frame_length = 15
df = traces 
for key, df in df.groupby(["name", "day", "context"]):
    p2rs[key] = np.asarray(df.index)[get_p2r(df,
                                             pause_min=frame_length,
                                             unpaused_min=frame_length,
                                             half_window_length=frame_length)]

In [None]:
#top
df = traces

axon_val = 'axon'
frame_end = 10000
x_tick_interpolation = FRAME_RATE

mid_line = next(iter(p2rs.values())).shape[1] / 2
index = np.arange(-mid_line, mid_line) / FRAME_RATE

axon = axon_val
columns = ['fear', 'safe']

palette = ['orchid','teal','black']

colors = {"axon": [palette[0],palette[1]],
          'mid_line': [palette[2], palette[2]]}

shading = {"axon": [palette[0],palette[1]],
          'mid_line': [palette[2],palette[2]]}

fig, axes = plt.subplots(ncols=2, sharey=False, figsize=(4.95*3,1.74*3))
axes = np.asarray(axes)

window_range = [-frame_length/15.49,frame_length/15.49]

for i in range(2):
    for j, context in enumerate(columns):
        p2r = np.concatenate([val for (name, day, ctxt), val in p2rs.items()
                              if ctxt == context and (day > 1 if i else day == 1)])
        trace = df[axon].values[p2r]
        mean = savgol_filter(trace.mean(0),10,5)
        std = trace.std(0)
        se = sp.stats.sem(trace)
        color = colors[axon][j]
        light_color = shading[axon][j]
        
        axes.flat[i].plot(index, mean, color=color, linewidth = 3,zorder=2.55)
        axes.flat[i].fill_between(index, mean - se, mean + se,
                                  color=light_color, alpha=0.2,zorder=2.56)
        axes.flat[i].axvline(x=0, linewidth=3, color=colors['mid_line'][j])
        
for i in range (2):
        axes.flat[i].axvspan(0, window_range[1], color='white', alpha=0.08, lw=0,zorder=0.4)
        axes.flat[i].axvspan(window_range[0], 0, color='xkcd:red orange', alpha=0.15, lw=0,zorder=0.4)
        ax.set_axisbelow(True)
        
clean_axes(axes)

xlabels = ['pre-shocks','post shocks']
ylabels = ['mean $Δf/f$ within epoch range','']

for i, ax in enumerate(axes.flat):
    ax.set(
        xlim = window_range,
        xticks=[-3,-1.5,0,1.5,3],
        xticklabels=['-3s','-1.5s','0s','1.5s','3s'],
        ylim=[-0.1,.2],
        yticks=[-0.1,0,0.1,0.2],
        yticklabels=['-0.1','0.0','0.1','0.2'],
        ylabel = ylabels[i],
        xlabel = xlabels[i],
    )
        
    ax.legend([], [], frameon=False)
    ax.tick_params('both', length=15, width=3, which='major')
    for axis in ['left','bottom']:
        ax.spines[axis].set_linewidth(3)


fig.subplots_adjust(wspace=0.2)
plt.show() 

In [None]:
p2rs = {}
frame_length = 30
df = traces 
for key, df in df.groupby(["name", "day", "context"]):
    p2rs[key] = np.asarray(df.index)[get_p2r(df,
                                             pause_min=frame_length,
                                             unpaused_min=frame_length,
                                             half_window_length=frame_length)]

In [None]:
#upper middle
df = traces

axon_val = 'axon'
frame_end = 10000
x_tick_interpolation = FRAME_RATE

mid_line = next(iter(p2rs.values())).shape[1] / 2
index = np.arange(-mid_line, mid_line) / FRAME_RATE

axon = axon_val
columns = ['fear', 'safe']

palette = ['orchid','teal','black']

colors = {"axon": [palette[0],palette[1]],
          'mid_line': [palette[2], palette[2]]}

shading = {"axon": [palette[0],palette[1]],
          'mid_line': [palette[2],palette[2]]}

fig, axes = plt.subplots(ncols=2, sharey=False, figsize=(4.95*3,1.74*3))
axes = np.asarray(axes)

window_range = [-frame_length/15.49,frame_length/15.49]

for i in range(2):
    for j, context in enumerate(columns):
        p2r = np.concatenate([val for (name, day, ctxt), val in p2rs.items()
                              if ctxt == context and (day > 1 if i else day == 1)])
        trace = df[axon].values[p2r]
        mean = savgol_filter(trace.mean(0),10,5)
        std = trace.std(0)
        se = sp.stats.sem(trace)
        color = colors[axon][j]
        light_color = shading[axon][j]
        
        axes.flat[i].plot(index, mean, color=color, linewidth = 3,zorder=2.55)
        axes.flat[i].fill_between(index, mean - se, mean + se,
                                  color=light_color, alpha=0.2,zorder=2.56)
        axes.flat[i].axvline(x=0, linewidth=3, color=colors['mid_line'][j])
        
for i in range (2):
        axes.flat[i].axvspan(0, window_range[1], color='white', alpha=0.08, lw=0,zorder=0.4)
        axes.flat[i].axvspan(window_range[0], 0, color='xkcd:red orange', alpha=0.15, lw=0,zorder=0.4)
        ax.set_axisbelow(True)
        
clean_axes(axes)

xlabels = ['pre-shocks','post shocks']
ylabels = ['mean $Δf/f$ within epoch range','']

for i, ax in enumerate(axes.flat):
    ax.set(
        xlim = window_range,
        xticks=[-3,-1.5,0,1.5,3],
        xticklabels=['-3s','-1.5s','0s','1.5s','3s'],
        ylim=[-0.1,.2],
        yticks=[-0.1,0,0.1,0.2],
        yticklabels=['-0.1','0.0','0.1','0.2'],
        ylabel = ylabels[i],
        xlabel = xlabels[i],
    )
        
    ax.legend([], [], frameon=False)
    ax.tick_params('both', length=15, width=3, which='major')
    for axis in ['left','bottom']:
        ax.spines[axis].set_linewidth(3)


fig.subplots_adjust(wspace=0.2)
plt.show() 

In [None]:
p2rs = {}
frame_length = 60
df = traces 
for key, df in df.groupby(["name", "day", "context"]):
    p2rs[key] = np.asarray(df.index)[get_p2r(df,
                                             pause_min=frame_length,
                                             unpaused_min=frame_length,
                                             half_window_length=frame_length)]

In [None]:
#lower middle
df = traces

axon_val = 'axon'
frame_end = 10000
x_tick_interpolation = FRAME_RATE

mid_line = next(iter(p2rs.values())).shape[1] / 2
index = np.arange(-mid_line, mid_line) / FRAME_RATE

axon = axon_val
columns = ['fear', 'safe']

palette = ['orchid','teal','black']

colors = {"axon": [palette[0],palette[1]],
          'mid_line': [palette[2], palette[2]]}

shading = {"axon": [palette[0],palette[1]],
          'mid_line': [palette[2],palette[2]]}

fig, axes = plt.subplots(ncols=2, sharey=False, figsize=(4.95*3,1.74*3))
axes = np.asarray(axes)

window_range = [-frame_length/15.49,frame_length/15.49]

for i in range(2):
    for j, context in enumerate(columns):
        p2r = np.concatenate([val for (name, day, ctxt), val in p2rs.items()
                              if ctxt == context and (day > 1 if i else day == 1)])
        trace = df[axon].values[p2r]
        mean = savgol_filter(trace.mean(0),10,5)
        std = trace.std(0)
        se = sp.stats.sem(trace)
        color = colors[axon][j]
        light_color = shading[axon][j]
        
        axes.flat[i].plot(index, mean, color=color, linewidth = 3,zorder=2.55)
        axes.flat[i].fill_between(index, mean - se, mean + se,
                                  color=light_color, alpha=0.2,zorder=2.56)
        axes.flat[i].axvline(x=0, linewidth=3, color=colors['mid_line'][j])
        
for i in range (2):
        axes.flat[i].axvspan(0, window_range[1], color='white', alpha=0.08, lw=0,zorder=0.4)
        axes.flat[i].axvspan(window_range[0], 0, color='xkcd:red orange', alpha=0.15, lw=0,zorder=0.4)
        ax.set_axisbelow(True)
        
clean_axes(axes)

xlabels = ['pre-shocks','post shocks']
ylabels = ['mean $Δf/f$ within epoch range','']

for i, ax in enumerate(axes.flat):
    ax.set(
        xlim = window_range,
        xticks=[-3,-1.5,0,1.5,3],
        xticklabels=['-3s','-1.5s','0s','1.5s','3s'],
        ylim=[-0.1,.2],
        yticks=[-0.1,0,0.1,0.2],
        yticklabels=['-0.1','0.0','0.1','0.2'],
        ylabel = ylabels[i],
        xlabel = xlabels[i],
    )
        
    ax.legend([], [], frameon=False)
    ax.tick_params('both', length=15, width=3, which='major')
    for axis in ['left','bottom']:
        ax.spines[axis].set_linewidth(3)


fig.subplots_adjust(wspace=0.2)
plt.show() 

In [None]:
p2rs = {}
frame_length = 75
df = traces 
for key, df in df.groupby(["name", "day", "context"]):
    p2rs[key] = np.asarray(df.index)[get_p2r(df,
                                             pause_min=frame_length,
                                             unpaused_min=frame_length,
                                             half_window_length=frame_length)]

In [None]:
#bottom
df = traces

axon_val = 'axon'
frame_end = 10000
x_tick_interpolation = FRAME_RATE

mid_line = next(iter(p2rs.values())).shape[1] / 2
index = np.arange(-mid_line, mid_line) / FRAME_RATE

axon = axon_val
columns = ['fear', 'safe']

palette = ['orchid','teal','black']

colors = {"axon": [palette[0],palette[1]],
          'mid_line': [palette[2], palette[2]]}

shading = {"axon": [palette[0],palette[1]],
          'mid_line': [palette[2],palette[2]]}

fig, axes = plt.subplots(ncols=2, sharey=False, figsize=(4.95*3,1.74*3))
axes = np.asarray(axes)

window_range = [-frame_length/15.49,frame_length/15.49]

for i in range(2):
    for j, context in enumerate(columns):
        p2r = np.concatenate([val for (name, day, ctxt), val in p2rs.items()
                              if ctxt == context and (day > 1 if i else day == 1)])
        trace = df[axon].values[p2r]
        mean = savgol_filter(trace.mean(0),10,5)
        std = trace.std(0)
        se = sp.stats.sem(trace)
        color = colors[axon][j]
        light_color = shading[axon][j]
        
        axes.flat[i].plot(index, mean, color=color, linewidth = 3,zorder=2.55)
        axes.flat[i].fill_between(index, mean - se, mean + se,
                                  color=light_color, alpha=0.2,zorder=2.56)
        axes.flat[i].axvline(x=0, linewidth=3, color=colors['mid_line'][j])
        
for i in range (2):
        axes.flat[i].axvspan(0, window_range[1], color='white', alpha=0.08, lw=0,zorder=0.4)
        axes.flat[i].axvspan(window_range[0], 0, color='xkcd:red orange', alpha=0.15, lw=0,zorder=0.4)
        ax.set_axisbelow(True)
        
clean_axes(axes)

xlabels = ['pre-shocks','post shocks']
ylabels = ['mean $Δf/f$ within epoch range','']

for i, ax in enumerate(axes.flat):
    ax.set(
        xlim = window_range,
        xticks=[-3,-1.5,0,1.5,3],
        xticklabels=['-3s','-1.5s','0s','1.5s','3s'],
        ylim=[-0.1,.2],
        yticks=[-0.1,0,0.1,0.2],
        yticklabels=['-0.1','0.0','0.1','0.2'],
        ylabel = ylabels[i],
        xlabel = xlabels[i],
    )
        
    ax.legend([], [], frameon=False)
    ax.tick_params('both', length=15, width=3, which='major')
    for axis in ['left','bottom']:
        ax.spines[axis].set_linewidth(3)


fig.subplots_adjust(wspace=0.2)
plt.show() 

# Supplementary Figure 3

## Supplementary Figure 3a

In [None]:
df = traces[(traces.day>0) & (traces.context == 'fear')]

cols = ["name", "day", "pause"]

pause_info = pd.DataFrame({
    "start": df.groupby(cols).frame.first(),
    "length": df.groupby(cols).frame.last() - traces.groupby(cols).frame.first() + 1
})

df = pd.merge(df[df.is_local_max], pause_info.reset_index(), on=cols)

df["pause_progress"] = (df.frame - df.start + 1) / df.length

df = df[df.pause > 0]

r_2 = round(df.rescaled_axon.corr(df.pause_progress),2)

facetgrid = sns.lmplot(x='pause_progress',y='zeroed_axon', hue = 'day', data=df, x_ci=95, legend=False, height = 5, aspect = 1.3,
            scatter_kws={"s": 10, "alpha" : 0.3,'label':'_ignore'}, 
            line_kws={ 'label':r_2})

ax = facetgrid.ax
ax.set(
    ylabel='Normalized peak $Δf/f$', 
    xlabel = 'fraction of time through pause',
    yticks = [0,0.5,1,1.5,2],
    xticks = [0,0.2,0.4,0.6,0.8,1],
    xticklabels = ['0%','20%','40%','60%','80%','100%']
    )


ax.tick_params('both', length=15, width=2.5, which='major')
ax.tick_params('both', length=7.5, width=2.5, which='minor')

for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
handles, labels = plt.gca().get_legend_handles_labels()
  
order = [1,3,5,7]
  
leg = plt.legend([handles[i] for i in order], 
                 ['day 1\nr2=-0.04',
                  'day 2\nr2=-0.06',
                  'day 3\nr2=-0.09',
                  'day 4\nr2=-0.08'
                  ],
                  bbox_to_anchor=(.94,1)
                )
    
sns.despine()
plt.show()

## Supplementary Figure 3b

In [None]:
day = 1
context = 'fear'

df = traces[(traces.day == day)&(traces.context == context)]

gb = df.groupby(["name", "pause"])
df = pd.DataFrame({
    "peak": df[df.is_local_max].groupby(["name", "pause"]).rescaled_axon.max(),
    "length": gb.frame.last() - gb.frame.first() + 1
}).reset_index()

df = df[df.pause > 0].dropna()

r_2 = round(np.log(df.length).corr(df.peak),2)

ax = sns.regplot(x='length',y = 'peak',data = df, robust = True, x_ci=[95],
            scatter_kws={"s": 10, "alpha" : 0.3, "color" : "tomato",'label':'_ignore'}, 
            line_kws={"color": "royalblue", 'label':r_2})

plt.xscale("log")

ax.set(
    xlim = [12,1000],
    ylabel='normalized max peak $Δf/f$', 
    xlabel = 'pause length',
    yticklabels = [0,0.5,1,1.5,2],
    )

ax.tick_params('both', length=15, width=2.5, which='major')
ax.tick_params('both', length=7.5, width=2.5, which='minor')

for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
plt.legend(['_ignore','r2 = %s' % (r_2)],loc = 'upper left')

sns.despine()

plt.show()

## Supplementary Figure 3c

In [None]:
day = 2
context = 'fear'

df = traces[(traces.day == day)&(traces.context == context)]

gb = df.groupby(["name", "pause"])
df = pd.DataFrame({
    "peak": df[df.is_local_max].groupby(["name", "pause"]).rescaled_axon.max(),
    "length": gb.frame.last() - gb.frame.first() + 1
}).reset_index()

df = df[df.pause > 0].dropna()

r_2 = round(np.log(df.length).corr(df.peak),2)

ax = sns.regplot(x='length',y = 'peak',data = df, robust = True, x_ci=[95],
            scatter_kws={"s": 10, "alpha" : 0.3, "color" : "tomato",'label':'_ignore'}, 
            line_kws={"color": "royalblue", 'label':r_2})

plt.xscale("log")

ax.set(
    xlim = [12,1000],
    ylabel='normalized max peak $Δf/f$', 
    xlabel = 'pause length',
    #xticks = [0,0.75,1.25,1.75,2.25],
    yticklabels = [0,0.5,1,1.5,2],
    #xticklabels = ['1s','3s','10s','30s'],
    )

ax.tick_params('both', length=15, width=2.5, which='major')
ax.tick_params('both', length=7.5, width=2.5, which='minor')

for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
plt.legend(['_ignore','r2 = %s' % (r_2)],loc = 'upper left')

sns.despine()

plt.show()

## Supplementary Figure 3d

In [None]:
day_label = {
    1: "pre-shocks",
    2: "post-shocks",
    3: "post-shocks",
    4: "post-shocks"
}

colors = ['xkcd:darkish purple','xkcd:olive']

id_cols = ["name", "day", "context", "pause"]
gb = traces[traces.pause > 0].groupby(id_cols).frame
pause_lengths = (gb.last() - gb.first() + 1).reset_index()
df_pause_length = pd.merge(df[id_cols], pause_lengths, on=id_cols).frame

min_frames = 45
max_frames = np.Inf

samples = df[
    df.is_peak &
    (df.context == "fear") &
    (df_pause_length > min_frames) & 
    (df_pause_length < max_frames)
].groupby(["name", "day", "pause", "bins"]).rescaled_axon.mean().reset_index()

samples["day_label"] = samples.day.apply(day_label.__getitem__)
samples["bins"] = samples.bins.apply(lambda t: f"{20 * t}-{20 * (t + 1)}%")

fig, ax = plt.subplots(1, figsize=(7.1 , 4.8))

ax = sns.lineplot(data=samples, 
                  x="bins", y="rescaled_axon", hue="day_label", palette = colors ,
                  hue_order = sorted(set(day_label.values())),
                  errorbar = ('se',1), style = 'day_label', dashes = False, markers = ['o']*2)

ax.set(
    ylabel='mean normalized $Δf/f$', 
    yticks = [0.10,0.15,0.2,0.25,0.3],
    ylim = [0.2,0.5],
    xlabel = 'progress through freeze epoch',
    xticks = [0,1,2,3,4],
    )
ax.tick_params('both', length=15, width=2.5, which='major')


for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
plt.legend(bbox_to_anchor=(0.99,1.09))
sns.despine()
plt.show()

## Supplementary Figure 3e

In [None]:
bin_size = 0.07
day = 0
bin_window = 20

colors = ['teal','orchid']

fig, ax = plt.subplots(1, figsize=(6.4, 4.8))

pause_lengths_dreadd_dcz = get_pause_lengths(all_behavior[all_behavior.context.isin(["fear","safe"])
                                                & (all_behavior.day == day) 
                                                & (all_behavior.condition == "dreadd_dcz")],bin_window)

ax = sns.histplot(data=pause_lengths_dreadd_dcz[pause_lengths_dreadd_dcz.context == "fear"], x='length',
                  multiple="layer", edgecolor = colors[0], color = colors[0], linewidth = 2, fill = False,
                  log_scale=True, stat='probability', binwidth = bin_size, kde=True, zorder = 3,
                  alpha=1, line_kws={"color": colors[0], "linewidth": 3, "alpha":1, "linestyle":'solid'})


ax = sns.histplot(data=pause_lengths_dreadd_dcz[pause_lengths_dreadd_dcz.context == "safe"], x='length',
                  multiple="layer", color = colors[1], edgecolor = colors[1], linewidth = 2, fill = False,
                  log_scale=True, stat='probability', binwidth = bin_size, kde=True, zorder = 0,
                  alpha=1, line_kws={"color": colors[1], "linewidth": 3, "alpha":1, "linestyle":'solid'}, ax=ax)

ax.lines[0].set_color(colors[0])

ax.set(
    ylabel ='freeze length probability', 
    ylim = [0,.18],
    yticks = [0, 0.06, 0.12,0.18],
    xlabel = 'length of freezed epoch (log scale)',
    xlim = [1,100],
    xticks = [1,3,10,30,100],
    xticklabels=['1s','3s','10s','30s','100s']
    )

ax.tick_params('both', length=15, width=2.5, which='major')
ax.tick_params('both', length=7.5, width=2.5, which='minor')
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
    
sns.despine()
plt.legend([],[], frameon=False)
plt.show()

# Supplementary Figure 4

## Supplementary Figure 4a

In [None]:
[(i, t["r2"]) for i, t in enumerate(results["MR7e", 2, "fear"])]

df = results["MR7e", 2, "fear"][99]["model"].get_booster().trees_to_dataframe()

df[df.Tree == 0]
xgb.plot_tree(results["MR7e", 2, "fear"][99]["model"].get_booster())
fig = matplotlib.pyplot.gcf()
fig.set_size_inches(150, 100)
plt.show()

## Supplementary Figure 4b

In [None]:
imps = []
for (name, day, context), vs in results.items():
    if name not in nice_mice: continue
    for v in vs:
        bst = v["model"].get_booster()
        imps.append({"name": name, "day": day, "context": context, 
                     **bst.get_score(importance_type="weight")})
imps = pd.DataFrame(imps).fillna(0)
imps.iloc[:, 3:] /= imps.iloc[:, 3:].sum(1).values[:, None]

df = imps.melt(id_vars=["name", "day", "context"])
df = df[df.day == 2]
order = df.groupby("variable").value.mean().sort_values(ascending=False).index

top_n = 24
if top_n is not None:
    df = df[df.variable.isin(order[:top_n])]
    order = order[:top_n]

plt.figure(figsize=(10,len(order) / 2))
ax = sns.barplot(data=df,
            x="value",
            y="variable",
            hue="context",
            order=order,
            palette = ['teal','orchid']     
           )

ax.set(
    ylabel = '',
    xlabel = 'Feature importance for Recall day 1',
    yticklabels = ['freezing elapsed',
                   'freezing progress',
                   'freezing remaining',
                   'interval elapsed',
                   'velocity',
                   'velocity 0.5s ago',
                   'velocity 1s ago', 
                   'is frozen', 
                   'interval progress',
                   'is post-freeze', 
                   'velocity in 1s',
                   'velocity in 0.5s', 
                   'running elapsed',
                   'interval remaining',
                   'running progress', 
                   'running remaining', 
                   'is running',
                   'is backtracking', 
                   'ybinned',
                   'pupil_area_smoothed',
                   'acceleration',
                   'pupil vertical position', 
                   'pupil horizontal position',
                   'is licking',
                  ])
    
ax.tick_params(
        axis='y',     
        which='both',
        length=15, 
        width=2.5,
        )
for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
sns.despine()
plt.legend([],[],frameon=False)
plt.tight_layout()
plt.show()

## Supplementary Figure 4c

In [None]:
#train model 
import warnings
warnings.filterwarnings("ignore", 
                        message="pandas.Int64Index is deprecated and "
                        "will be removed from pandas in a future version. "
                        "Use pandas.Index with the appropriate dtype instead.")
from IPython.core.display import HTML

# See https://github.com/bstriner/keras-tqdm/issues/21#issuecomment-443019223
HTML("""
<style>
.p-Widget.jp-OutputPrompt.jp-OutputArea-prompt:empty {
  padding: 0;
  border: 0;
}
</style>
""")
df = feature_df
gb_cols = ["name", "day", "context"]
label = "rescaled_axon"
drop_cols = ["lap", label]
n_draws = 100

#random seed is set to 42 for reproducability
np.random.seed(42)
results = {}
for key, gdf in tqdm(df.groupby(gb_cols)):
    for _ in tqdm(range(n_draws), leave=False):
        split = split_by_lap(gdf, gb_cols=[])
        X = gdf.drop(drop_cols, axis=1)
        y = gdf[label]

        model = xgb.XGBRegressor(
            gamma=1,
            learning_rate=0.01,
            n_estimators=1000,
            base_score=1,
            verbosity=0
        )

        model.fit(X[~split], y[~split], 
                  eval_set=[(X[split], y[split])],
                  early_stopping_rounds=5,
                  verbose=False
                 )

        results.setdefault(key, []).append({
            "split": split,
            "model": model,
            "r2": sklearn.metrics.r2_score(y[split], model.predict(X[split]))
        })

In [None]:
#create dataframes to plot
d = pd.DataFrame([[name, day, context, val["r2"]]
                   for (name, day, context), vs in results.items()
                   for val in vs], 
                  columns=["name", "day", "context", "r2"])

r2s = pd.DataFrame([[name, day, context, val["r2"]]
                   for (name, day, context), vs in results.items()
                   for val in vs], 
                  columns=["name", "day", "context", "r2"])

peak_v_r2 = pd.merge(r2s, peak_medians, on=["name", "day", "context"])

In [None]:
#plot data
df = pd.DataFrame([[ day, context, val["r2"]]
                   for (name, day, context), vs in results.items()
                   for val in vs], 
                  columns=[ "name","day", "context", "r2"])

ax = sns.boxplot(data=df, x="day", y="r2", hue="context", whis = 1, showfliers=True, boxprops=dict(alpha=.5),
                 palette = ('teal','grey'))
ax = sns.stripplot(data=df, x="day", y="r2", hue="context", alpha = 0.5, dodge = 'true',
                 palette = ('teal','grey'))
ax.set(
    ylim=[-0.1,.8],
    ylabel ='model fit ($r^2$)', 
    yticks = [-1,-0.75,-0.5,-0.25,0,0.25,0.5,0.75,1],
    xticks = [0,1,2,3],
    xticklabels=['pre-shock baseline','recall day 1','recall day 2','recall day 3']
    )
sns.despine()
plt.legend([],[],frameon=False)
plt.tight_layout()
plt.show()

## Supplementary Figure 4d

In [None]:
#top
filt = {"name": 'MR7e',
        "day": 1,
        "context": 'safe'
       }

vals = results[tuple(filt[k] for k in ["name", "day", "context"])]

result = max(vals, key=lambda t: t["r2"])
df = feature_df

bools = np.logical_and.reduce([df[col] == val for col, val in filt.items()])

test_y = df.loc[bools, 'axon'].values[result["split"]]
test_yhat = result["model"].predict(df.drop(['axon', "lap", *filt.keys()], axis=1).values[bools][result["split"]])

plt.plot((test_y -1)[20:794], color=("xkcd:blue grey"), linewidth = 1.7)
plt.plot((test_yhat -1)[20:794], color=("xkcd:rose"),linestyle = '-')
ax = plt.gca()


ax.set(
    ylabel = 'Actual or predicted $Δf/f$',
    ylim = [-0.4,0.6],
    xlim = [0,620],
    yticks = [
             -0.2,
              0.0,
              0.2,
              0.4,
              0.6,
              0.8,
             ],
    xticks = [0,155,310,465,620],
    xticklabels = ['0s','10s','20s','30s','40s']
    )

ax.tick_params(
        axis='y',         
        which='both',     
        length=15, 
        width=2.5,
        )

for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
sns.despine()
plt.show()

In [None]:
#bottom
filt = {"name": 'MR7e',
        "day": 4,
        "context": 'safe'
       }

vals = results[tuple(filt[k] for k in ["name", "day", "context"])]

result = max(vals, key=lambda t: t["r2"])
df = feature_df

bools = np.logical_and.reduce([df[col] == val for col, val in filt.items()])

test_y = df.loc[bools, 'axon'].values[result["split"]]
test_yhat = result["model"].predict(df.drop(['axon', "lap", *filt.keys()], axis=1).values[bools][result["split"]])

plt.plot((test_y -1)[20:794], color=("xkcd:blue grey"), linewidth = 1.7)
plt.plot((test_yhat -1)[20:794], color=("xkcd:rose"),linestyle = '-')
ax = plt.gca()


ax.set(
    ylabel = 'Actual or predicted $Δf/f$',
    ylim = [-0.4,0.6],
    xlim = [0,620],
    yticks = [
             -0.2,
              0.0,
              0.2,
              0.4,
              0.6,
              0.8,
             ],
    xticks = [0,155,310,465,620],
    xticklabels = ['0s','10s','20s','30s','40s']
    )

ax.tick_params(
        axis='y',         
        which='both',     
        length=15, 
        width=2.5,
        )

for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
sns.despine()
plt.show()

In [None]:
#middle bottom
filt = {"name": 'MR7e',
        "day": 3,
        "context": 'safe'
       }

vals = results[tuple(filt[k] for k in ["name", "day", "context"])]

result = max(vals, key=lambda t: t["r2"])
df = feature_df

bools = np.logical_and.reduce([df[col] == val for col, val in filt.items()])

test_y = df.loc[bools, 'axon'].values[result["split"]]
test_yhat = result["model"].predict(df.drop(['axon', "lap", *filt.keys()], axis=1).values[bools][result["split"]])

plt.plot((test_y -1)[20:794], color=("xkcd:blue grey"), linewidth = 1.7)
plt.plot((test_yhat -1)[20:794], color=("xkcd:rose"),linestyle = '-')
ax = plt.gca()


ax.set(
    ylabel = 'Actual or predicted $Δf/f$',
    ylim = [-0.4,0.6],
    xlim = [0,620],
    yticks = [
             -0.2,
              0.0,
              0.2,
              0.4,
              0.6,
              0.8,
             ],
    xticks = [0,155,310,465,620],
    xticklabels = ['0s','10s','20s','30s','40s']
    )

ax.tick_params(
        axis='y',         
        which='both',     
        length=15, 
        width=2.5,
        )

for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
sns.despine()
plt.show()

In [None]:
#middle top
filt = {"name": 'MR7e',
        "day": 2,
        "context": 'safe'
       }

vals = results[tuple(filt[k] for k in ["name", "day", "context"])]

result = max(vals, key=lambda t: t["r2"])
df = feature_df

bools = np.logical_and.reduce([df[col] == val for col, val in filt.items()])

test_y = df.loc[bools, 'axon'].values[result["split"]]
test_yhat = result["model"].predict(df.drop(['axon', "lap", *filt.keys()], axis=1).values[bools][result["split"]])

plt.plot((test_y -1)[20:794], color=("xkcd:blue grey"), linewidth = 1.7)
plt.plot((test_yhat -1)[20:794], color=("xkcd:rose"),linestyle = '-')
ax = plt.gca()


ax.set(
    ylabel = 'Actual or predicted $Δf/f$',
    ylim = [-0.4,0.6],
    xlim = [0,620],
    yticks = [
             -0.2,
              0.0,
              0.2,
              0.4,
              0.6,
              0.8,
             ],
    xticks = [0,155,310,465,620],
    xticklabels = ['0s','10s','20s','30s','40s']
    )

ax.tick_params(
        axis='y',         
        which='both',     
        length=15, 
        width=2.5,
        )

for axis in ['left','bottom']:
    ax.spines[axis].set_linewidth(2)
    
sns.despine()
plt.show()