# Data Preprocessing, Model Loading, Prediction, Evaluation

This notebook shows how to preprocess audio files, load a trained model, how to predict pitches and evaluate the estimates.

In [None]:
import os
import sys
basepath = os.path.dirname(os.path.abspath('.'))
sys.path.append(basepath)

import numpy as np
import pandas as pd
import librosa
import libfmp
import matplotlib.pyplot as plt
import IPython.display as ipd
import torch
import torchinfo

import libdl

In [None]:
# CPU / GPU 
device = torch.device('cpu')
# device = torch.device('cuda')

## 1. Load and preprocess audio

### Load audio file

In [None]:
fs = 22050

audio_folder = os.path.join(basepath, 'data', 'Schubert_Winterreise', '01_RawData', 'audio_wav')
fn_audio = 'Schubert_D911-23_SC06.wav'

# Load audio
path_audio = os.path.join(audio_folder, fn_audio)
f_audio, fs_load = librosa.load(path_audio, sr=fs)

In [None]:
libfmp.b.plot_signal(f_audio, Fs=fs_load)
ipd.display(ipd.Audio(data=f_audio, rate=fs_load))

### Compute HCQT

In [None]:
# HCQT parameters
bins_per_semitone = 3
hcqt_config = {
    'fs': fs,
    'fmin': librosa.note_to_hz('C1'),  # MIDI pitch 24
    'fs_hcqt_target': 50,
    'bins_per_octave': 12 * bins_per_semitone,
    'num_octaves': 6,
    'num_harmonics': 5,
    'num_subharmonics': 1,
    'center_bins': True,
}

# Compute HCQT
f_hcqt, fs_hcqt, hopsize_cqt = libdl.data_preprocessing.compute_efficient_hcqt(f_audio, **hcqt_config);

### Visualize first harmonic

In [None]:
def plot_matrix_with_ticks(data, title, bins_per_semitone=bins_per_semitone, 
                           hcqt_config=hcqt_config, fs_hqct=fs_hcqt, pitches=True, **kwargs):
    vis_start_sec = 25
    vis_stop_sec = 50
    vis_step_sec = 5
    
    n_bins = bins_per_semitone*12*hcqt_config["num_octaves"]

    plt.rcParams.update({'font.size': 11})
    fig, ax = plt.subplots(1, 2, gridspec_kw={'width_ratios': [1, 0.05]}, figsize=(10, 3.5))
    im = libfmp.b.plot_matrix(data[:, int(vis_start_sec*fs_hcqt):int(vis_stop_sec*fs_hcqt)], 
                              Fs=fs_hcqt, ax=ax, cmap='gray_r', ylabel='MIDI pitch', **kwargs)
    
    if pitches:
        ax[0].set_yticks(np.arange(0, 73, 12))
        ax[0].set_yticklabels([str(24+12*octave) for octave in range(0, hcqt_config["num_octaves"]+1)])
    else:
        ax[0].set_yticks(np.arange(1, n_bins+13, 12*bins_per_semitone))
        ax[0].set_yticklabels([str(24+12*octave) for octave in range(0, hcqt_config["num_octaves"]+1)])
    ax[0].set_xticks(np.arange(0, (vis_stop_sec-vis_start_sec)+vis_step_sec, vis_step_sec))
    ax[0].set_xticklabels(np.arange(vis_start_sec, vis_stop_sec+vis_step_sec, vis_step_sec))
    ax[0].set_title(title)
    plt.tight_layout()

In [None]:
plot_matrix_with_ticks(data=np.log(1+10*np.abs(f_hcqt[:, :, 1])), title='Harmonic 1 (fundamental)', pitches=False)

## 2. Specify and load model

In [None]:
dir_models = os.path.join(basepath, 'experiments', 'models')

# fn_model = '02_schubert_baseline_ae.pt'
# fn_model = '03_schubert_baseline_sup.pt'  
# fn_model = '04_schubert_cva.pt'
# fn_model = '05_schubert_cva_ov.pt'
# fn_model = '06_schubert_cva_b.pt'
fn_model = '07_schubert_cva_ov_b.pt'

In [None]:
# Model parameters
num_octaves_inp = 6
num_output_bins, min_pitch = 72, 24
model_params = {
    'n_chan_input': 6,
    'n_chan_layers': [20, 20, 10, 1],
    'n_bins_in': num_octaves_inp * 12 * 3,
    'n_bins_out': num_output_bins,
    'a_lrelu': 0.3,
    'p_dropout': 0.2
}

if fn_model == '03_schubert_baseline_sup.pt':
    # Model without final sigmoid activation; only for 03_schubert_baseline_sup 
    model = libdl.nn_models.basic_cnn_segm_logit(**model_params)
else:
    # Model with final sigmoid activation
    model = libdl.nn_models.basic_cnn_segm_sigmoid(**model_params)

In [None]:
# Load trained model
model.load_state_dict(torch.load(os.path.join(dir_models, fn_model), map_location=device))

model.to(device)
model.eval();

In [None]:
torchinfo.summary(model, input_size=(1, 6, 574, 216), device=device)

## 3. Predict pitches

### Create dataset object

In [None]:
test_dataset_params = {
    'context': 75,
    'compression': 10   # log-compression applied to HCQT
}

half_context = test_dataset_params['context'] // 2

inputs = np.transpose(f_hcqt, (2, 1, 0))

# Pad input in order to account for context frames
inputs_context = torch.from_numpy(np.pad(inputs, ((0, 0), (half_context, half_context+1), (0, 0))))

# Create dummy targets for dataset object
targets_context = torch.zeros(inputs_context.shape[1], num_output_bins)

test_dataset_params['seglength'] = inputs.shape[1]  # dataset will then contain only 1 segment which includes all frames
test_dataset_params['stride'] = inputs.shape[1]

test_set = libdl.data_loaders.dataset_context_segm(inputs_context, targets_context, test_dataset_params)

### Make prediction

In [None]:
test_batch, _ = test_set[0]

# Batch format
test_batch = test_batch.unsqueeze(dim=0).to(device)

# Predict
y_pred = model(test_batch)

# Apply sigmoid activation if not contained as last layer in model
if model.__class__ == libdl.nn_models.basic_cnns_mctc.basic_cnn_segm_logit:
    y_pred = torch.sigmoid(y_pred)

# Convert prediction to Numpy array
pred = y_pred.to('cpu').detach().squeeze().numpy()

In [None]:
plot_matrix_with_ticks(data=pred.T, title='Pitch prediction', pitches=True, clim=[0.0, 1.0])

### (Visualize predictions + overtone model / bias)

In [None]:
def overtone_model(pred):
    shifts = [12, 19, 24, 28, 31, 34, 36, 38, 40]
    strengths = 0.9 ** np.array(shifts)

    w_overtones = torch.clone(pred)
    for shift, strength in zip(shifts, strengths):
        w_overtones[:, :, shift:] += strength * pred[:, :, :-shift]
    return torch.clip(w_overtones, 0.0, 1.0)

pred_ov = overtone_model(y_pred.squeeze(dim=1))
pred_ov_np = pred_ov.to('cpu').detach().squeeze().numpy()

plot_matrix_with_ticks(data=pred_ov_np.T, title='Pitch prediction + Ov', pitches=True, clim=[0.0, 1.0])

In [None]:
bias = 0.2
pred_ov_b = torch.clip(pred_ov + bias, 0.0, 1.0).to('cpu').detach().squeeze().numpy()

plot_matrix_with_ticks(data=pred_ov_b.T, title='Pitch prediction + Ov + B', pitches=True, clim=[0.0, 1.0])

## 4. Load and convert annotations

In [None]:
annot_folder = os.path.join(basepath, 'data', 'Schubert_Winterreise', '02_Annotations', 'ann_audio_note')
fn_annot = os.path.join(annot_folder, fn_audio[:-4]+'.csv')

if os.path.exists(fn_annot):
    df = pd.read_csv(fn_annot, sep=';', skiprows=1, header=None)
    note_events = df.to_numpy()[:, :3]

    f_annot_pitch = libdl.data_preprocessing.compute_annotation_array_nooverlap(note_events, f_hcqt, fs_hcqt, 
                                                                               annot_type='pitch', shorten=1.0)

In [None]:
if os.path.exists(fn_annot):
    plot_matrix_with_ticks(data=f_annot_pitch[24:97], title='Pitch annotations', pitches=True)

## 5. Multi-pitch evaluation

In [None]:
eval_measures = ['precision', 'recall', 'f_measure', 'cosine_sim', 'binary_crossentropy', 'euclidean_distance',
                 'binary_accuracy', 'soft_accuracy', 'accum_energy', 'roc_auc_measure', 'average_precision_score']

eval_thresh = 0.4

In [None]:
# Thresholding
pred_th = (pred > eval_thresh).astype(float)

plot_matrix_with_ticks(data=pred_th.T, title=f'Pitch prediction after thresholding (tau={eval_thresh})', pitches=True)

In [None]:
if os.path.exists(fn_annot):
    # Calculate metrics
    targ = np.transpose(f_annot_pitch, (1, 0))[:, min_pitch:(min_pitch+num_output_bins)]

    eval_dict = libdl.metrics.calculate_eval_measures(targ, pred, measures=eval_measures, threshold=eval_thresh, save_roc_plot=False)
    eval_numbers = np.fromiter(eval_dict.values(), dtype=float)

    metrics_mpe = libdl.metrics.calculate_mpe_measures_mireval(targ, pred, threshold=eval_thresh, min_pitch=min_pitch)
    mireval_measures = [key for key in metrics_mpe.keys()]
    mireval_numbers = np.fromiter(metrics_mpe.values(), dtype=float)

In [None]:
if os.path.exists(fn_annot):
    for i, meas_name in enumerate(eval_measures):
        print(f'{meas_name:<30} {eval_numbers[i]}')

    print('')

    for i, meas_name in enumerate(mireval_measures):
        print(f'{meas_name:<30} {mireval_numbers[i]}')