In [1]:
import os
import pandas as pd
import librosa
import note_seq

from typing import List, Optional, Tuple
from dataclasses import dataclass
from torch.utils.data import Dataset, DataLoader
from note_seq.protobuf.music_pb2 import NoteSequence


MAESTRO_DS_PATH = "../data/raw/maestro-v3.0.0/"


@dataclass
class DatasetParams:
    root_path: str
    metadata: str
    years_list: List[int]
    split: Optional[str]


@dataclass
class AudioFile:
    def __init__(self, audio_arr, sample_rate) -> None:
        self.audio_arr = audio_arr
        self.sample_rate = sample_rate


class WavMidiDataset(Dataset):
    def __init__(self, params: DatasetParams) -> None:
        super().__init__()

        self._root_path = params.root_path
        self._years = params.years_list
        self._split = params.split
        self._data = []

        metadata_path = os.path.join(self._root_path, params.metadata)
        ds_metadata = pd.read_csv(metadata_path)

        if self._split:
            ds_metadata = ds_metadata[ds_metadata["split"] == self._split]
        if len(self._years) > 0:
            ds_metadata = ds_metadata[ds_metadata["year"].map(lambda x: x in self._years)]

        ds_metadata = ds_metadata[["midi_filename", "audio_filename"]]

        self._len = ds_metadata.shape[0]
        self._data = ds_metadata

    def _process_audio(self):
        pass

    def _process_midi(self):
        pass

    def __len__(self):
        return self._len

    def __getitem__(self, idx) -> Tuple[NoteSequence, AudioFile]:
        midi_filename, audio_filename = self._data.iloc[idx]
        
        midi_path = os.path.join(self._root_path, midi_filename)
        audio_path = os.path.join(self._root_path, audio_filename)

        ns = (midi_path)
        audio = AudioFile(*librosa.load(audio_path, sr=44100))
        return {"data": audio, "target": ns}
    
    def _process_midi(self):
        pass

    def _process_audio(self):
        pass



In [2]:
params = DatasetParams(
    root_path=MAESTRO_DS_PATH,
    metadata="maestro-v3.0.0.csv",
    years_list=[],
    split="train"
)

dataset = WavMidiDataset(params)
dl = DataLoader(dataset, 10)

for batch in dl:
    print(batch)
    break

dl = DataLoader(dataclass, 10, shuffle=True)
for batch in dl:
    print(batch)
    break

KeyboardInterrupt: 

In [3]:
ds = pd.read_csv(MAESTRO_DS_PATH + "maestro-v3.0.0.csv")
ds

Unnamed: 0,canonical_composer,canonical_title,split,year,midi_filename,audio_filename,duration
0,Alban Berg,Sonata Op. 1,train,2018,2018/MIDI-Unprocessed_Chamber3_MID--AUDIO_10_R...,2018/MIDI-Unprocessed_Chamber3_MID--AUDIO_10_R...,698.661160
1,Alban Berg,Sonata Op. 1,train,2008,2008/MIDI-Unprocessed_03_R2_2008_01-03_ORIG_MI...,2008/MIDI-Unprocessed_03_R2_2008_01-03_ORIG_MI...,759.518471
2,Alban Berg,Sonata Op. 1,train,2017,2017/MIDI-Unprocessed_066_PIANO066_MID--AUDIO-...,2017/MIDI-Unprocessed_066_PIANO066_MID--AUDIO-...,464.649433
3,Alexander Scriabin,"24 Preludes Op. 11, No. 13-24",train,2004,2004/MIDI-Unprocessed_XP_21_R1_2004_01_ORIG_MI...,2004/MIDI-Unprocessed_XP_21_R1_2004_01_ORIG_MI...,872.640588
4,Alexander Scriabin,"3 Etudes, Op. 65",validation,2006,2006/MIDI-Unprocessed_17_R1_2006_01-06_ORIG_MI...,2006/MIDI-Unprocessed_17_R1_2006_01-06_ORIG_MI...,397.857508
...,...,...,...,...,...,...,...
1271,Wolfgang Amadeus Mozart,"Sonata in F Major, K280",test,2004,2004/MIDI-Unprocessed_XP_14_R1_2004_04_ORIG_MI...,2004/MIDI-Unprocessed_XP_14_R1_2004_04_ORIG_MI...,241.470442
1272,Wolfgang Amadeus Mozart,"Sonata in F Major, K280",train,2004,2004/MIDI-Unprocessed_XP_14_R1_2004_04_ORIG_MI...,2004/MIDI-Unprocessed_XP_14_R1_2004_04_ORIG_MI...,114.696243
1273,Wolfgang Amadeus Mozart,"Sonata in F Major, K533",validation,2004,2004/MIDI-Unprocessed_SMF_12_01_2004_01-05_ORI...,2004/MIDI-Unprocessed_SMF_12_01_2004_01-05_ORI...,1139.198478
1274,Wolfgang Amadeus Mozart,"Sonata in F Major, K533/K494",validation,2018,2018/MIDI-Unprocessed_Recital17-19_MID--AUDIO_...,2018/MIDI-Unprocessed_Recital17-19_MID--AUDIO_...,1068.751602


In [4]:
years = [2018, 2017]
ds = ds[ds["year"].map(lambda x: x in years)]
ds["year"].value_counts()

2017    140
2018     93
Name: year, dtype: int64

In [6]:
ds = ds[["midi_filename", "audio_filename"]]

In [9]:
ds.iloc[0]

midi_filename     2018/MIDI-Unprocessed_Chamber3_MID--AUDIO_10_R...
audio_filename    2018/MIDI-Unprocessed_Chamber3_MID--AUDIO_10_R...
Name: 0, dtype: object