In [1]:

import sys
sys.path.append('..')

from src.data.dataset import AudioDataset
import pandas as pd
import torch
from pandas import DataFrame

%load_ext autoreload
# We use the development dataset for this example:
base_dir = '../data/training'
metadata_df: DataFrame | None = pd.read_csv('../data/training/metadata.tsv', sep='\t', names=['file_reference', 'start_time', 'end_time', 'label'])




In [2]:

from src.data.split import compare_wav_file_refs, split_list
import functools
from src.data.dataset import get_file_names

wav_refs, labels = get_file_names(base_dir)
split_percentage = 0.8

gunshots = [wav_ref for wav_ref in wav_refs if wav_ref.label == 'Gunshot']
rumbles = [wav_ref for wav_ref in wav_refs if wav_ref.label == 'Rumble']

gunshots_sorted = sorted(gunshots, key=functools.cmp_to_key(compare_wav_file_refs))
rumbles_sorted = sorted(rumbles, key=functools.cmp_to_key(compare_wav_file_refs))

# Split the data
train_gunshots, val_gunshots = split_list(gunshots_sorted, split_percentage)
train_rumbles, val_rumbles = split_list(rumbles_sorted, split_percentage)

# Merge the data
training_wav_files = train_gunshots + train_rumbles
val_wav_files = val_gunshots + val_rumbles

In [3]:
val_dataset = AudioDataset('', metadata_df, wav_files = val_wav_files)

# Load the AST

In [4]:
from src.models.utils import load_model

model = load_model('E:\Python Projects\Fruitpunch\Elephants\model-exploration\checkpoints_ast\example_simple_transformer_best.pt')

  from .autonotebook import tqdm as notebook_tqdm


FileNotFoundError: [Errno 2] No such file or directory: 'E:\\Python Projects\\Fruitpunch\\Elephants\\model-exploration\\checkpoints_ast\\example_simple_transformer_best.pt'

Load the HTS

In [None]:
# import htsat
from hts_transformer.htsat import HTSAT_Swin_Transformer
import hts_transformer.config as config

model = HTSAT_Swin_Transformer(
    spec_size=config.htsat_spec_size,
    patch_size=config.htsat_patch_size,
    in_chans=1,
    num_classes=config.classes_num,
    window_size=config.htsat_window_size,
    config = config,
    depths = config.htsat_depth,
    embed_dim = config.htsat_dim,
    patch_stride=config.htsat_stride,
    num_heads=config.htsat_num_head
)

# Load from checkpoint
checkpoint = torch.load('E:\Python Projects\Fruitpunch\Elephants\model-exploration\checkpoints\epoch=33-step=27302.ckpt', map_location=torch.device('cuda'))
model.load_state_dict(checkpoint['state_dict'])

In [None]:
from src.data.dataset import collate_fn
# Create a dataloader to iterate over the dataset
from torch.utils.data import DataLoader
dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

In [None]:

import torchmetrics
from tqdm import tqdm
from src.metrics import get_metrics
import torch

@torch.no_grad()
def calculate_metrics_for_rumbles(model, dataloader):
    model = model.to('cuda')
    metrics  = {
        'acc': torchmetrics.Accuracy(task='binary').to('cuda'),
        'f1': torchmetrics.F1Score(task='binary', average='macro').to('cuda'),
        'precision': torchmetrics.Precision(task='binary', average='macro').to('cuda'),
        'recall': torchmetrics.Recall(task='binary', average='macro').to('cuda'),

    }

    for batch in tqdm(dataloader):
        audio, labels = batch
        audio = audio.to('cuda')
        labels = labels.to('cuda')
        predictions = model(audio)  # Swap to model.infer(audio) for HTS

        for key, metric in metrics.items():
            metric( predictions[:, 0], labels[:, 0])
    return metrics

In [None]:

import torchmetrics
from tqdm import tqdm
from src.metrics import get_metrics
import torch

device = 'cpu'

@torch.no_grad()
def calculate_metrics_for_all(model, dataloader):
    model = model.to(device)
    metrics  = {
        'acc': torchmetrics.Accuracy(task='binary').to(device),
        'f1': torchmetrics.F1Score(task='binary', average='macro').to(device),
        'precision': torchmetrics.Precision(task='binary', average='macro').to(device),
        'recall': torchmetrics.Recall(task='binary', average='macro').to(device),

    }

    for batch in tqdm(dataloader):
        audio, labels = batch
        audio = audio.to(device)
        labels = labels.to(device)
        predictions = model(audio) # Swap to model.infer(audio) for HTS

        for key, metric in metrics.items():
            metric( predictions, labels)
    return metrics

In [None]:
metrics = calculate_metrics_for_all(model, dataloader)

In [None]:
metric_results = {key: m.compute().to('cpu') for key, m in metrics.items()}

In [None]:
metrics = calculate_metrics_for_all(model, dataloader)

In [None]:
metric_results = {key: m.compute().to('cpu') for key, m in metrics.items()}

In [57]:
print(metric_results)

{'acc': tensor(0.8483), 'f1': tensor(0.7650), 'precision': tensor(0.7073), 'recall': tensor(0.8330)}
