## Inference

Notebook to experiment with the inference process

In [None]:
import os
import json

import wandb
import torch
import pandas as pd
import torchaudio
import sed_eval
from matplotlib import pyplot as plt

from lib.lightning_modules import ClassifierModule
from lib.models.sed import RecurrentCNNModel

## Config

In [None]:
annotations_path = "../dataset/preprocessed_data/frame-2048_hop-1024_chunk-512_spec-mel_mels-128/labels.tsv"
label2idx_path = "../dataset/preprocessed_data/frame-2048_hop-1024_chunk-512_spec-mel_mels-128/label2idx.json"
audio_dir = "../dataset/preprocessed_data/frame-2048_hop-1024_chunk-512_spec-mel_mels-128/audio_tensors"
ckpt_path = "chavicoski/DatathonMarine2022/model-3jcsvq90:v0"

## Load data

In [None]:
annotations = pd.read_csv(annotations_path, sep="\t")
annotations.head()

In [None]:
with open(label2idx_path, "r") as file_handle:
    label2idx = json.load(file_handle)

labels = list(label2idx)
print(label2idx) 
print(labels)

Load one sample (features and mask)

In [None]:
mask = annotations.whistle == 1
#mask &= annotations.click == 1
#mask &= annotations.cetaceans_allfreq == 1
selected_annot = annotations[mask]
sample_annot = selected_annot.sample(1)

In [None]:
sample_annot.feature_path.values[0]

In [None]:
feature_data = torch.load(os.path.join(audio_dir, sample_annot.feature_path.values[0]))
mask_data = torch.load(os.path.join(audio_dir, sample_annot.mask_path.values[0]))
print(f"{feature_data.shape=}")
print(f"{mask_data.shape=}")

## Load the model ckpt using wandb

In [None]:
run = wandb.init()

In [None]:
artifact = run.use_artifact(ckpt_path, type="model")
artifact_dir = artifact.download()

In [None]:
model = RecurrentCNNModel(n_conv_blocks=5, n_classes=len(labels), filters_factor=1, input_height=128, start_n_filters=128)
module = ClassifierModule.load_from_checkpoint(os.path.join(artifact_dir, "model.ckpt"), model=model, labels=labels)
module.eval()

## Predict with one sample

In [None]:
sample_batch = torch.unsqueeze(torch.from_numpy(feature_data), 0)
pred = module.predict_step(sample_batch, 0).detach().numpy()

In [None]:
def plot_spectrogram(specgram, title=None, ylabel="freq_bin", interpolation="antialiased"):
    plt.figure(figsize=(30, 10))
    plt.title(title or "Spectrogram (db)")
    plt.ylabel(ylabel)
    plt.xlabel("frame")
    amplitude_2_DB = torchaudio.transforms.AmplitudeToDB()
    plt.imshow(amplitude_2_DB(specgram), origin="lower", aspect="auto", interpolation=interpolation)
    plt.colorbar()
    plt.show()

def plot_mask(mask, labels):
    plt.figure(figsize=(30, 10))
    plt.imshow(mask, aspect="auto", interpolation="none", cmap="jet")
    plt.yticks(range(len(labels)), labels=labels)
    plt.xlabel("Frame")
    plt.colorbar()
    plt.show()

In [None]:
plot_spectrogram(torch.from_numpy(feature_data)[0])
plot_mask(torch.from_numpy(mask_data), labels)
plot_mask(pred[0], labels)

## Test the sed_eval library

In [None]:
#sample_annot["path"]