# ACR_9

In [None]:
%matplotlib widget
%reload_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import tdt
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns

import kd_analysis.main.kd_utils as kd
import kd_analysis.main.kd_plotting as kp
import kd_analysis.main.kd_hypno as kh
import kd_analysis.ACR.acr_utils as acu
import sleep_score_for_me.v4 as ssfm

bp_def = dict(delta=(0.75, 4), theta=(4, 8), alpha = (8, 13), sigma = (11, 16), beta = (13, 30), gamma=(35, 55))

kd_ref = {}
kd_ref['echans'] = [1,2]
kd_ref['fchans']=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]
kd_ref['analysis_root'] = Path('/Volumes/opto_loc/Data/ACHR_PROJECT_MATERIALS/ACHR_3/ACHR_3-analysis-data')
kd_ref['tank_root'] = Path('/Volumes/opto_loc/Data/ACHR_3/ACHR_3_TANK')

import plotly.express as px
import plotly.io as pio

pio.templates.default = "plotly_dark"


In [None]:
acr9_info = {}
acr9_info['subject'] = 'ACR_9'
acr9_info['complete_key_list'] = ['control1', 'laser1']

In [None]:
cds = ['white', 'cyan']

# Functions

In [None]:
def ss_times(sub, exp):
    def acr_get_times(sub, exp):
        block_path = '/Volumes/opto_loc/Data/'+sub+'/'+sub+'-'+exp
        ep = tdt.read_block(block_path, t1=0, t2=0, evtype=['epocs'])
        times = {}
        times['bl_sleep_start'] = ep.epocs.Bttn.onset[0]
        times['stim_on'] = ep.epocs.Wdr_.onset[-1]
        times['stim_off'] = ep.epocs.Wdr_.offset[-1]

        dt_start = pd.to_datetime(ep.info.start_date)

        on_sec = pd.to_timedelta(times['stim_on'], unit='S')
        off_sec = pd.to_timedelta(times['stim_off'], unit='S')

        times['stim_on_dt'] = dt_start+on_sec
        times['stim_off_dt'] = dt_start+off_sec
        return times
    
    times = acr_get_times(sub, exp)

    start1 = times['bl_sleep_start'] - 30
    end1 = start1 + 7200
    print('FILE #1'), print(start1), print(end1)

    start2 = end1
    end2 = times['stim_off']
    print('FILE #2'), print(start2), print(end2)
    return times

In [None]:
def load_data(sub_info, exp_list=None, add_time=None):
    sub = sub_info['subject']
    times = {}

    if exp_list == None:
        exp_list = sub_info['complete_key_list']
    
    for condition in exp_list:
        times[condition] = ss_times(sub, condition)
    paths = acu.get_paths(sub, sub_info['complete_key_list'])   

    a = {}
    h={}
    for condition in exp_list:
       if add_time is not None:
           a[condition+'-e-d'], a[condition+'-e-s'] = kd.get_data_spg(paths[condition], store='EEGr', t1=times[condition]['bl_sleep_start']-30, t2=times[condition]['stim_off']+add_time, channel=[1,2])
           a[condition+'-f-d'], a[condition+'-f-s'] = kd.get_data_spg(paths[condition], store='LFP_', t1=times[condition]['bl_sleep_start']-30, t2=times[condition]['stim_off']+add_time, channel=[2, 8, 15])
       else:
           a[condition+'-e-d'], a[condition+'-e-s'] = kd.get_data_spg(paths[condition], store='EEGr', t1=times[condition]['bl_sleep_start']-30, t2=times[condition]['stim_off'], channel=[1,2])
           a[condition+'-f-d'], a[condition+'-f-s'] = kd.get_data_spg(paths[condition], store='LFP_', t1=times[condition]['bl_sleep_start']-30, t2=times[condition]['stim_off'], channel=[2, 8, 15])
       start_time = a[condition+'-e-d'].datetime.values[0]
       h[condition] = acu.load_hypno_set(sub, condition, scoring_start_time=start_time)
    return a, h, times


In [None]:
def acr_rel2peak(spg, hyp, times, band='delta', ylim=None):
    """
    spg --> xarray.dataarray
    hyp --> hypnogram object
    times --> dictionary (make sure to select the condition)
     """
    
    bp = kd.get_bp_set2(spg, bp_def)
    smooth_bp = kd.get_smoothed_ds(bp, smoothing_sigma=6)
    smooth_nrem_bp = kh.keep_states(smooth_bp, hyp, ['NREM'])

    nrem_spg = kh.keep_states(spg, hyp, ['NREM'])
    nrem_bp = kd.get_bp_set2(nrem_spg, bands=bp_def)
    rel_time_index = np.arange(0, len(smooth_nrem_bp.datetime.values))

    t1 = smooth_nrem_bp.datetime.values[0]
    t2 = times['stim_on_dt']
    avg_period = slice(t1, t2)
    avgs = smooth_nrem_bp.sel(datetime=avg_period).mean(dim='datetime')

    bp_nrem_rel2peak = smooth_nrem_bp/avgs

    bp_nrem_rel2peak = bp_nrem_rel2peak.assign_coords(time_rel=('datetime', rel_time_index))

    return bp_nrem_rel2peak.to_dataframe().reset_index()

In [None]:
def acr_rel_allstates(spg, hyp, times, band='delta', ylim=None):
    """
    spg --> xarray.dataarray
    hyp --> hypnogram object
    times --> dictionary (make sure to select the condition)
     """
    
    bp = kd.get_bp_set2(spg, bp_def)
    smooth_bp = kd.get_smoothed_ds(bp, smoothing_sigma=6)
    smooth_nrem_bp = kh.keep_states(smooth_bp, hyp, ['NREM'])

    nrem_spg = kh.keep_states(spg, hyp, ['NREM'])
    nrem_bp = kd.get_bp_set2(nrem_spg, bands=bp_def)
    rel_time_index = np.arange(0, len(smooth_bp.datetime.values))

    t1 = smooth_nrem_bp.datetime.values[0]
    t2 = times['stim_on_dt']
    avg_period = slice(t1, t2)
    avgs = smooth_nrem_bp.sel(datetime=avg_period).mean(dim='datetime')

    bp_rel2peak = smooth_bp/avgs

    bp_rel2peak = bp_rel2peak.assign_coords(time_rel=('datetime', rel_time_index))

    return bp_rel2peak.to_dataframe().reset_index()

In [None]:
def acr(spg, hyp, times, band='delta', ylim=None):
    """
    spg --> xarray.dataarray
    hyp --> hypnogram object
    times --> dictionary (make sure to select the condition)
     """
    
    bp = kd.get_bp_set2(spg, bp_def)
    smooth_bp = kd.get_smoothed_ds(bp, smoothing_sigma=6)
    smooth_nrem_bp = kh.keep_states(smooth_bp, hyp, ['NREM'])

    nrem_spg = kh.keep_states(spg, hyp, ['NREM'])
    nrem_bp = kd.get_bp_set2(nrem_spg, bands=bp_def)
    rel_time_index = np.arange(0, len(smooth_nrem_bp.datetime.values))

    t1 = smooth_nrem_bp.datetime.values[0]
    t2 = times['stim_on_dt']
    avg_period = slice(t1, t2)
    avgs = smooth_nrem_bp.sel(datetime=avg_period).mean(dim='datetime')

    bp_nrem_rel2peak = smooth_nrem_bp/avgs

    bp_nrem_rel2peak = bp_nrem_rel2peak.assign_coords(time_rel=('datetime', rel_time_index))

    df4plot = bp_nrem_rel2peak.to_dataframe()
    df4plot.reset_index(inplace=True)
    
    g = sns.FacetGrid(df4plot, row='channel', ylim=ylim, height=3, aspect=6)
    g.map(sns.lineplot, 'time_rel', band)

    return g

# Nights

In [None]:
t1 = np.datetime64('2022-06-10T21:00')
t2 = np.datetime64('2022-06-11T09:00')
t3 = np.datetime64('2022-06-11T21:00')
t4 = np.datetime64('2022-06-12T09:00')
t5 = np.datetime64('2022-06-12T21:00')
t6 = np.datetime64('2022-06-13T09:00')

s1 = slice(t1, t2)
s2 = slice(t2, t3)
s3 = slice(t3, t4)
s4 = slice(t4, t5)
s5 = slice(t5, t6)

In [None]:
#Friday Night
en1 = a9ed['test1'].sel(datetime=s1)
mn1 = a9md['test1'].sel(datetime=s1)
hyp1 = ssfm.ssfm_v4(en1, mn1, 1)

In [None]:
#Saturday Night
en2 = a9ed['test1'].sel(datetime=s3)
mn2 = a9md['test1'].sel(datetime=s3)
hyp2 = ssfm.ssfm_v4(en2, mn2, 1)

In [None]:
#Sunday Night
en3 = a9ed['test1'].sel(datetime=s5)
mn3 = a9md['test1'].sel(datetime=s5)
hyp3 = ssfm.ssfm_v4(en3, mn3, 1)

# Peak in Morn

In [None]:
m1 = np.datetime64('2022-06-15T09:00')
m2 = np.datetime64('2022-06-15T21:00')
m3 = np.datetime64('2022-06-16T09:00')
m4 = np.datetime64('2022-06-16T15:00')

ms1 = slice(m1, m2)
ms2 = slice(m2, m3)
ms3 = slice(m3, m4)

In [None]:
#Tuesday night - Light was incorrectly ON
t1 = np.datetime64('2022-06-14T21:00')
t2 = np.datetime64('2022-06-15T09:00')
ts = slice(t1, t2)
ec = a9ed['control1'].sel(datetime=ts)
mc = a9md['control1'].sel(datetime=ts)
hyp1 = ssfm.ssfm_v4(ec, mc, 1)

In [None]:
# Wednesday-light
e1 = a9ed['mon1'].sel(datetime=ms1)
m1 = a9md['mon1'].sel(datetime=ms1)
hyp1 = ssfm.ssfm_v4(e1, m1, 1)

In [None]:
# Wednesday NIGHT into thursday morn (DARK)
e2 = a9ed['mon1'].sel(datetime=ms2)
m2 = a9md['mon1'].sel(datetime=ms2)
ssfm.ssfm_v4(e2, m2, 1)

In [None]:
# Thursday morn (light)
e3 = a9ed['mon1'].sel(datetime=ms3)
m3 = a9md['mon1'].sel(datetime=ms3)
ssfm.ssfm_v4(e3, m3, 1)

# Relevant Times for Scoring/Analysis

Scoring plan 
- start 30 sec before Bttn sleep start
- score all of 15-min 'baseline' period, and entire 4-hour stim

In [None]:
ctrl_times = ss_times('ACR_9', 'control1')

In [None]:
lsr_times = ss_times('ACR_9', 'laser1')

# Load Data + Quick Plots

In [None]:
a9, h9, a9_times = load_data(acr9_info, add_time=None)

In [None]:
f, ax = plt.subplots()
kp.plot_shaded_bp(a9['control1-e-s'], 1, bp_def, 'delta', h9['control1'], ax=ax)

# Quantify relative to "Sleep Peak"

Procedure:
- Get nrem only
- chunk in some way
    - or take the average of the entire sleep period
- express everything relative to that peak

Then should have the option to plot NREM only, or all states expressed relative to NREM peak

Need:
- A way to get continuous NREM Data, (i.e. w/o gaps from other states)

In [None]:
c = acr_rel2peak(a9['control1-e-s'], h9['control1'], a9_times['control1'])
l = acr_rel2peak(a9['laser1-e-s'], h9['laser1'], a9_times['laser1'])
cf = acr_rel2peak(a9['control1-f-s'], h9['control1'], a9_times['control1'])
lf = acr_rel2peak(a9['laser1-f-s'], h9['laser1'], a9_times['laser1'])
c['Condition'] = 'Control'
l['Condition'] = 'Laser'
cf['Condition'] = 'Control'
lf['Condition'] = 'Laser'
cl = pd.concat([c, l])
clf = pd.concat([cf, lf])

In [None]:
title = "Delta Power (0.75-4Hz) During Sinusoidal Laser Stimulation vs Control - LFP | NREM Only"
fig = px.line(clf, x='time_rel', y='delta', color='Condition', facet_row='channel', height=600, width=2200, color_discrete_sequence=['lightgray', 'cyan'], title=title)
fig.update_xaxes(range=[0, 4000], title='Time')
fig.update_yaxes(range=[0, 2], title='Norm. Delta Power')
fig.add_vrect(x0=307, x1=4000, line_width=0, fillcolor="turquoise", opacity=0.05)
fig.add_vline(x=307, line_width=2, opacity=1, line_color='red')
fig.update_traces(line=dict(width=3))
fig

In [None]:
title = "Delta Power (0.75-4Hz) During Sinusoidal Laser Stimulation vs Control - EEG | NREM Only"
fig = px.line(cl, x='time_rel', y='delta', color='Condition', facet_row='channel', height=600, width=2200, color_discrete_sequence=['lightgray', 'cyan'], title=title)
fig.update_xaxes(range=[0, 4000], title='Time')
fig.update_yaxes(range=[0, 2], title='Norm. Delta Power')
fig.add_vrect(x0=307, x1=4000, line_width=0, fillcolor="turquoise", opacity=0.05)
fig.add_vline(x=307, line_width=2, opacity=1, line_color='red')
fig.update_traces(line=dict(width=2))
fig

# Histograms/Quantify

In [None]:
def x2df(xl, keys):
    dfs = []
    for x, key in zip(xl, keys):
        x = x.to_dataframe(name=key)
        x.reset_index(inplace=True)
        x['key'] = key
        dfs.append(x)
    df = pd.concat(dfs)
    return df

In [None]:
def acr_bp(spg, hyp, times, state=['NREM'], type='df', key=''):
    t1 = times['stim_on_dt']
    t2 = times['stim_off_dt']
    bp = kd.get_bp_set2(spg, bp_def)

    bp = bp.sel(datetime=slice(t1, t2))

    bp = kh.keep_states(bp, hyp, state)


    if type == 'xr':
        return bp
    elif type == 'df':
        bp_df = bp.to_dataframe()
        bp_df = bp_df.reset_index()
        bp_df['key'] = key
        return bp_df
    

In [None]:
def acr_bp_rel(spg, hyp, times, state=['NREM'], type='df', key=''):
    #Time values that we will need
    start = spg.datetime.values[0]
    t1 = times['stim_on_dt']
    t2 = times['stim_off_dt']
    
    #Calculate the bandpower values, then cut out only the desired states 
    bp = kd.get_bp_set2(spg, bp_def)
    bp = kh.keep_states(bp, hyp, state)

    #Gets the average bandpowers over the peak period (for the given state)
    avg_period = slice(start, t1)
    avgs = bp.sel(datetime=avg_period).mean(dim='datetime')

    #This expresses everything relative to that mean value over the peak period
    bp = bp/avgs

    #This selects out only the stim period
    bp = bp.sel(datetime=slice(t1, t2))
    
    # NOW HAVE: Stim period bandpower values, from only the desired state(s), relative to their mean value during the peak period

    #This outputs the data in the desired format:
    if type == 'xr':
        return bp
    elif type == 'df':
        bp_df = bp.to_dataframe()
        bp_df = bp_df.reset_index()
        bp_df['Key'] = key
        return bp_df
    

In [None]:
ct_eeg = acr_bp_rel(a9['control1-e-s'], h9['control1'], a9_times['control1'], key='Control')
ls_eeg = acr_bp_rel(a9['laser1-e-s'], h9['laser1'], a9_times['laser1'], key='Laser')
ct_lfp = acr_bp_rel(a9['control1-f-s'], h9['control1'], a9_times['control1'], key='Control')
ls_lfp = acr_bp_rel(a9['laser1-f-s'], h9['laser1'], a9_times['laser1'], key='Laser')

In [None]:
rel_stim_bp = pd.concat([ct_eeg, ls_eeg])
rel_stim_bp_lfp = pd.concat([ct_lfp, ls_lfp])
new_rsbp = rel_stim_bp.melt(id_vars=['Key', 'channel'], value_vars=['delta', 'theta', 'alpha', 'beta', 'gamma'], var_name='band', value_name='power')
new_rsbp_lfp = rel_stim_bp_lfp.melt(id_vars=['Key', 'channel'], value_vars=['delta', 'theta', 'alpha', 'beta', 'gamma'], var_name='band', value_name='power')

In [None]:
bp1eeg = rel_stim_bp[rel_stim_bp['channel']==1]

In [None]:
title = 'Contralateral-EEG Delta Values During Sinusoidal Laser Stimulation vs Control, Normalized to Baseline Period | NREM Only'
f = px.histogram(rel_stim_bp.loc[(rel_stim_bp['channel']==1)], x='delta', color='key', barmode='overlay', opacity=0.6, marginal='box', color_discrete_sequence=['white', 'cornflowerblue'], title=title)
f.update_xaxes(title='Normalized Delta Power')

In [None]:
title = 'Ipsilateral-EEG Delta Values During Sinusoidal Laser Stimulation vs Control, Normalized to Baseline Period | NREM Only'
f = px.histogram(rel_stim_bp.loc[(rel_stim_bp['channel']==2)], x='delta', color='key', barmode='overlay', opacity=0.6, marginal='box', color_discrete_sequence=['white', 'cornflowerblue'], title=title)
f.update_xaxes(title='Normalized Delta Power')

In [None]:
title = 'Superficial-LFP Delta Values During Sinusoidal Laser Stimulation vs Control, Normalized to Baseline Period | NREM Only'
f = px.histogram(rel_stim_bp_lfp.loc[(rel_stim_bp_lfp['channel']==2)], x='delta', color='key', barmode='overlay', opacity=0.6, marginal='box', color_discrete_sequence=['white', 'cornflowerblue'], title=title, nbins=250)
f.update_xaxes(title='Normalized Delta Power', range=[0, 2.5])

In [None]:
title = 'Mid-LFP Delta Values During Sinusoidal Laser Stimulation vs Control, Normalized to Baseline Period | NREM Only'
f = px.histogram(rel_stim_bp_lfp.loc[(rel_stim_bp_lfp['channel']==8)], x='delta', color='key', barmode='overlay', opacity=0.6, marginal='box', color_discrete_sequence=['white', 'cornflowerblue'], title=title, nbins=250)
f.update_xaxes(title='Normalized Delta Power', range=[0, 2.5])

In [None]:
title = 'Deep-LFP Delta Values During Sinusoidal Laser Stimulation vs Control, Normalized to Baseline Period | NREM Only'
f = px.histogram(rel_stim_bp_lfp.loc[(rel_stim_bp_lfp['channel']==15)], x='delta', color='key', barmode='overlay', opacity=0.6, marginal='box', color_discrete_sequence=['white', 'cornflowerblue'], title=title, nbins=250)
f.update_xaxes(title='Normalized Delta Power', range=[0, 2.5])

In [None]:
sns.set(style="ticks", context="talk")
plt.style.use("dark_background")

In [None]:
rel_stim_bp.reset_index(inplace=True)
rel_stim_bp_lfp.reset_index(inplace=True)

In [None]:
%matplotlib inline

In [None]:
#f, ax = plt.subplots(figsize=(15, 5))
f = sns.displot(rel_stim_bp, x="delta", hue="Key", palette=cds, row='channel', kind="kde", fill=True)
f.set_axis_labels("Delta Power Normalized to Baseline", "")
f.set_titles("")

In [None]:
#f, ax = plt.subplots(figsize=(15, 5))
f = sns.displot(rel_stim_bp_lfp, x="delta", hue="Key", palette=cds, row='channel', kind="kde", fill=True)
f.set_axis_labels("Delta Power Normalized to Baseline", "")
f.set_titles("")

In [None]:
#f, ax = plt.subplots(figsize=(5, 7))
f = sns.catplot(x="Key", y="delta", data=bp1eeg, palette=cds, kind="box", height=8, aspect=1)
f.set(ylim=(0,1.5))
f.set_axis_labels("Condition", "Delta Power Normalized to Baseline")
f.set_titles("")

In [None]:
#f, ax = plt.subplots(figsize=(5, 7))
f = sns.catplot(x="Key", y="delta", row='channel', data=rel_stim_bp_lfp, palette=cds, kind="box", height=8, aspect=1)
f.set(ylim=(0,2))
f.set_axis_labels("Condition", "Delta Power Normalized to Baseline")
f.set_titles("")

In [None]:
#f, ax = plt.subplots(figsize=(15, 5))
f = sns.displot(rel_stim_bp_lfp, x="delta", hue="Key", palette=cds, row='channel', kind="kde", fill=True)
f.set_axis_labels("Delta Power as % of Baseline", "")
f.set_titles("")

In [None]:
f = sns.catplot(x='band', y='power', hue='Key', kind='box', row='channel', data=new_rsbp, palette=cds)
f.set_axis_labels( "", "Delta Power Normalized to Baseline")
f.set_titles("")
f.set(ylim=(0,2))

In [None]:
bands=['delta', 'theta', 'alpha', 'beta']
f = sns.catplot(x='band', y='power', hue='Key', kind='box', row='channel', data=new_rsbp_lfp, palette=cds)
f.set_axis_labels( "", "Delta Power Normalized to Baseline")
f.set_titles("")
f.set(ylim=(0,2))

In [None]:
f, ax = plt.subplots(figsize=(10, 18))
ax = sns.catplot(x="band", y="power", kind="boxen",
            data=new_rsbp, fig=f)

# Boxplots

In [None]:
fig = px.box(x='band', y='power', data_frame=new_rsbp, color='key', points=False, notched=True, facet_row='channel', color_discrete_sequence=['white', 'cornflowerblue'])
fig

In [None]:
fig = px.box(x='band', y='power', data_frame=new_rsbp_lfp, color='key', points=False, notched=True, facet_row='channel', color_discrete_sequence=['white', 'cornflowerblue'])
fig

# PSD Plots

In [None]:
control_peak = slice(a9['control1-e-d'].datetime.values[0], a9_times['control1']['stim_on_dt'])
control_stim = slice(a9_times['control1']['stim_on_dt'], a9_times['control1']['stim_off_dt'])

In [None]:
laser_peak = slice(a9['laser1-e-d'].datetime.values[0], a9_times['laser1']['stim_on_dt'])
laser_stim = slice(a9_times['laser1']['stim_on_dt'], a9_times['laser1']['stim_off_dt'])

In [None]:
control_psd_nrem_peak = kd.get_ss_psd(a9['control1-e-s'].sel(datetime=control_peak), h9['control1'], ['NREM'])
control_psd_nrem_stim = kd.get_ss_psd(a9['control1-e-s'].sel(datetime=control_stim), h9['control1'], ['NREM'])

In [None]:
laser_psd_nrem_peak = kd.get_ss_psd(a9['laser1-e-s'].sel(datetime=laser_peak), h9['laser1'], ['NREM'])
laser_psd_nrem_stim = kd.get_ss_psd(a9['laser1-e-s'].sel(datetime=laser_stim), h9['laser1'], ['NREM'])

In [None]:
control_psd_eeg = control_psd_nrem_stim/control_psd_nrem_peak

In [None]:
laser_psd_eeg = laser_psd_nrem_stim/laser_psd_nrem_peak

In [None]:
c1 = control_psd_eeg.sel(channel=1)
l1 = laser_psd_eeg.sel(channel=1)

In [None]:
f = kp.compare_psd(l1, c1, 'NREM', keys=['Laser', 'Control'], scale='linear')
f.set(ylim=(0,1.25), xlim=(0.75,30))

In [None]:
f = kp.compare_psd(laser_psd_nrem_peak, control_psd_nrem_peak, 'NREM', scale='linear')
f.set(xlim=(0,10))

In [None]:
f = kp.compare_psd(laser_psd_nrem_stim.sel(channel=1), control_psd_nrem_stim.sel(channel=1), 'NREM', scale='linear')
f.set(xlim=(0,20), ylabel='NREM Power Spectral Density')

In [None]:
control_psd_nrem_peak_lfp = kd.get_ss_psd(a9['control1-f-s'].sel(datetime=control_peak), h9['control1'], ['NREM'])
control_psd_nrem_stim_lfp = kd.get_ss_psd(a9['control1-f-s'].sel(datetime=control_stim), h9['control1'], ['NREM'])

In [None]:
laser_psd_nrem_peak_lfp = kd.get_ss_psd(a9['laser1-f-s'].sel(datetime=laser_peak), h9['laser1'], ['NREM'])
laser_psd_nrem_stim_lfp = kd.get_ss_psd(a9['laser1-f-s'].sel(datetime=laser_stim), h9['laser1'], ['NREM'])

In [None]:
control_psd_lfp = control_psd_nrem_stim_lfp/control_psd_nrem_peak_lfp

In [None]:
laser_psd_lfp = laser_psd_nrem_stim_lfp/laser_psd_nrem_peak_lfp

In [None]:
kp.compare_psd(laser_psd_nrem_peak_lfp, control_psd_nrem_peak_lfp, 'NREM', scale='linear')

In [None]:
f = kp.compare_psd(laser_psd_nrem_stim_lfp.sel(channel=15), control_psd_nrem_stim_lfp.sel(channel=15), 'NREM', scale='linear')
f.set(xlim=(0,20), ylabel='NREM Power Spectral Density')

In [None]:
f = kp.compare_psd(laser_psd_lfp, control_psd_lfp, 'NREM', keys=['Laser', 'Control'], scale='linear')
f.set(ylim=(0,1.25), xlim=(0.75,30))

In [None]:
bp = kd.get_bp_set2(a9['control1-e-s'], bp_def, pandas=True).reset_index()

# Sleep Pressure Relief

In [None]:
dsn = ds.filt_state().ch(1).ri()

In [None]:
def get_time_iterables(df, time='1h', period = '5m'):
    total = pd.to_timedelta(time)
    num_periods = total / pd.to_timedelta(period)
    print(num_periods)
    time_iterables = [dsn.datetime.min()]
    for i in range(int(num_periods)):
        time_iterables.append(time_iterables[-1] + pd.to_timedelta(period))
    return time_iterables

In [None]:
times = get_time_iterables(dsn, time='3h', period = '5m')

36.0


In [None]:
def get_delta_means_by_dt(df, times):
    delta_means = []
    for i in range(len(times)-1):
        t1 = times[i]
        t2 = times[i+1]
        df_slice = df.ts(slice(t1, t2))
        print(len(df_slice))
        delta_means.append(df_slice.delta.mean())
    return delta_means

In [None]:
len(times[0:12])

12

In [None]:
def get_delta_means_by_index(dsn, period=5, total_time=60):
    delta_means = []
    datetimes = []
    interval = (dsn.datetime[1] - dsn.datetime[0]).seconds
    total_time = total_time*60
    total_ints = total_time/interval
    step_size = (period*60)/interval
    for i in np.arange(0, total_ints, step_size):
        df_slice = dsn.loc[i:i+step_size]
        dt = df_slice.datetime.values.min()
        datetimes.append(dt)
        delta_means.append(df_slice.delta.mean())
    return delta_means, datetimes

In [None]:
dm_ix, dt = get_delta_means_by_index(dsn, period=3, total_time=180)

In [None]:
late_dm = dsn[dsn.time_class=='Post Stim, 4-6Hr'].delta.mean()

In [None]:
delta_decay = pd.DataFrame()
delta_decay['datetime'] = dt
delta_decay['delta_mean'] = dm_ix
delta_decay['delta_mean_norm'] = delta_decay['delta_mean']/late_dm

In [None]:
import plotly.graph_objects as go
fig = px.line(x=delta_decay['datetime'], y=delta_decay['delta_mean_norm'], title='Delta Power in 3 Min Averages, Laser Experiment')
fig.add_vline(x=laser_laser_on, line_width=3, line_dash="dash", line_color="green")

In [None]:
import plotly.graph_objects as go
fig = go.Figure(data=[go.Scatter(x=delta_decay['datetime'], y=delta_decay['delta_mean'])])
fig

# Units

In [None]:
lp = '/Volumes/opto_loc/Data/ACR_9/ACR_9-laser1'
cp = '/Volumes/opto_loc/Data/ACR_9/ACR_9-control1'

sp15_con = kpd.get_spike_df(cp, t1=5723, t2=42575, chan=15, hyp=h['control1'], condition='control1')
sp15_las = kpd.get_spike_df(lp, t1=9077, t2=45922, chan=15, hyp=h['laser1'], condition='laser1')

In [None]:
hyp.fractional_occupancy(

state
Brief-Arousal         0.000860
NREM                  0.503469
REM                   0.096036
Transition-to-NREM    0.001612
Transition-to-REM     0.018505
Unsure                0.000105
Wake                  0.379413
Name: duration, dtype: float64

In [None]:
def total_time_in_state_over_period(hyp, state, t1, t2):
    
    def time_in_state(hyp, state):
        """Return the total time spent in a given state."""
        return hyp.keep_states([state]).duration.sum()
    
    hyp = hyp[hyp.end_time.between(t1, t2)]
    return time_in_state(hyp, state)

In [None]:
total_time_in_state_over_period(hyp, 'NREM', laser_laser_on, laser_laser_off)

Timedelta('0 days 02:18:09.478446959')

In [None]:
def get_spikes_per_min(spike_df, hyp):
    """Return the number of spikes per minute in a given state."""
    return spike_df.groupby('state').apply(lambda x: x.between(t1, t2).sum() / (t2 - t1) * 60)

In [None]:
sp15_con = add_time_class(sp15_con, t)

In [None]:
sp15_las = add_time_class(sp15_las, t)

In [None]:
sp15_las.time_class.unique()

array(['Baseline', 'Photostim, 0-2Hr', 'Photostim, 2-4Hr',
       'Post Stim, 0-2Hr', 'Post Stim, 2-4Hr', 'Post Stim, 4-6Hr', 'NA'],
      dtype=object)

In [None]:
for tc in sp15_las.time_class.unique():
    spikes_in_tc = sp15_las.loc[sp15_las.time_class == tc]
    spikes_in_tc_state = 
    t1 = spikes_in_tc.datetime.min()
    t2 = spikes_in_tc.datetime.max()

In [None]:
get_spikes_per_min(sp15_con, hyp, 'NREM', laser_laser_on, laser_laser_off)

KeyError: 'hypnogram'

In [None]:
nrem_times[nrem_times.end_time.between(laser_laser_on, laser_laser_off)].duration.sum()

Timedelta('0 days 02:18:09.478446959')

In [None]:
sp15_con.ts(slice(control_laser_on, control_laser_off)).filt_state()

Unnamed: 0,datetime,spikes,state,condition
0,2022-06-14 10:41:14.536746520,1.0,NREM,control1
1,2022-06-14 10:41:14.863935000,1.0,NREM,control1
2,2022-06-14 10:41:15.022491160,1.0,NREM,control1
3,2022-06-14 10:41:15.039489560,1.0,NREM,control1
4,2022-06-14 10:41:15.046247960,1.0,NREM,control1
...,...,...,...,...
125155,2022-06-14 14:40:10.494844440,1.0,NREM,control1
125156,2022-06-14 14:40:10.503036439,1.0,NREM,control1
125157,2022-06-14 14:40:10.792787480,1.0,NREM,control1
125158,2022-06-14 14:40:10.980875800,1.0,NREM,control1


In [None]:
sp15_las.ts(slice(laser_laser_on, laser_laser_off)).filt_state()

Unnamed: 0,datetime,spikes,state,condition
0,2022-06-17 11:19:14.181180440,1.0,NREM,laser1
1,2022-06-17 11:19:14.225499160,1.0,NREM,laser1
2,2022-06-17 11:19:14.246061080,1.0,NREM,laser1
3,2022-06-17 11:19:14.316225560,1.0,NREM,laser1
4,2022-06-17 11:19:14.538515480,1.0,NREM,laser1
...,...,...,...,...
55773,2022-06-17 15:18:32.137092120,1.0,NREM,laser1
55774,2022-06-17 15:18:32.305314840,1.0,NREM,laser1
55775,2022-06-17 15:18:32.708688920,1.0,NREM,laser1
55776,2022-06-17 15:18:32.905010200,1.0,NREM,laser1


# Scratch

In [None]:
cond_list = ['control1-eeg', 'control1-lfp', 'laser1-eeg', 'laser1-lfp']
path_root = '/Volumes/opto_loc/Data/ACR_PROJECT_MATERIALS/ACR_9/analysis-data/'
bp = kpd.load_dataset(path_root, cond_list, '-bp')
h = load_hypnos(path_root, ['control1', 'laser1'])

In [None]:
hyp = h['laser1']

In [None]:
ds = bp['laser1-eeg']