# Imports

In [None]:
#-------------------------- Standard Imports --------------------------#
%reload_ext autoreload
%autoreload 2
import kdephys as kde
import pandas as pd
import polars as pl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import acr
plt.style.use('fast')
plt.style.use('/home/kdriessen/gh_master/kdephys/kdephys/plot/acr_plots.mplstyle')

# ---------------------------- EXTRAS --------------------------------#

# Load Basic Info

In [None]:
#---------------- Adjust Parameters Here -----------------# 
subject = "ACR_#"
exp = 'EXP'
stores = ['NNXo', 'NNXr']
rel_state='NREM'
#---------------------------------------------------------#

In [None]:
# ----------------------------------------- subject_info + Hypno -----------------------------------------
h = acr.io.load_hypno_full_exp(subject, exp)
si = acr.info_pipeline.load_subject_info(subject)
sort_ids = [f'{exp}-{store}' for store in stores]
recordings = acr.info_pipeline.get_exp_recs(subject, exp)
#---------------------------------------------------------------------------------------------------------

In [None]:
# ----------------------------------------- Load Basic Info -----------------------------------------
stim_start, stim_end = acr.stim.stim_bookends(subject, exp)
reb_start = h.hts(stim_end-pd.Timedelta('15min'), stim_end+pd.Timedelta('1h')).st('NREM').iloc[0].start_time
if reb_start < stim_end:
    stim_end_hypno = h.loc[(h.start_time<stim_end)&(h.end_time>stim_end)] # if stim time is in the middle of a nrem bout, then it can be the start of the rebound
    if stim_end_hypno.state.values[0] == 'NREM':
        reb_start = stim_end
    else:
        raise ValueError('Rebound start time is before stim end time, need to inspect')

assert reb_start >= stim_end, 'Rebound start time is before stim end time'

bl_start_actual = si["rec_times"][f'{exp}-bl']["start"]
bl_day = bl_start_actual.split("T")[0]
bl_start = pd.Timestamp(bl_day + "T09:00:00")

if f'{exp}-sd' in si['rec_times'].keys():
    sd_rec = f'{exp}-sd'
    sd_end = pd.Timestamp(si['rec_times'][sd_rec]['end'])
else:
    sd_rec = exp
    sd_end = stim_start
sd_start_actual = pd.Timestamp(si['rec_times'][sd_rec]['start'])
sd_day = si['rec_times'][sd_rec]['start'].split("T")[0]
sd_start = pd.Timestamp(sd_day + "T09:00:00")

# Load Data

In [None]:
# BANDPOWER DATA
#-------------------------------
bp = acr.io.load_concat_bandpower(subject, recordings, stores, hypno=True);
bp_rel = kde.xr.utils.rel_by_store(bp, state=rel_state, t1=None, t2=None);

In [None]:
# UNIT DATA
#-------------------------------
df, idf = acr.pl_units.load_spikes_polars(subject, sort_ids, info=True, exclude_bad_units=True)

# relative firing rate
window = '30s'
df_rel = acr.pl_units.get_rel_fr_df(df, h, window=window, rel_state=rel_state, t1=None, t2=None)

# relative firing rate - Course Grained
window_course = '300s'
df_rel_course = acr.pl_units.get_rel_fr_df(df, h, window=window_course, rel_state=rel_state, t1=None, t2=None)

In [None]:
# ON OFF DATA
#-------------------------------
off_period_min = .05
oodf = acr.onoff.on_off_detection_basic(df, off_period_min=off_period_min)
oodf = acr.onoff.add_datetime_to_oodf(oodf, subject, exp)
oodf = acr.onoff.states_to_oodf(oodf, h)
oodf = acr.onoff.time_zones_to_oodf(oodf, bl_start, bl_start+pd.Timedelta('12h'), "baseline")
oodf = acr.onoff.time_zones_to_oodf(oodf, reb_start, reb_start+pd.Timedelta('6h'), "rebound")
oodf = pl.from_pandas(oodf)
oodf = acr.onoff.oodf_durations_rel2bl(oodf)