# Label acoustic data

Optional notebook within the chronic ephys processing pipeline
- 1-preprocess_acoustics
- 2-curate_acoustics
- 3-sort_spikes
- 4-curate_spikes
- **5-label_acoustics**

*Currently contains functionality to label social context and syllables*

Use the environment **birdsong** to run this notebook

In [None]:
import os
import logging
import socket
import pickle
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
from scipy import stats
from scipy import signal
from scipy.io import wavfile
import IPython.display as ipd
from tqdm.autonotebook import tqdm
from joblib import Parallel, delayed
from datetime import datetime
import copy
import sys

from praatio import textgrid
from praatio import audio

sys.path.append('/mnt/cube/lo/envs')
from plot_sonogram import plot_sonogram as ps

sys.path.append('/mnt/cube/lo/envs/ceciestunepipe')
from ceciestunepipe.file import bcistructure as et

sys.path.append('/mnt/cube/lo/envs/vocalization-segmentation')
from vocalseg.continuity_filtering import plot_labelled_elements

sys.path.append('/mnt/cube/lo/envs/avgn_paper')
from avgn.signalprocessing.filtering import butter_bandpass_filter
from avgn.utils.hparams import HParams
from avgn.signalprocessing.filtering import prepare_mel_matrix
from avgn.signalprocessing.create_spectrogram_dataset import make_spec, mask_spec, log_resize_spec, pad_spectrogram, flatten_spectrograms
from avgn.visualization.spectrogram import draw_spec_set
from avgn.visualization.quickplots import draw_projection_plots
from avgn.visualization.projections import scatter_spec
from avgn.visualization.barcodes import plot_sorted_barcodes
from avgn.visualization.network_graph import plot_network_graph

np.set_printoptions(precision=3, suppress=True)

In [None]:
# session parameters
sess_par = {
    'bird':'z_r12r13_21', # bird ID
    'sess':'2021-06-27', # session date
    'ephys_software':'sglx', # recording software, sglx or oe
    'stim_sess':False, # if song stimulus was played during the session, ignore detected bouts
    'trim_bouts':True, # manually trim bouts after curation
    'sort':'sort_0', # sort index
}

time_F_in = '00:00:00'
time_F_out = None

## Load curated acoustics

In [None]:
sess_epochs = et.list_ephys_epochs(sess_par)
print(f"Found {len(sess_epochs)} epoch(s):", sess_epochs)

In [None]:
this_epoch = sess_epochs[0] # set epoch index
epoch_struct = et.sgl_struct(sess_par,this_epoch,ephys_software=sess_par['ephys_software'])
print('Processing epoch', this_epoch)

# load bout dataframe
bout_df_path = os.path.join(epoch_struct['folders']['derived'],'bout_pd_ap0_curated.pkl')
with open(bout_df_path, 'rb') as handle:
    bout_df = pickle.load(handle)

In [None]:
# get sample rate
if len(bout_df.sample_rate.unique()) > 1:
    print(f"{len(bout_df.sample_rate.unique())} sample rates found:", bout_df.sample_rate.unique())
fs = bout_df.sample_rate.unique()[0]

In [None]:
# get neural sample rate
ap_path = os.path.join(epoch_struct['folders']['derived'],'ap_0_sync_dict.pkl')
with open(ap_path, 'rb') as handle:
    ap_syn_dict = pickle.load(handle)
ap_fs = ap_syn_dict['s_f']

## Add social context

In [None]:
def get_bout_start(start_ms):
    hour = int(np.floor(start_ms/3600000))
    minute = int(np.floor((start_ms/60000)-(hour*60)))
    second = int(np.floor(start_ms % 60000)/1000)
    
    bout_start = datetime.strptime(f"{hour:02}:{minute:02}:{second:02}", "%H:%M:%S").time()
    
    return bout_start

def set_behavior(row, F_in_dt=None, F_out_dt=None):
    bout_start = get_bout_start(row['start_ms'])
    if F_in_dt and bout_start < F_in_dt:
        return 'undirected'
    elif F_out_dt and bout_start > F_out_dt:
        return 'undirected'
    else:
        return 'directed'

def add_social_context(bout_df_in, time_F_in=None, time_F_out=None):
    bout_df_out = bout_df_in.copy()
    
    if time_F_in:
        F_in_dt = datetime.strptime(f"{time_F_in}", "%H:%M:%S").time()
        print('Female introduced at', F_in_dt, '\n')
        bout_df_out['behavior'] = Parallel(n_jobs=-1)(delayed(set_behavior)(row, F_in_dt=F_in_dt) for _, row in bout_df_out.iterrows())
        
    elif time_F_out:
        F_out_dt = datetime.strptime(f"{time_F_out}", "%H:%M:%S").time()
        print('Female removed at', F_out_dt, '\n')
        bout_df_out['behavior'] = Parallel(n_jobs=-1)(delayed(set_behavior)(row, F_out_dt=F_out_dt) for _, row in bout_df_out.iterrows())
    
    return bout_df_out

bout_df = add_social_context(bout_df, time_F_in, time_F_out)

print(len(bout_df[bout_df['behavior']=='undirected']), 'undirected bouts')
print(len(bout_df[bout_df['behavior']=='directed']), 'directed bouts')

## Export wav files to label in Praat

In [None]:
praat_dir = os.path.join(epoch_struct['folders']['derived'],'praat')
os.makedirs(praat_dir, exist_ok=True)

for idx, row in bout_df.iterrows():
    file_path = os.path.join(praat_dir, f"{idx}-{row['start_ms']}.wav")
    wavfile.write(file_path, fs, row['waveform'])

## Import TextGrid files from Praat

In [None]:
bouts_segmented = bout_df.copy()
bouts_segmented['bout_waveform_filt'] = bouts_segmented.apply(lambda r: butter_bandpass_filter(r['waveform'], 300, 12000, r['sample_rate']), axis=1)
bouts_segmented.rename(columns={'waveform': 'bout_waveform_raw'}, inplace=True)

In [None]:
# Create a dataframe for segmented syllables
praat_dir = os.path.join(epoch_struct['folders']['derived'],'praat')
dfs = []
for index, row in bouts_segmented.iterrows():
    
    # Info from bouts
    file = row['file']
    sess = row['sess']
    epoch = row['epoch']
    sample_rate = row['sample_rate']
    bout_index = index
    bout_waveform_raw = row['bout_waveform_raw']
    bout_waveform_filt = row['bout_waveform_filt']
    start_ms_bout = row['start_ms']
    start_sample_bout = row['start_sample']
    start_ms_ap_0_bout = row['start_ms_ap_0']
    start_sample_ap_0_bout = row['start_sample_ap_0']
    
    # Syllable labels from praat
    tg = textgrid.openTextgrid(os.path.join(praat_dir,f"{index}-{row['start_ms']}.TextGrid"),
                               includeEmptyIntervals=False)
    syllables = tg.getTier(tg.tierNames[0])
    on_ss = [interval.start for interval in syllables.entries]
    off_ss = [interval.end for interval in syllables.entries]
    labels = [interval.label for interval in syllables.entries]
    
    data = []
    for syllable_index, (on_s, 
                         off_s,
                         label) in enumerate( zip(on_ss, 
                                                   off_ss,
                                                   labels)
                                             ):
        on_sample = int(start_sample_bout + on_s*fs)
        off_sample = int(start_sample_bout + off_s*fs)
        
        data.append({
            'file': file,
            'sess': sess,
            'epoch': epoch,
            'sample_rate': sample_rate,
            'bout_index': bout_index,
            'bout_waveform_raw': bout_waveform_raw,
            'bout_waveform_filt': bout_waveform_filt,
            'start_ms_ap_0': int(start_ms_ap_0_bout + on_s*1000),
            'start_sample_ap_0': int(start_sample_ap_0_bout + on_sample/fs*ap_fs),
            'syllable_index': syllable_index,
            'on_sample': on_sample,
            'off_sample': off_sample,
            'on_ms': int(start_ms_bout + on_s*1000),
            'off_ms': int(start_ms_bout + off_s*1000),
            'label': label,
            'syllable_waveform': bout_waveform_filt[int(on_s*fs):int(off_s*fs)]})
    
    df = pd.DataFrame(data)
    dfs.append(df)

syl_df = pd.concat(dfs, ignore_index=True)

In [None]:
# normalize audio
syl_df['syllable_waveform'] = [syll/max(np.min(syll), np.max(syll), key=abs) for i, syll in enumerate(syl_df['syllable_waveform'].values)]
syl_df['syllable_waveform'] = [np.nan_to_num(syll) if not np.all(np.isfinite(syll)) else syll for syll in syl_df['syllable_waveform'].values]

In [None]:
# Plot some of the syllables to see how they look
nrows = 10
ncols = 10
zoom = 2
fig, axs = plt.subplots(ncols=ncols, nrows = nrows, figsize = (ncols*zoom, nrows+zoom/1.5))
for i, syll in tqdm(enumerate(syl_df['syllable_waveform'].values), total = nrows*ncols):
    ax = axs.flatten()[i]
    ax.plot(syll)
    if i == nrows*ncols-1:
        break

## Plot syllable spectrograms

In [None]:
syllables_wav = syl_df.syllable_waveform.values
syllables_rate = syl_df.sample_rate.values

In [None]:
hparams = HParams(
    num_mel_bins = 64,
    mel_lower_edge_hertz=300,
    mel_upper_edge_hertz=12000,
    butter_lowcut = 300,
    butter_highcut = 12000,
    ref_level_db = 20,
    min_level_db = -100,
    mask_spec = True,
    win_length_ms = 4,
    hop_length_ms = 1,
    nex = -1,
    n_jobs = -1,
    verbosity = 0,
)

In [None]:
n_jobs = 36
verbosity = 0

In [None]:
# create spectrograms
with Parallel(n_jobs=n_jobs, verbose=verbosity) as parallel:
    syllables_spec = parallel(
        delayed(make_spec)(
            syllable,
            rate,
            hparams=hparams,
            mel_matrix=prepare_mel_matrix(hparams, rate),
            use_mel=True,
            use_tensorflow=False,
        )
        for syllable, rate in tqdm(
            zip(syllables_wav, syllables_rate),
            total=len(syllables_rate),
            desc="getting syllable spectrograms",
            leave=False,
        )
    )

In [None]:
draw_spec_set(syllables_spec, zoom=1, maxrows=10, colsize=40)

In [None]:
# log rescale spectrograms
log_scaling_factor = 4

with Parallel(n_jobs=n_jobs, verbose=verbosity) as parallel:
    syllables_spec = parallel(
        delayed(log_resize_spec)(spec, scaling_factor=log_scaling_factor)
        for spec in tqdm(syllables_spec, desc="scaling spectrograms", leave=False)
    )

In [None]:
draw_spec_set(syllables_spec, zoom=1, maxrows=10, colsize=40)

## Plot syllable barcodes

In [None]:
def song_barcode(start_times, stop_times, labels, label_dict, label_pal_dict, resolution=0.01):
    begin = np.min(start_times)
    end = np.max(stop_times)
    trans_list = (
        np.zeros(int((end - begin) / resolution)).astype("str").astype("object")
    )
    # print(end, begin, end-begin, resolution, len(trans_list))
    for start, stop, label in zip(start_times, stop_times, labels):
        trans_list[
            int((start - begin) / resolution) : int((stop - begin) / resolution)
        ] = label_dict[label]

    color_list = [
        label_pal_dict[i] if i in label_pal_dict else [1, 1, 1] for i in trans_list
    ]
    color_list = np.expand_dims(color_list, 1)

    return trans_list, color_list


def indv_barcode(this_df, time_resolution=0.01, label="label", pal="tab20"):
    unique_labels = this_df[label].unique()
    
    # song palette
    label_pal = np.random.permutation(sns.color_palette(pal, len(unique_labels)))
    label_dict = {lab: str(int(i)).zfill(3) for i, lab in enumerate(unique_labels)}

    label_pal_dict = {
        label_dict[lab]: color for lab, color in zip(unique_labels, label_pal)
    }
    sns.palplot(list(label_pal_dict.values()))

    # get list of syllables by time
    trans_lists = []
    color_lists = []
    for key in tqdm(this_df.bout_index.unique(), leave=False):
        # dataframe of wavs
        wav_df = this_df[this_df['bout_index'] == key]
        labels = wav_df[label].values
        start_times = wav_df.on_ms.values
        stop_times = wav_df.off_ms.values
        trans_list, color_list = song_barcode(
            start_times,
            stop_times,
            labels,
            label_dict,
            label_pal_dict,
            resolution=time_resolution,
        )
        color_lists.append(color_list)
        trans_lists.append(trans_list)

    return color_lists, trans_lists, label_pal_dict, label_pal, label_dict

In [None]:
# Get variables for plotting
print('Syllable barcodes: ' + str(syl_df.label.unique()))

color_lists, trans_lists, label_pal_dict, label_pal, label_dict = indv_barcode(
    syl_df,
    time_resolution=12
)

In [None]:
# Plot syllable barcodes for songs
ids = syl_df.bout_index.unique()

fig, ax = plt.subplots(figsize=(20, 3))
plot_sorted_barcodes(
    [color_lists[i] for i in ids],
    [trans_lists[i] for i in ids],
    max_list_len=600,
    seq_len=100,
    nex=200,
    figsize=(10, 4),
    ax=ax,
)
plt.show()

## Save syl_df

In [None]:
syl_df.to_pickle(os.path.join(epoch_struct['folders']['derived'],'syl_df_ap0.pickle'))