In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

import torch
from sklearn.model_selection import train_test_split
from torch import nn

from pitch_tracker.utils.constants import (F_MIN, HOP_LENGTH, N_CLASS, N_FFT,
                                           N_MELS, PATCH_SIZE,
                                           PATCH_STEP,
                                           PATCH_TIME, SAMPLE_RATE,
                                           ANALYSIS_FRAME_SIZE, ANALYSIS_FRAME_TIME, WIN_LENGTH)
from pitch_tracker.utils.dataset import AudioDataset
from pitch_tracker.ml.net import create_conv2d_block, MPT2023
from pitch_tracker.ml.train_model import train_model, train, test
from pitch_tracker import THESIS_2023_MODEL_PATH

In [3]:
device = "cuda" if torch.cuda.is_available() \
    else "mps" if torch.backends.mps.is_available() \
    else "cpu"

print(f"Using {device} device")

Using mps device


In [4]:
model_path = '../pitch_tracker/saved_model/mpt_2023.pt'

In [5]:
from typing import Any, Union
import numpy as np
from pitch_tracker.utils.audio import load_audio_mono
from pitch_tracker.utils.constants import PRE_MIDI_START
from pitch_tracker.utils.dataset import build_pick_features_and_time, extract_melspectrogram_feature

class MelodyExtractor():
    def __init__(self, model_path=THESIS_2023_MODEL_PATH, device='cpu') -> None:
        self.model = MPT2023().to(device)
        self.model.load_state_dict(torch.load(model_path, map_location=device))

    def __call__(
            self,
            file_path: str = None,
            signal: Union[torch.Tensor, np.ndarray] = None,
            sample_rate: int = SAMPLE_RATE,
            n_fft: int = N_FFT,
            n_mels: int = N_MELS*2,
            hop_length: int = HOP_LENGTH,
            patch_size: int = PATCH_SIZE,
            analysis_frame_size: int = ANALYSIS_FRAME_SIZE,
            analysis_frame_time: float = ANALYSIS_FRAME_TIME,
            fmin:float = F_MIN,
            voicing_bias:float = 0.0,
    ) -> Any:
        
        pick_features, pick_times = self.get_pick_features_and_time(
            file_path=file_path,
            signal=signal,
            sample_rate=sample_rate,
            n_fft=n_fft,
            n_mels=n_mels,
            hop_length=hop_length,
            patch_size=patch_size,
            analysis_frame_size=analysis_frame_size,
            analysis_frame_time=analysis_frame_time,
            fmin=fmin,
        )

        self.model.eval()
        pred = self.model(pick_features)
        pred[:,:,0] -= voicing_bias
        pitch = pred.argmax(2).flatten()
        pitch[pitch>0] += PRE_MIDI_START

        pick_times = pick_times.flatten()

        return pitch, pick_times
    
    

    def get_pick_features_and_time(
            self,
            file_path: str = None,
            signal: Union[torch.Tensor, np.ndarray] = None,
            sample_rate: int = SAMPLE_RATE,
            n_fft: int = N_FFT,
            n_mels: int = N_MELS*2,
            hop_length: int = HOP_LENGTH,
            patch_size: int = PATCH_SIZE,
            analysis_frame_size: int = ANALYSIS_FRAME_SIZE,
            analysis_frame_time: float = ANALYSIS_FRAME_TIME,
            fmin:float = F_MIN):
        
        
        if file_path is None and signal is None:
            raise Exception('Missing one required parameter `file_path` or `signal`.')
        
        # ignore `signal` param if file_path is used
        if file_path:
            signal, _ = load_audio_mono(file_path, sample_rate, keep_channel_dim=False)
        
        melspectrogram_features = extract_melspectrogram_feature(
            y=signal,
            n_fft=n_fft,
            hop_length=hop_length,
            n_mels=n_mels,
            sample_rate=sample_rate,
            backend='librosa',
            fmin=fmin,
        )

        pick_features, pick_times = build_pick_features_and_time(
            STFT_features=melspectrogram_features.T,
            patch_step=patch_size,
            patch_size= patch_size,
            analysis_frame_size=analysis_frame_size,
            analysis_frame_time=analysis_frame_time
        )

        pick_features = torch.from_numpy(pick_features).type(torch.float32)
        pick_times = torch.from_numpy(pick_times).type(torch.float32)

        pick_features=pick_features.unsqueeze(1)
        return pick_features, pick_times

In [6]:
def consecutive(data:Union[np.ndarray,torch.Tensor], stepsize:int=1, as_indices:bool=False):
    if isinstance(data,np.ndarray):
        split_indices = np.where(np.diff(data) != stepsize)[0]+1
        if as_indices:
            data_indices = np.arange(data.shape[0])
            return np.split(data_indices, split_indices)
        return np.split(data, split_indices)
    
    if isinstance(data,torch.Tensor):
        split_indices = torch.where(torch.diff(data) != stepsize)[0]+1
        split_indices = split_indices.tolist()
        if as_indices:
            data_indices = torch.arange(data.shape[0])
            return torch.hsplit(data_indices, split_indices)
        return torch.hsplit(data, split_indices)

    raise TypeError('`data` must be np.ndarray or torch.Tensor')
    

def get_consecutive_pred(pitch_pred:torch.Tensor):
    
    split_indices = torch.where(torch.diff(pitch_pred) != 0)[0]+1
    split_indices = split_indices.tolist()
    pitch_pred_indices_mask = torch.arange(pitch_pred.shape[0])

    sections = torch.hsplit(pitch_pred_indices_mask, split_indices)
    sections_pitch_values = pitch_pred[[indices[0] for indices in sections]].tolist()
    sections_pitch_values = tuple(sections_pitch_values)

    return list(zip(sections, sections_pitch_values))



a = torch.Tensor([0, 47, 48, 49, 50, 97, 98, 99])
b = torch.Tensor([0, 1, 1, 1, 2, 3, 3, 99, 1,1])
consecutive(b, 0, as_indices=True)

(tensor([0]),
 tensor([1, 2, 3]),
 tensor([4]),
 tensor([5, 6]),
 tensor([7]),
 tensor([8, 9]))

In [7]:
from functools import partial
from pitch_tracker.utils.midi import build_note_messages
from pitch_tracker.utils.midi import convert_to_midi
import mido


def build_note_sequences(pitch_pred:torch.Tensor, analysis_frame_time:int, analysis_frame_powers=None):
    note_sequences = []
    pitch_sequences = get_consecutive_pred(pitch_pred)
    if analysis_frame_powers is None:
        analysis_frame_powers = 50
    # filter non-melody sequences
    
    pitch_sequences = [(sequence, midi_value) for sequence, midi_value in pitch_sequences if midi_value != 0]
    for sequence, midi_value in pitch_sequences:
        sequence += 1
        start_time = (sequence[0] * analysis_frame_time).item()
        end_time = (sequence[-1] * analysis_frame_time).item() + analysis_frame_time
        note_sequences.append((start_time, end_time, midi_value, analysis_frame_powers))

    return torch.Tensor(note_sequences)

In [9]:
melody_extractor = MelodyExtractor(model_path=model_path)



In [10]:
audio_paths = [
    '../medleydb/medleydb/data/Audio/FamilyBand_Again/FamilyBand_Again_MIX.wav',
    '../medleydb/medleydb/data/Audio/AClassicEducation_NightOwl/AClassicEducation_NightOwl_MIX.wav',
    '../content/audio/mp3/Take on Me_ORIGINAL.mp3',
    '../content/audio/mp3/Let It Happen_Original.mp3',
]

In [12]:
out_midi_dir = '../content/midi'
voicing_bias = 0.2
for audio_path in audio_paths:
    print(audio_path)
    file_name_without_ext = os.path.splitext(os.path.basename(audio_path))[0]
    out_midi_path = os.path.join(out_midi_dir,file_name_without_ext + '.mid')

    pitch, time1d = melody_extractor(audio_path, voicing_bias=0.1)
    note_sequences = build_note_sequences(pitch, ANALYSIS_FRAME_TIME)
    note_messages = build_note_messages(note_sequences, ticks_per_beat=960)
    midi = convert_to_midi(note_messages.numpy(), ticks_per_beat=960)
    midi.save(out_midi_path)




../medleydb/medleydb/data/Audio/FamilyBand_Again/FamilyBand_Again_MIX.wav


  return f(*args, **kwargs)


../medleydb/medleydb/data/Audio/AClassicEducation_NightOwl/AClassicEducation_NightOwl_MIX.wav
../content/audio/mp3/Take on Me_ORIGINAL.mp3
../content/audio/mp3/Let It Happen_Original.mp3
