# Import 

In [None]:
import numpy as np
import re
import glob

from matplotlib import pyplot as plt
from tqdm import tqdm_notebook as tqdm
from pypianoroll import Track, Multitrack
from pypianoroll.plot import plot_track

import librosa.display
import pretty_midi

# Utils 

In [1]:
LEARNING_DRUM_MAP = {
    0: dict(
        encoding=[[36]], #kick
        decoding=36,
    ),
    1:  dict(
        encoding=[[37, 44]], # snares
        decoding=38,
    ),
    2: dict(
        encoding=[[47, 63], [75, 82]], # cymbals
        decoding=54,
    ),
    3: dict(
        encoding=[[65, 72]], # toms
        decoding=65,
    ),
}

def note_gen(track):
    """
    yields notes from the song
    usage:
        gen = note_gen(pypianoroll.Track)
        next(gen)
    """
    for i in range(track.pianoroll.shape[0]):
        notes = track.pianoroll[i]
        hits = np.where(notes != 0)
        if np.sum(notes) != 0:
            yield hits
            
def load_midi_map(file_path):
    res = {}
    with open(file_path, 'r') as f:
        for line in f:
            midi_id = re.search(r'\d+', line).group(0)
            drum_name = line.replace(midi_id, '').strip()
            res[int(midi_id)] = drum_name
    return res

def search_folder_for_file_format(folder_path, file_format=r'\.rar$'):
    res = []
    for file in glob.iglob(folder_path + '**/*', recursive=True):
        if re.search(file_format, file) is not None:
            res.append(file)
    return res

def load_midi_files(root_path):
    res = []
    # walk recursively over root_path
    for file in tqdm(glob.iglob(root_path + '**/*', recursive=True)):
        # find midi files(.mid)
        if re.search(r'\.mid$', file) is not None:
            res.append(
                dict(midi=Multitrack(file), file_path=file)
            )
    return res

def collapse_index_range(drum_ranges):
    idx = []
    for indecies in drum_ranges:
        if len(indecies) == 1:
            idx.append(indecies[0])
        else:
            idx += list(range(*indecies))
    return idx

def binarize_array(arr):
    return (arr > 0) * 1

def track_to_learn_seq(track, learning_map:'{key: {encoding=drum_range, decoding=drum}}'):
    track.binarize()
    track.trim_trailing_silence()
    
    seq = [0] * len(learning_map.keys())
    for key, drum_ranges in learning_map.items():
        idx = collapse_index_range(drum_ranges['encoding'])
        pre_collapsed_drum_lines = track.pianoroll[:, idx]
        collapsed_drum_lines = np.sum(pre_collapsed_drum_lines, axis=1)
        collapsed_drum_lines = binarize_array(collapsed_drum_lines)
        seq[key] = collapsed_drum_lines
    
    return np.array(seq)

def learn_seq_to_midi(seq, learning_map):
    res = np.zeros([seq.shape[1], 128])
    seq = seq.T
    
    for drum_num, (coding, drum_coding) in zip(range(seq.shape[0]), learning_map.items()):
        res[:, drum_coding['decoding']] = seq[:, drum_num]
    
    return res

def notes_used_in_track(track, midi_map=None):   
    notes = np.hstack([note for note in note_gen(track)])
    
    if midi_map is None:
        return np.unique(notes)
    else:
        res = {}
        for i in np.unique(notes):
            res[i] = midi_map[i]
        return res
    
def plot_piano_roll(pm, start_pitch, end_pitch, fs=100):
    # Use librosa's specshow function for displaying the piano roll
    librosa.display.specshow(pm.pianoroll[start_pitch:end_pitch].T,
                             hop_length=1, sr=fs, x_axis='time', y_axis='cqt_note',
                             fmin=pretty_midi.note_number_to_hz(start_pitch))
    
def moving_window(seq, stride=1, window_len=96):
    for time in range(0, seq.shape[1], stride):
        yield seq[:, time:(time + window_len)]
        
def plot_learn_seq(seq, learning_map):
    track = Track(pianoroll=learn_seq_to_midi(seq, learning_map), program=0, is_drum=True)
    plot_piano_roll(track, 0, track.pianoroll.shape[0])

In [None]:
midi_map = load_midi_map('./midi_map.txt')
midi_map

In [None]:
midis = load_midi_files('./_datasets/')

In [None]:
for midi_mtrack in midis:
    track = midi_mtrack['midi'].tracks[0]
    file = midi_mtrack['file_path']
    break

In [None]:
notes_used_in_track(track, midi_map)

In [None]:
track.plot()

In [None]:
plot_piano_roll(track, 0, track.pianoroll.shape[0])

In [None]:
seq = track_to_learn_seq(track, LEARNING_DRUM_MAP)

In [None]:
seq

In [None]:
track = Track(pianoroll=learn_seq_to_midi(seq, LEARNING_DRUM_MAP), program=0, is_drum=True)
plot_piano_roll(track, 0, track.pianoroll.shape[0])