# Whistle Detection

## Dataset Preparation

In [None]:
import warnings
warnings.filterwarnings("ignore")
from datamodules.whistle import WhistleDataset
from models import WHISTLEDATA_CONFIG as DATA_PARAMS

dataset = WhistleDataset(name = 'whistle/saved/boh', tobeloaded=True, params=DATA_PARAMS)
dataset.summarize()

## Training

In [None]:
from models import WhistleNet
from models import WHISTLETRAIN_CONFIG as TRAIN_PARAMS
from core.trainer import Trainer
import time
import matplotlib.pyplot as plt # type: ignore

complete_plot = False
train_model = False
names = []
names.append('whistle/boh') # TODO hardcodato
for name in names:
    model = WhistleNet(name,num_classes=5)
    trainer = Trainer(params=TRAIN_PARAMS)
    
    if not train_model:
        model.load(name)
    else:
        start_time = time.time()
        trainer.fit(model,dataset)
        model.training_time = time.time() - start_time
    
    plt.plot(model.test_scores, label=f'{name} - test scores')
    if complete_plot:
        plt.plot(model.train_scores, label=f'{name} - train scores')
        plt.plot(model.val_scores, label=f'{name} - val scores')
        
plt.legend()
plt.ylabel('score')
plt.xlabel('epoch')
plt.show()

## Inference

In [None]:
import torch
# import sounddevice as sd
from datamodules.whistle import Audio
from core.utils import project_root

def classify(window):
    with torch.no_grad():
        prediction = model(window)
    predicted_class = torch.argmax(prediction).item()
    if predicted_class == 0:
        print(f"{prediction}: NO")
    else:
        print(f"{prediction}: YES")
    return predicted_class


projroot = project_root()
audio = Audio(name="test4", datapath=f'{projroot}/data/whistle/raw/test')
for i in range(audio.S.shape[1]):
    window = audio.S[:,i].reshape(1,1,513)
    print(f"frame {i}, time {audio.frame2time(i):.2f}")
    classify(window)
audio.freq_plot()