## Inference

Notebook to experiment with the inference process

In [None]:
import os
import json

import wandb
import torch
import pandas as pd
import torchaudio
from matplotlib import pyplot as plt

from lib.lightning_modules import ClassifierModule
from lib.models.sed import RecurrentCNNModel

## Config

In [None]:
annotations_path = "../dataset/preprocessed_data/frame-2048_hop-1024_chunk-128_spec-mel_mels-128_silence-label/labels.tsv"
label2idx_path = "../dataset/preprocessed_data/frame-2048_hop-1024_chunk-128_spec-mel_mels-128_silence-label/label2idx.json"
audio_dir = "../dataset/preprocessed_data/frame-2048_hop-1024_chunk-128_spec-mel_mels-128_silence-label/audio_tensors"
ckpt_path = "chavicoski/DatathonMarine2022/model-uzdwerxt:v0"

## Load data

In [None]:
annotations = pd.read_csv(annotations_path, sep="\t")
annotations.head()

In [None]:
with open(label2idx_path, "r") as file_handle:
    label2idx = json.load(file_handle)

labels = list(label2idx)
print(label2idx) 
print(labels)

Load one sample (features and mask)

In [None]:
mask = annotations.whistle == 1
#mask &= annotations.click == 1
#mask &= annotations.cetaceans_allfreq == 1
selected_annot = annotations[mask]
sample_annot = selected_annot.sample(1)

In [None]:
sample_annot.feature_path.values[0]

In [None]:
feature_data = torch.load(os.path.join(audio_dir, sample_annot.feature_path.values[0]))
mask_data = torch.load(os.path.join(audio_dir, sample_annot.mask_path.values[0]))
print(f"{feature_data.shape=}")
print(f"{mask_data.shape=}")

## Load the model ckpt using wandb

In [None]:
run = wandb.init()

In [None]:
artifact = run.use_artifact(ckpt_path, type="model")
artifact_dir = artifact.download()

In [None]:
model = RecurrentCNNModel(
    n_conv_blocks=3, 
    n_classes=len(labels), 
    filters_factor=1, 
    input_height=128, 
    start_n_filters=128, 
    drop_factor=0.2,
    lstm_h_size=64,
    pool_factor=4
)
module = ClassifierModule.load_from_checkpoint(os.path.join(artifact_dir, "model.ckpt"), model=model, labels=labels)
module.eval()

## Predict with one sample

In [None]:
sample_batch = torch.unsqueeze(torch.from_numpy(feature_data), 0)
pred = module.predict_step(sample_batch, 0).detach().numpy()

In [None]:
def plot_spectrogram(specgram, title=None, ylabel="freq_bin", interpolation="antialiased", to_db=False):
    plt.figure(figsize=(30, 10))
    plt.title(title or "Spectrogram (db)")
    plt.ylabel(ylabel)
    plt.xlabel("frame")
    if to_db:
        amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
        specgram = amplitude_to_db(specgram)
    plt.imshow(specgram, origin="lower", aspect="auto", interpolation=interpolation)
    plt.colorbar()
    plt.show()

def plot_mask(mask, labels):
    plt.figure(figsize=(30, 10))
    plt.imshow(mask, aspect="auto", interpolation="none", cmap="jet")
    plt.yticks(range(len(labels)), labels=labels)
    plt.xlabel("Frame")
    plt.colorbar()
    plt.show()

Convert the matrix with logits into a mask by taking the argmax class for each frame

In [None]:
def pred2mask(pred_tensor: torch.Tensor) -> torch.Tensor:
    """Takes the prediction for one sample with shape (labels, frames)
    and outputs the corresponding binary mask"""
    pred_mask = torch.zeros(pred_tensor.shape)
    for f in range(pred_tensor.shape[1]):
        f_label = pred_tensor[:, f].argmax()
        pred_mask[f_label, f] = 1
    return pred_mask

In [None]:
plot_spectrogram(torch.from_numpy(feature_data)[0])
plot_mask(torch.from_numpy(mask_data), labels)
plot_mask(pred2mask(pred[0]), labels)

## Predict one audio

In [None]:
audio_id = sample_annot.feature_path.values[0].split("_")[0]
audio_chunks_annot = annotations[annotations.feature_path.str.startswith(audio_id)]
chunks_pred_masks = []
for _, chunk_annot in audio_chunks_annot.iterrows():
    feature_data = torch.load(os.path.join(audio_dir, chunk_annot.feature_path))
    sample_batch = torch.unsqueeze(torch.from_numpy(feature_data), 0)
    logits = module.predict_step(sample_batch, 0).detach().numpy()
    chunks_pred_masks.append(pred2mask(logits[0]))

print(len(chunks_pred_masks))
print(chunks_pred_masks[0].shape)

In [None]:
plot_mask(torch.cat(chunks_pred_masks, axis=1), labels)