In [1]:
%matplotlib notebook
import os
import pickle
import numpy as np
import pandas as pd
from scipy.io import wavfile
import sys
sys.path.append('/mnt/cube/tsmcpher/code/')
from ceciestunepipe.file import bcistructure as et
from ceciestunepipe.util.sound import boutsearch as bs
from ceciestunepipe.pipeline import searchbout as sb
from ceciestunepipe.mods import curate_bouts as cb
from ceciestunepipe.mods import preproc_oe

h5py version > 2.10.0. Some extractors might not work properly. It is recommended to downgrade to version 2.10.0: 
>>> pip install h5py==2.10.0


In [14]:
sess_par = {
    'bird':'s_b1527_23',
    'sess':'2023-09-03',
    'software':'oe'
}

stim_map_dict = {
    '/home/bird/tsmcpher/chronic_songs/s_b1527_23/2023-08-20/':'/home/AD/tsmcpher/chronic_songs/s_b1527_23/2023-08-20/',
    '/home/bird/tsmcpher/chronic_songs/s_b1555_22/2022-08-17/':'/home/AD/tsmcpher/chronic_songs/s_b1555_22/2022-08-17/',
}

In [15]:
# directories / files
exp_struct = et.get_exp_struct(sess_par['bird'],sess_par['sess'],ephys_software=sess_par['software'])
bouts_folder = exp_struct['folders']['processed'][:-2] + 'bouts_{}'.format(sess_par['software'])
os.makedirs(bouts_folder,exist_ok=True)
sess_bouts_curated_file = os.path.join(bouts_folder,'bout_curated.pickle')

# load bouts
hparams, bout_pd = sb.load_bouts(sess_par['bird'],sess_par['sess'],'',
                                 derived_folder='bouts_{}'.format(sess_par['software']),
                                 bout_file_key='bout_auto_file')

In [16]:
# sort by start sample
bout_pd.sort_values('start_sample',ascending=True,inplace=True)
bout_pd.reset_index(drop=True,inplace=True)

# bout for all epoches where stim overlap is removed
remaining_bouts_all_epoch = []
# get epoch mic files from preprocessing
epoch_wav_mic_files = np.unique(bout_pd['file'])
# loop through epochs
for this_epoch_wav_mic in epoch_wav_mic_files:
    # get bouts for this epoch
    this_epoch_bout_pd = bout_pd.iloc[np.where(bout_pd['file'] == this_epoch_wav_mic)]
    # load mic metadata file for this epoch
    pkl_path = this_epoch_wav_mic.split('.')[0] + '-npy_meta.pickle'
    with open(pkl_path, 'rb') as fp:
        meta_dict = pickle.load(fp)
    # this epoch name - get recording events path
    this_epoch = this_epoch_wav_mic.split('/')[-2]
    raw_folder = exp_struct['folders']['oe']
    epoch_path = os.path.join(raw_folder,this_epoch)
    node_path = preproc_oe.get_default_node(exp_struct,this_epoch)
    rec_path = preproc_oe.get_default_recording(node_path)
    events_path = os.path.join(rec_path,'events/Network_Events-102.0/TEXT_group_1/')
    # load stim lables / onsets
    stim_labels = np.load(os.path.join(events_path,'text.npy'))
    stim_onsets = np.load(os.path.join(events_path,'timestamps.npy'))
    
    # get stim onsets and offsets
    stim_on_all = []
    stim_off_all = []
    # loop through stim
    for stim_i in range(len(stim_labels)):
        this_stim_label = stim_labels[stim_i].astype('str')
        this_stim_onset = stim_onsets[stim_i]
        # get stim name / experimental folder
        stim_name = this_stim_label.split('/')[-1]
        # only stim from folder
        if len(this_stim_label.split('/')) > 1:
            stim_exp_dir = this_stim_label[len(this_stim_label.split('/')[0]):-len(stim_name)]
            # get stim file for processing
            stim_file = os.path.join(stim_map_dict[stim_exp_dir],stim_name)
            # load stim and get length
            sf,this_wav = wavfile.read(stim_file,mmap=True)
            stim_len = this_wav.shape[0]/sf
            # get length of stim in samples - round up
            stim_samp_len = int(np.ceil(stim_len * meta_dict['s_f']))
            # get stim on / off
            stim_on_all.append(this_stim_onset)
            stim_off_all.append(this_stim_onset+stim_samp_len)  

    # get bouts that overlap with stim
    stim_bout_list = []
    # loop through stim
    for stim_i in range(len(stim_on_all)):
        # find stim onsets / offsets that occur within a bout
        on_after_start = np.where(this_epoch_bout_pd.start_sample <= stim_on_all[stim_i])[0]
        on_before_end = np.where(this_epoch_bout_pd.end_sample >= stim_on_all[stim_i])[0]
        off_after_start = np.where(this_epoch_bout_pd.start_sample <= stim_off_all[stim_i])[0]
        off_before_end = np.where(this_epoch_bout_pd.end_sample >= stim_off_all[stim_i])[0]
        on_in_bout = np.intersect1d(on_after_start,on_before_end)
        off_in_bout = np.intersect1d(off_after_start,off_before_end)
        # if stim happens during bout store bout index
        if len(on_in_bout) == 1:
            stim_bout_list.append(on_in_bout[0])
        elif len(on_in_bout) > 1:
            breakme
        if len(off_in_bout) == 1:
            stim_bout_list.append(off_in_bout[0])
        elif len(off_in_bout) > 1:
            breakme
        # fin stim onstets / offsets that occur between bouts
        on_before_start = np.where(this_epoch_bout_pd.start_sample >= stim_on_all[stim_i])[0]
        on_after_end = np.where(this_epoch_bout_pd.end_sample <= stim_on_all[stim_i])[0]
        off_before_start = np.where(this_epoch_bout_pd.start_sample >= stim_off_all[stim_i])[0]
        off_after_end = np.where(this_epoch_bout_pd.end_sample <= stim_off_all[stim_i])[0]
        on_out_bout = np.intersect1d(on_before_start,on_after_end+1)
        off_out_bout = np.intersect1d(off_before_start,off_after_end+1)
        # stim onset and offset are both between bouts
        if (len(on_out_bout) > 0) & (len(off_out_bout) > 0):
            # stim spans multiple bouts
            if (off_out_bout[0] - on_out_bout[0]) != 0:
                # store bout indexes of all spanned stim
                for bout_span_i in range(on_out_bout[0],off_out_bout[0]):
                    stim_bout_list.append(bout_span_i)
        elif (len(on_out_bout) > 1) & (len(off_out_bout) > 1):
            breakme
    
    # get unique bouts that overlap with stims
    stim_bouts_unique = np.unique(np.array(stim_bout_list))
    # get rid of them
    remaining_bouts = this_epoch_bout_pd.drop(stim_bouts_unique,errors="ignore")
    # store
    remaining_bouts_all_epoch.append(remaining_bouts)
    
# regroup across epochs
bout_pd_no_stim = pd.concat(remaining_bouts_all_epoch)

# sort by duration
bout_pd_no_stim.sort_values('len_ms',ascending=False,inplace=True)
bout_pd_no_stim.reset_index(drop=True,inplace=True)

print(len(bout_pd),len(bout_pd_no_stim))

1158 161


In [17]:
viz_bout = cb.VizBout(hparams,bout_pd_no_stim)

<IPython.core.display.Javascript object>



In [6]:
bpd = cb.sess_bout_summary(bout_pd)
sum_dict = cb.give_summary(viz_bout.bouts_pd)
sum_dict

{'/mnt/sphere/chronic_ephys/der/s_b1555_22/2022-08-17/oe/2022-08-17_12-38-44_550/wav_mic.npy': 40}

### Lo's additions below:

In [None]:
bout_pd = viz_bout.bouts_pd

In [None]:
## some of these redundant with first code block
%matplotlib widget
import os
import sys
import pickle
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output

from scipy.io import wavfile
from scipy.io import loadmat

sys.path.append('/mnt/cube/lo/envs/ceciestunepipe')
from ceciestunepipe.file import bcistructure as et
from ceciestunepipe.lo import trim_bout as tb
from ceciestunepipe.util.sound import boutsearch as bs

In [None]:
## add to second code block?
do_trim_bouts = True

In [None]:
## trim bouts using WIDGET !!!
if do_trim_bouts:
    trim_bouts = tb.TrimBout(hparams, bout_pd)

In [None]:
## generate new trimmed bouts df
if do_trim_bouts:
    if len(bout_pd) > 100: print('Trimming bouts may take a while: there are ' + str(len(bout_pd)) + ' bouts.')
    start_ms = (trim_bouts.crop_min * 1000).astype(int)
    end_ms = (trim_bouts.crop_max * 1000).astype(int)
    bout_pd, bout_dict = tb.handle_trim_bouts(bout_pd, syn_dict_path, start_ms, end_ms, hparams)

In [None]:
# update hparams
if do_trim_bouts:
    hparams['bout_curated_file'] = 'bout_checked_trimmed.pickle'
    bs.save_bouts_params_dict(hparams, os.path.join(bouts_folder, 'bout_search_params.pickle') )

In [None]:
# save curated bouts
sb.save_auto_bouts(bout_pd,sess_par,hparams,software=sess_par['software'],bout_file_key='bout_curated_file')

### Original save:

In [10]:
# save
viz_bout.bouts_pd.to_pickle(sess_bouts_curated_file)
sb.save_auto_bouts(viz_bout.bouts_pd,sess_par,hparams,software=sess_par['software'],bout_file_key='bout_curated_file')