In [102]:
%matplotlib inline

In [109]:
import numpy as np
import torch
import torch.nn as nn
import sys
import torch.utils.data as data
import os
import pretty_midi


from tqdm import tqdm_notebook as tqdm
from collections import deque, Counter
from typing import Dict, List, Tuple

In [104]:
dataset = '/home/bartek/Datasets/maestro/'

In [105]:
def midi_path_to_pianoroll(path: str, fs: int=5) -> np.ndarray:
    pmid = pretty_midi.PrettyMIDI(path)
    piano = pmid.instruments[0]
    pianoroll = piano.get_piano_roll(fs=fs)
    return pianoroll
    
def pianoroll_to_time_dict(pianoroll: np.ndarray) -> Dict[int, str]:
    times = np.unique(pianoroll.nonzero()[1])  # czasy gdzie występuje przynajmniej jedna nuta 
    index = pianoroll.nonzero()  # indeksy wszystkich nut
    dict_keys_time = {}

    for time in times:
        index_where = (index[1] == time).nonzero()  # pozycje nut, które występują w danym czasie, w indeksie
        notes = index[0][index_where]  # odszukanie nut
        dict_keys_time[time] = ','.join(notes.astype(str))
        
    return dict_keys_time


In [162]:
class MIDIDataset:
    def __init__(self, converter, seq_len=50, song_batch_size=12, nn_batch_size=96):
        self.converter = converter
        self.seq_len = seq_len
        self.song_batch_size = song_batch_size
        self.nn_batch_size = nn_batch_size
        self._pos = 0
        self._song_batch = None
        
    def get_batch(self):
        if self._song_batch is None:
            songs_train, songs_target = self.converter.get_batch(self.song_batch_size, self.seq_len)
            shuffle_order = np.random.permutation(np.arange(songs_target.shape[0]))
            songs_train = songs_train[shuffle_order, :]
            songs_target = songs_target[shuffle_order]
            self._song_batch = (songs_train, songs_target)
            self._pos = 0
            
        songs_train, songs_target = self._song_batch
        end_pos = min(self._pos + self.nn_batch_size, songs_target.shape[0])
        train_batch, target_batch = songs_train[self._pos:end_pos, :], songs_target[self._pos:end_pos]
        if end_pos == songs_target.shape[0]:
            self._song_batch = None
        self._pos = end_pos
        return train_batch, target_batch
    
    def generate_sample(self, sample_len, temperature, start_note=None):
        raise NotImplementedError()

        
class MIDIConverter:
    def __init__(self, directory: str, frac: float=0.1) -> None:
        assert 0 < frac <= 1
        self._paths = self._locate_midi_files(directory, frac)
        self._time_dicts, self._notes_mapping, self._notes_frequency = self._convert_to_time_dicts()
        self._inverse_notes_mapping = {v: k for k, v in self._notes_mapping.items()}
        
    def get_batch(self, batch_size, seq_len):
        idx = np.random.choice(len(self._paths), size=batch_size)
        batch_train, batch_target = [], []
        for i in idx:
            time_dict = self._time_dicts[i]
            train_vals, target_vals = self._time_dict_to_seq(time_dict, seq_len)
            batch_train.append(train_vals)
            batch_target.append(target_vals)
        return np.vstack(batch_train), np.hstack(batch_target)
        
    def _time_dict_to_seq(self, time_dict, seq_len) -> np.ndarray:
        times = list(time_dict.keys())
        start_time, end_time = np.min(times), np.max(times)
        n_samples = end_time - start_time
        initial_values = [0]*(seq_len-1) + [time_dict[start_time]]
        train_values = np.zeros(shape=(n_samples+1, seq_len))
        target_values = np.zeros(shape=(n_samples+1))
        train_values_per_step = deque(initial_values)
        for i in range(n_samples):
            train_values[i, :] = list(train_values_per_step)
            current_target = time_dict.get(start_time + i, 0)
            target_values[i] = current_target
            train_values_per_step.popleft()
            train_values_per_step.append(current_target)
        train_values[n_samples, :] = list(train_values_per_step)
        return train_values, target_values
        
    def _convert_to_time_dicts(self) -> Tuple[Dict[int, Dict[int, str]], Dict[str, int]]:
        unique_notes = list()
        time_dicts = {}
        for i, path in tqdm(enumerate(self._paths), total=len(self._paths)):
            pianoroll = midi_path_to_pianoroll(path)
            time_dict = pianoroll_to_time_dict(pianoroll)
            time_dicts[i] = time_dict
            unique_notes += list(time_dict.values())

        notes_freq = Counter(unique_notes)
        unique_notes = set(unique_notes)
        # Replace strings with oridinal encoding
        notes_mapping = {note:(i+1) for i, note in enumerate(unique_notes)}
        for i, time_dict in time_dicts.items():
            for time, notes in time_dict.items():
                time_dict[time] = notes_mapping[notes]
        return time_dicts, notes_mapping, notes_freq
            
    def _locate_midi_files(self, base_dir: str, frac: float) -> List[str]:
        midi_files = []
        for root, dirs, files in os.walk(base_dir):
            for file in files:
                if file.endswith('.midi'):
                    midi_files.append(os.path.join(root, file))
        return np.random.choice(midi_files, size=int(frac*len(midi_files)), replace=False)
        

In [163]:
x = MIDIConverter('/home/bartek/Datasets/maestro/', frac=0.1)
dat = MIDIDataset(x)

HBox(children=(IntProgress(value=0, max=118), HTML(value='')))