## BW High Quality

In [None]:

from whale.data_io.data_loader import WhaleDataModule
import torch
from pytorch_lightning import seed_everything
import numpy  as np
from whale.models import LSTM
from matplotlib import pyplot as plt
from whale.utils.spectrogram import show_spectrogram, cal_spectrogram

seed_everything(1234)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ckpt_path = "/network/projects/aia/whale_call/mlruns/214176838239808987/ef4956f212324b7fb1bcf1436c577743/artifacts/model/checkpoints/epoch=7-step=128/"
# best_model_path  = ckpt_path+'epoch=7-step=128.ckpt'
ckpt_path = "/network/projects/aia/whale_call/mlruns/214176838239808987/ef4956f212324b7fb1bcf1436c577743/artifacts/model/checkpoints/epoch=22-step=368/"
best_model_path  = ckpt_path+'epoch=22-step=368.ckpt'
model = LSTM.load_from_checkpoint(best_model_path)
model.eval();
model.to(device);


whale_dm = WhaleDataModule(
    data_dir="/network/projects/aia/whale_call/LABELS/BWC_3CH_HQ", batch_size=1,data_type="spec"
)
whale_dm.setup()

ds_sel = whale_dm.train_ds
dataset_size = len(ds_sel)

In [None]:
num_samples = 6
idx_rands = np.random.choice(dataset_size, num_samples)

In [None]:
ds_sel[idx_rands[0]]

### Random Results visualization


In [None]:
sig_length = 1601
dt = 0.01
t_axis = np.arange(0,sig_length)*dt
label_dict = {0:'Noise',1:'BW Call'}
fig,axs = plt.subplots(2,num_samples,figsize=(16,6))
for i in range(num_samples):
    idx_rand = idx_rands[i]
    data_sel = ds_sel[idx_rand]
    sig = data_sel['sig']
    spectrogram = data_sel['spec']

    class_logits, reg_out = model(spectrogram.unsqueeze(0).to(device))
    class_pred = torch.argmax(class_logits,axis=1)
    
    axs[0][i].plot(t_axis,sig[0],color='grey');
    axs[0][i].set_title(f"Target Label: {label_dict[data_sel['target_label']]}\n"+ 
                        f"Pred Label:{label_dict[class_pred.cpu().detach().numpy()[0]]}\n"+
                        f"Group ID: {data_sel['meta_data']['group_id']}\n")
    # add axis labels
    axs[0][i].set_xlabel('Time (s)')
    axs[0][i].set_ylabel('Normalized Amplitude')

    # plot spectrogram and add vertical lines representing target time and predicted time
    ## first get the time and freq bins
    _,freq,time=cal_spectrogram(sig[0],samp_rate=100,per_lap=0.9, wlen=0.5, mult=4)
    # calculate half bin width
    halfbin_time = (time[1] - time[0]) / 2.0
    halfbin_freq = (freq[1] - freq[0]) / 2.0
    # this method is much much faster!
    specgram = np.flipud(spectrogram.T)
    # center bin
    extent = (
        time[0] - halfbin_time,
        time[-1] + halfbin_time,
        freq[0] - halfbin_freq,
        freq[-1] + halfbin_freq,
    )
    axs[1][i].imshow(specgram, interpolation="nearest", extent=extent)

    # show_spectrogram(sig[0],axes=axs[1][i],samp_rate=100,per_lap=0.9, wlen=0.5, mult=4)
    axs[1][i].axvline(x=data_sel['target_time'],color='red')
    axs[1][i].axvline(x=reg_out.cpu().detach().numpy(),color='green')
    axs[1][i].set_xlim([0,time[-1]])
    axs[1][i].axis("tight")
    # ax.grid(False)

    # add axis labels
    axs[1][i].set_xlabel('Time (s)')
    axs[1][i].set_ylabel('Frequency (Hz)')
    # show legend
    axs[1][i].legend(['t_trgt','t_pred'])

plt.tight_layout()
# fig.savefig('prediction_examples_baseline_LSTM_cls_reg.png',dpi=300)


### Iterate through the entire dataset (WARING: It will generate a large number of figures)

In [None]:
sig_length = 1601
dt = 0.01
t_axis = np.arange(0,sig_length)*dt
label_dict = {0:'Noise',1:'BW Call'}

for i in range(dataset_size):
    fig,ax = plt.subplots(figsize=(5,5))
    # idx_rand = idx_rands[i]
    data_sel = ds_sel[i]
    sig = data_sel['sig']
    group_id = data_sel['meta_data']['group_id']
    r_max_UTC = data_sel['meta_data']['time_R_max']
    spectrogram = data_sel['spec']

    class_logits, reg_out = model(spectrogram.unsqueeze(0).to(device))
    class_pred = torch.argmax(class_logits,axis=1)
  

    # plot spectrogram and add vertical lines representing target time and predicted time
    ## first get the time and freq bins
    _,freq,time=cal_spectrogram(sig[0],samp_rate=100,per_lap=0.9, wlen=0.5, mult=4)
    # calculate half bin width
    halfbin_time = (time[1] - time[0]) / 2.0
    halfbin_freq = (freq[1] - freq[0]) / 2.0
    # this method is much much faster!
    specgram = np.flipud(spectrogram.T)
    # center bin
    extent = (
        time[0] - halfbin_time,
        time[-1] + halfbin_time,
        freq[0] - halfbin_freq,
        freq[-1] + halfbin_freq,
    )
    ax.imshow(specgram, interpolation="nearest", extent=extent)

    # show_spectrogram(sig[0],axes=axs[1][i],samp_rate=100,per_lap=0.9, wlen=0.5, mult=4)
    ax.axvline(x=data_sel['target_time'],color='red')
    ax.axvline(x=reg_out.cpu().detach().numpy(),color='green')
    ax.set_xlim([0,time[-1]])
    ax.axis("tight")
    # ax.grid(False)
    ax.set_title(f"Target Label: {label_dict[data_sel['target_label']]} "+ 
                        f"Pred Label:{label_dict[class_pred.cpu().detach().numpy()[0]]}\n"+
                        f"Group ID: {group_id}")

    # add axis labels
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Frequency (Hz)')
    # show legend
    ax.legend(['t_trgt','t_pred'])

    plt.tight_layout()
    fig.savefig(f'bwc_vis_train/group_id_{group_id}_{r_max_UTC}.png',dpi=300)
    plt.close(fig)


## A high quality BW detection

In [None]:
import pandas as pd
import numpy as np
from obspy import read, UTCDateTime
fw_path = '/network/projects/aia/whale_call/LABELS/BW/'

bw_filt = pd.read_csv(fw_path+'bw_filt.csv')
R0=5
SNR0=5
# Create a column containing the average SNR of the same group-id
bw_filt['SNR_avg'] = bw_filt.groupby('group_id')['SNR'].transform('mean')
bw_filt = bw_filt[bw_filt['SNR_avg']>SNR0]
# Create a column containing the average R of the same group-id
bw_filt['R_avg'] = bw_filt.groupby('group_id')['R'].transform('mean')
bw_filt = bw_filt[bw_filt['R_avg']>R0]
## Get a random detection group
bw_detection_group= bw_filt['group_id'].unique()
group_id = np.random.choice(bw_detection_group)


In [None]:
bw_data_sample = bw_filt[bw_filt['group_id']==group_id]
list_spec = []

for component in bw_data_sample.component.unique():

    one_component = bw_data_sample[bw_data_sample.component == component].copy()
    ## Get the ealiest time_window_start within the group
    t0 = bw_data_sample.sort_values(by=['time_window_start']).iloc[0]['time_window_start']
    ## Get the latest time_window_end within the group
    t1 = bw_data_sample.sort_values(by=['time_window_start'],ascending=False).iloc[0]['time_window_end']
    ## Get a list of unique call_start_time within the group
    call_list = bw_data_sample['time_R_max'].unique()

    sac_file = bw_data_sample.sample(n=1).iloc[0]['file_path']
    t0 = UTCDateTime(t0)
    t1 = UTCDateTime(t1)

    st = read(sac_file)
    st_sliced = st.slice(starttime=t0,endtime=t1)
    data_len = len(st_sliced[0].data)

    input_spec, freq,time= cal_spectrogram(
                st_sliced[0].data,
                samp_rate=100,
                per_lap=0.9,
                wlen=0.5,
                mult=4,
            )
    list_spec.append(input_spec)

input_spec = np.average(list_spec, axis=0)


# calculate half bin width
halfbin_time = (time[1] - time[0]) / 2.0
halfbin_freq = (freq[1] - freq[0]) / 2.0
# this method is much much faster!
specgram = np.flipud(input_spec)
# center bin
extent = (
    time[0] - halfbin_time,
    time[-1] + halfbin_time,
    freq[0] - halfbin_freq,
    freq[-1] + halfbin_freq,
)

fig,ax = plt.subplots(figsize=(10,3))
ax.imshow(specgram, interpolation="nearest", extent=extent)
ax.axis("tight")
ax.set_xlabel('Time (s)')
ax.set_ylabel('Frequency (Hz)')
# ax.set_yscale("log")
ax.set_title(f"Group ID: {group_id} T0: {t0}")
for call_t in call_list:
    call_t = UTCDateTime(call_t)
    # ax.axvline(x=call_t-t0-4,color='red')
    ax.axvline(x=call_t-t0,color='red')
    # ax.axvline(x=call_t+4-t0,color='green')
plt.tight_layout()
# fig.savefig(f'group_id_{group_id}.png',dpi=300)
