## BW High Quality

In [None]:

from whale.data_io.data_loader import WhaleDataModule
import torch
from pytorch_lightning import seed_everything
import numpy  as np
from whale.models import LSTM
from matplotlib import pyplot as plt
from whale.utils.spectrogram import show_spectrogram, cal_spectrogram

seed_everything(1234)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ckpt_path = '/network/projects/aia/whale_call/mlruns/756039190186882702/e01ee35dc34641e5ac7141b1864936dd/artifacts/model/checkpoints/epoch=29-step=4350/'

ckpt_path = "/network/projects/aia/whale_call/mlruns/214176838239808987/ef4956f212324b7fb1bcf1436c577743/artifacts/model/checkpoints/epoch=7-step=128/"


best_model_path  = ckpt_path+'epoch=7-step=128.ckpt'
model = LSTM.load_from_checkpoint(best_model_path)
model.eval();
model.to(device);


whale_dm = WhaleDataModule(
    data_dir="/network/projects/aia/whale_call/LABELS/BWC_3CH_HQ", batch_size=1,data_type="spec"
)
whale_dm.setup()

ds_sel = whale_dm.valid_ds
dataset_size = len(ds_sel)

In [None]:
num_samples = 6
idx_rands = np.random.choice(dataset_size, num_samples)

### Results visualization


In [None]:
sig_length = 1601
dt = 0.01
t_axis = np.arange(0,sig_length)*dt
label_dict = {0:'Noise',1:'BW Call'}
fig,axs = plt.subplots(2,num_samples,figsize=(16,6))
for i in range(num_samples):
    idx_rand = idx_rands[i]
    data_sel = ds_sel[idx_rand]
    sig = data_sel['sig']
    spectrogram = data_sel['spec']

    class_logits, reg_out = model(spectrogram.unsqueeze(0).to(device))
    class_pred = torch.argmax(class_logits,axis=1)
    
    axs[0][i].plot(t_axis,sig[0],color='grey');
    axs[0][i].set_title(f"Target Label: {label_dict[data_sel['target_label']]}\n Pred Label:{label_dict[class_pred.cpu().detach().numpy()[0]]}")
    # add axis labels
    axs[0][i].set_xlabel('Time (s)')
    axs[0][i].set_ylabel('Normalized Amplitude')

    # plot spectrogram and add vertical lines representing target time and predicted time
    ## first get the time and freq bins
    _,freq,time=cal_spectrogram(sig[0],samp_rate=100,per_lap=0.9, wlen=0.5, mult=4)
    # calculate half bin width
    halfbin_time = (time[1] - time[0]) / 2.0
    halfbin_freq = (freq[1] - freq[0]) / 2.0
    # this method is much much faster!
    specgram = np.flipud(spectrogram.T)
    # center bin
    extent = (
        time[0] - halfbin_time,
        time[-1] + halfbin_time,
        freq[0] - halfbin_freq,
        freq[-1] + halfbin_freq,
    )
    axs[1][i].imshow(specgram, interpolation="nearest", extent=extent)

    # show_spectrogram(sig[0],axes=axs[1][i],samp_rate=100,per_lap=0.9, wlen=0.5, mult=4)
    axs[1][i].axvline(x=data_sel['target_time'],color='red')
    axs[1][i].axvline(x=reg_out.cpu().detach().numpy(),color='green')
    axs[1][i].set_xlim([0,time[-1]])
    axs[1][i].axis("tight")
    # ax.grid(False)

    # add axis labels
    axs[1][i].set_xlabel('Time (s)')
    axs[1][i].set_ylabel('Frequency (Hz)')
    # show legend
    axs[1][i].legend(['t_trgt','t_pred'])

plt.tight_layout()
# fig.savefig('prediction_examples_baseline_LSTM_cls_reg.png',dpi=300)


## BW Low Quality

In [None]:

from whale.data_io.data_loader import WhaleDataModule
import torch
from pytorch_lightning import seed_everything
import numpy  as np
from whale.models import LSTM
from matplotlib import pyplot as plt
from whale.utils.spectrogram import show_spectrogram

seed_everything(1234)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt_path = "/network/projects/aia/whale_call/mlruns/214176838239808987/cca1b8e4231e4f5a82c8d5937f820f23/artifacts/model/checkpoints/epoch=29-step=19350/"


best_model_path  = ckpt_path+'epoch=29-step=19350.ckpt'
model = LSTM.load_from_checkpoint(best_model_path)
model.eval();
model.to(device);


whale_dm = WhaleDataModule(
    data_dir="/network/projects/aia/whale_call/LABELS/BWC_3CH_LQ", batch_size=1,data_type="spec"
)
whale_dm.setup()

ds_sel = whale_dm.valid_ds
dataset_size = len(ds_sel)

In [None]:
num_samples = 6
idx_rands = np.random.choice(dataset_size, num_samples)

In [None]:
sig_length = 1601
dt = 0.01
t_axis = np.arange(0,sig_length)*dt
label_dict = {0:'Noise',1:'BW Call'}
fig,axs = plt.subplots(2,num_samples,figsize=(16,6))
for i in range(num_samples):
    idx_rand = idx_rands[i]
    data_sel = ds_sel[idx_rand]
    sig = data_sel['sig']
    spectrogram = data_sel['spec']

    class_logits, reg_out = model(spectrogram.unsqueeze(0).to(device))
    class_pred = torch.argmax(class_logits,axis=1)
    
    axs[0][i].plot(t_axis,sig[0],color='grey');
    axs[0][i].set_title(f"Target Label: {label_dict[data_sel['target_label']]}\n Pred Label:{label_dict[class_pred.cpu().detach().numpy()[0]]}")
    # add axis labels
    axs[0][i].set_xlabel('Time (s)')
    axs[0][i].set_ylabel('Normalized Amplitude')

    # plot spectrogram and add vertical lines representing target time and predicted time
    show_spectrogram(sig[0],axes=axs[1][i],samp_rate=100,per_lap=0.9, wlen=0.5, mult=4)
    axs[1][i].axvline(x=data_sel['target_time'],color='red')
    axs[1][i].axvline(x=reg_out.cpu().detach().numpy(),color='green')
    # add axis labels
    axs[1][i].set_xlabel('Time (s)')
    axs[1][i].set_ylabel('Frequency (Hz)')
    # show legend
    # axs[1][i].legend(['t_trgt','t_pred'])

plt.tight_layout()
# fig.savefig('prediction_examples_baseline_LSTM_cls_reg.png',dpi=300)
