In [None]:
import torch
import os
import torchaudio
import torchaudio.functional as F
import glob
import numpy as np
import soundfile as sf
from tqdm import tqdm
from pesto import load_model
import mir_eval
from mir1k_dataset import MIR1KDataset
from torch.utils.data import Dataset
import mirdata
import pandas as pd
from scipy import signal

#### EVAL ON MIR-1K

In [None]:
class MIR1KDataset(torch.utils.data.Dataset):
    def __init__(self, root, mix=0.0, target_sr=16000):
        self.wav_files = sorted(glob.glob(os.path.join(root, "Wavfile", "*.wav")))
        self.pitch_dir = os.path.join(root, "PitchLabel")
        self.mix = mix
        self.target_sr = target_sr
        self.resampler_cache = {}

    def __len__(self):
        return len(self.wav_files)

    def _get_resampler(self, sr):
        if sr not in self.resampler_cache:
            self.resampler_cache[sr] = T.Resample(sr, self.target_sr)
        return self.resampler_cache[sr]

    def generate_rir(self, sr, duration=1.0):
        # Generate simple synthetic RIR (exponential decay noise)
        t = torch.linspace(0, duration, int(duration * sr))
        envelope = torch.exp(-5.0 * t)
        noise = torch.randn_like(t)
        rir = noise * envelope
        return rir / torch.norm(rir)

    def __getitem__(self, idx):
        wav_path = self.wav_files[idx]
        basename = os.path.basename(wav_path)
        pv_name = basename.replace(".wav", ".pv")
        pv_path = os.path.join(self.pitch_dir, pv_name)

        x, sr = torchaudio.load(wav_path)
        x = x[1:2, :] # take vocals only
        if sr != self.target_sr:
            x = self._get_resampler(sr)(x)

        # Apply Reverb
        if self.mix > 0:
            rir = self.generate_rir(self.target_sr).to(x.device).unsqueeze(0) # [1, T_rir]
            # Convolve
            rev = F.fftconvolve(x, rir, mode="full")
            rev = rev[:, :x.shape[1]]  # Keep original length
            
            # Normalize energy
            x_norm = torch.norm(x)
            rev_norm = torch.norm(rev)
            if rev_norm > 0:
                rev = rev * (x_norm / rev_norm)
            
            x = (1 - self.mix) * x + self.mix * rev

        # Load Labels (MIDI Semitones)
        # 0 = Unvoiced, >0 = Voiced Pitch
        if os.path.exists(pv_path):
            labels = np.loadtxt(pv_path)
            # We convert MIDI labels to Hz for consistent evaluation.
            # Hz = 440 * 2^((d-69)/12)
            # Mask unvoiced (0) to avoid log/exp errors
            mask = labels > 0
            labels_hz = np.zeros_like(labels)
            labels_hz[mask] = 440.0 * (2.0 ** ((labels[mask] - 69.0) / 12.0))
        else:
            labels_hz = np.array([])

        return x, labels_hz


In [None]:
HOP_SIZE_SECONDS = 0.020
MIR_1K_PATH = "./MIR-1K"
pesto_model = load_model(
    'mir-1k_g7',
    step_size=20.,
    sampling_rate=16000, # mir-1k is in sampled @16k
    max_batch_size=4
)

In [None]:
def run_evaluation(dataset_root):
    
    mix_levels = [0.0, 0.1, 0.2, 0.3, 0.6, 0.9]
    
    for mix in mix_levels:
        print(f"\n--- Evaluating Reverb Mix: {mix} ---")

        dataset = MIR1KDataset(dataset_root, mix=mix, target_sr=16000)
        loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
        
        # Accumulators for metrics
        total_rpa = 0.0
        total_rca = 0.0
        count = 0
        metrics = {}
        metrics["RPA"] = 0
        metrics["RCA"] = 0
        metrics["OA"] = 0
        for batch in tqdm(loader):
            x, y = batch
            # x: [1, 1, T], y: [1, L]
            
            audio_input = x.squeeze(0) # [1, T]
            if count==0:
                sf.write(f"audio_{mix}.wav", audio_input[0].detach().cpu().numpy(),  samplerate=16000)

            with torch.no_grad():
                # Get Pitch (Hz) and Confidence
                pitch, conf, _ = pesto_model(audio_input, convert_to_freq=True, return_activations=False)
            import ipdb
            # --- Post-Processing ---
            pred_hz = pitch.cpu().numpy().flatten()
            conf_np = conf.cpu().numpy().flatten()
            ref_hz = y.numpy().flatten()

            # Apply Voicing Threshold
            # Set unconfident predictions to 0 (Unvoiced)
            est_freq = pred_hz.copy()
            # est_freq[conf_np < 0.2] = 0.0
            
            # Generate Timestamps
            # We construct time arrays so mir_eval knows exactly where each frame sits.
            est_time = np.arange(len(est_freq)) * HOP_SIZE_SECONDS
            ref_time = np.arange(len(ref_hz)) * HOP_SIZE_SECONDS

            # mir_eval handles the interpolation/alignment automatically based on the timestamps
            # We skip files where ground truth is empty or invalid
            if len(ref_hz) > 0 and np.sum(ref_hz) > 0:
                scores = mir_eval.melody.evaluate(ref_time, ref_hz, est_time, est_freq)
                total_rpa += scores['Raw Pitch Accuracy']
                total_rca += scores['Raw Chroma Accuracy']
                count += 1

        if count > 0:
            avg_rpa = total_rpa / count
            avg_rca = total_rca / count
            print(f"Result (Mix {mix}): RPA={avg_rpa*100:.2f}% | RCA={avg_rca*100:.2f}%")
        else:
            print("No valid samples evaluated.")

In [None]:
run_evaluation(MIR_1K_PATH)

### EVAL ON MDB

The dataset will be automatically downloaded

In [None]:
class MDBDataset(Dataset):
    def __init__(self, dataset_name: str, mix : float = 0.0):
        # Initialize the loader, download if required, and validate
        self.loader = mirdata.initialize(dataset_name)
        self.loader.download()
        self.loader.validate()
        
        # batch size must be 1 because here we do not pad the items
        self.mix = mix

    def __len__(self) -> int:
        return len(self.loader.track_ids)

    def generate_rir(self, sr, duration=1.0):
        # Generate simple synthetic RIR (exponential decay noise)
        t = torch.linspace(0, duration, int(duration * sr))
        envelope = torch.exp(-5.0 * t)
        noise = torch.randn_like(t)
        rir = noise * envelope
        return (rir / torch.norm(rir)).numpy()

    def __getitem__(self, item: int) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
        # Unpack the current track
        track_id = self.loader.track_ids[item]
        track = self.loader.track(track_id)

        # Get the audio and annotations
        audio_signal, sample_rate = track.audio
        audio_signal = audio_signal.mean(axis=-1) if audio_signal.ndim > 1 else audio_signal

        if self.mix > 0:
            rir = self.generate_rir(sample_rate)
            # Convolve
            rev = signal.fftconvolve(audio_signal, rir, mode="full")
            rev = rev[:audio_signal.shape[0]]  # Keep original length
            
            # Normalize energy
            audio_signal_norm = np.linalg.norm(audio_signal)
            rev_norm = np.linalg.norm(rev)
            if rev_norm > 0:
                rev = rev * (audio_signal_norm / rev_norm)
            
            audio_signal = (1 - self.mix) * audio_signal + self.mix * rev

        times = track.f0.times
        frequencies = track.f0.frequencies

        return (
            audio_signal.astype(np.float32),
            times.astype(np.float32),
            frequencies.astype(np.float32),
        )

In [None]:
def run_evaluation_2():
    fs = 44100
    hpsz = 20 # ms
    pesto_model = load_model("mir-1k_g7", step_size=hpsz, sampling_rate=fs)
    pesto_model.eval()
    mix_levels = [0.0, 0.1, 0.2, 0.3, 0.6, 0.9]

    metrics = {}
    for mix in mix_levels:
        count = 0
        total_rpa = 0
        total_rca = 0
        total_oa = 0
        metrics[mix] = {}
        metrics[mix]["RPA"] = 0
        metrics[mix]["RCA"] = 0
        metrics[mix]["OA"] = 0
        print(f"\n--- Evaluating Reverb Mix: {mix} ---")
        md = torch.utils.data.DataLoader(MDBDataset("mdb_stem_synth", mix=mix), batch_size=1, shuffle=True, drop_last=False)
        for audio, times, freqs in tqdm(md):
            with torch.no_grad():
                f0_pred, _, _ = pesto_model(
                audio,
                convert_to_freq=True,
                return_activations=False,
            )
            f0_pred = np.nan_to_num(f0_pred, nan=0.0)

            times_pred = np.arange(f0_pred.shape[-1]) * (hpsz / 1000.0)
            times_pred = times_pred.flatten()
            f0_pred = f0_pred.flatten()
            times = times.numpy().flatten()
            freqs = freqs.numpy().flatten()
            scores = mir_eval.melody.evaluate(times, freqs, times_pred, f0_pred)
            total_rpa += scores['Raw Pitch Accuracy']
            total_rca += scores['Raw Chroma Accuracy']
            total_oa += scores['Overall Accuracy']
            count += 1
        if count > 0:
            avg_rpa = total_rpa / count
            avg_rca = total_rca / count
            avg_oa = total_oa / count
            print(f"Result (Mix {mix}): RPA={avg_rpa*100:.2f}% | RCA={avg_rca*100:.2f}%")
            metrics[mix]["RPA"] = avg_rpa*100
            metrics[mix]["RCA"] = avg_rca*100
            metrics[mix]["OA"] = avg_oa*100
        else:
            print("No valid samples evaluated.")

    df = pd.DataFrame.from_dict(metrics, orient='index')
    df.index.name = 'Mix Level'
    
    df.to_csv('results.csv')

In [None]:
run_evaluation_2()