# Preprocessing of audio and MIDI files

In [1]:
import numpy as np
import pretty_midi
import matplotlib.pyplot as plt
import librosa, librosa.display
import tensorflow as tf
from scipy import stats

from os import listdir

In [2]:
import import_ipynb
import audio_prep as ap, midi_prep as mp, constants as c

ModuleNotFoundError: No module named 'import_ipynb'

In [3]:
def create_midi_wav_pairs(path):
    """Create pairs of wav and midi files from specified path"""
    
    files = listdir(path)
    wavs = [wav[:-4] for wav in files if wav.endswith('.wav')]
    midis = [midi[:-4] for midi in files if midi.endswith('.mid')]
    
    pairs = []
    for file in wavs:
        if file not in midis:
            # Inform about file without pair and continue
            print('No matching pair for file: ', file)
        else:
            pairs.append((file + '.wav', file + '.mid'))
    return pairs

def load_midi_wav_pairs(path, pairs):
    """Load pairs as CQT and piano roll matrices"""
    
    cqt_matrices = []
    midi_matrices = []
    raw_MIDIs = []
    
    print("Loading ", len(pairs), "files.")
    for i, file in enumerate(pairs):
        # Load WAV file
        cqt_matrix = ap.cqt_matrix(path + '\\' + file[0])
        cqt_matrices.append(np.array(cqt_matrix))
        
        # Load MIDI file
        midi = pretty_midi.PrettyMIDI(path + '\\' + file[1])
        midi_matrix = midi.get_piano_roll(fs=c.FRAMES_PER_SEC)[c.MIDI_MIN:c.MIDI_MAX+1, :]
        midi_matrices.append(np.array(midi_matrix))
        raw_MIDIs.append(midi)
        
        if i % 3 == 0:
            print("Successfully loaded ", i+1 , " file(s)")
    
    print("Loading successfull!")
    return cqt_matrices, midi_matrices, raw_MIDIs

def align_midi_wav_pairs(cqt_matrices,
                         midi_matrices,
                         matrices_type="single_pair"):
    """
    Align the time shapes of CQT and MIDI metrices
    
    Args:
        cqt_matrices: CQT matrix/matrices to align
        midi_matrices: piano roll matrix/matrices to align
        matrices_type: 'single_pair' for only one sequence alignment
                       'array' for multiple sequences alignment
    
    Returns:
        aligned_cqts, aligned_midis: aligned pairs of sequences
    
    Raises:
        ValueError: if you provide wrong matrices_type parameter
    """
    
    if matrices_type == "single_pair":
        print("Aligning single pair of CQT spectrogram and MIDI matrix")
        cqt_length = len(cqt_matrices[0])
        midi_length = len(midi_matrices[0])
        
        if cqt_length > midi_length:
            print("Both matrices aligned to", midi_length, "frames.")
            return np.array(cqt_matrices[:, :midi_length]), np.array(midi_matrices)
        elif cqt_length < midi_length:
            print("Both matrices aligned to ", cqt_length, "frames.")
            return np.array(cqt_matrices), np.array(midi_matrices[:, :cqt_length])
        else:
            print("Same length of matrices on input.")
            return np.array(cqt_matrices), np.array(midi_matrices)
    elif matrices_type == "array":
        aligned_cqts = []
        aligned_midis = []
        for cqt, midi in zip(cqt_matrices, midi_matrices):
            cqt_shape = cqt[0].size
            midi_shape = midi[0].size

            if cqt_shape > midi_shape:
                aligned_cqts.append(np.array(cqt[:, :midi_shape]))
                aligned_midis.append(np.array(midi))
            elif cqt_shape < midi_shape:
                aligned_cqts.append(np.array(cqt))
                aligned_midis.append(np.array(midi[:, :cqt_shape]))
            else:
                aligned_cqts.append(np.array(cqt))
                aligned_midis.append(np.array(midi))
                
        return aligned_cqts, aligned_midis
    else:
        raise ValueError("Wrong matrices_type option. Only 'array' and 'single_pair' types allowed.")

def crop_midi_cqt_pairs(cqt_matrices,
                        midi_matrices,
                        operation_type="sequence",
                        matrices_type='single_pair'):
    """
    Crop CQT and piano roll pairs into sequences.
    
    Args:
        cqt_matrices: CQT matrix/matrices to crop
        midi_matrices: piano roll matrix/matrices to crop
        operation_type: if the output is single frame sequence or
                        sequence of specific length
        matrices_type: 'single_pair' for only one sequence crop
                       'array' for multiple sequences crop
                       
    Returns:
        crop_cqt, crop_midis: lists of cropped sequences
    """

    if matrices_type == 'single_pair':
        return ap.cqt_split_to_sequence(cqt_matrices), mp.midi_split_to_sequence(midi_matrices)
    
    crop_cqt = []
    crop_midis = []
    if operation_type == "sequence":
        for cqt, midi in zip(cqt_matrices, midi_matrices):
            crop_cqt.append(ap.cqt_split_to_sequence(cqt))
            crop_midis.append(mp.midi_split_to_sequence(midi))
    elif operation_type == "simple":
        for cqt, midi in zip(cqt_matrices, midi_matrices):
            crop_cqt.append(np.array(ap.split_wav(cqt)))
            crop_midis.append(np.array(mp.split_midi(midi)))
    else:
        raise ValueError("Wrong operation type.")
        
    return crop_cqt, crop_midis

def print_shapes(cqt_matrices, midi_matrices):
    """Print shapes of WAV and MIDI metrices"""
    
    for cqt, midi in zip(cqt_matrices, midi_matrices):
        print(cqt.shape, midi.shape)
        
def log_normalization(cqt_matrix):
    """Perform basic logarithmic transformation with zero values shift constant"""
    c = 10e-7
    norm = [np.log(x+c) for x in cqt_matrix]
    
    # Shift to interval (-1,1)
    n = np.min(norm)
    d = np.max(norm) - n    
    norm = [2*((x-n)/d)-1 for x in norm]
    
    return np.asarray(norm)

def std_mean_normalization(cqt_matrix):
    """Normalization based on paper: An End-to-End Neural Network for Polyphonic Piano Music Transcription"""
    
    std = np.std(cqt_matrix, axis=0)
    mean = np.mean(cqt_matrix, axis=0)
    norm = [(x-mean)/std for x in cqt_matrix]
    
    # Shift to interval (-1,1)
    n = np.min(norm)
    d = np.max(norm) - n    
    norm = [2*((x-n)/d)-1 for x in norm]
    return np.asarray(norm)

def get_datasets_pairs():
    """Load pairs from paths and preprocess them"""
    
    train_path = r'non-overlapping/train'
    valid_path = r'non-overlapping/valid'
    test_path = r'non-overlapping/test'
    
    train_pairs = []
    valid_pairs = []
    test_pairs = []
    
    train_file = open(train_path, "r")
    valid_file = open(valid_path, "r")
    test_file = open(test_path, "r")
    for line in train_file:
        pairs = line.split(sep=',')
        train_pairs.append((pairs[0], pairs[1][:-1]))
    train_file.close()
    
    for line in valid_file:
        pairs = line.split(sep=',')
        valid_pairs.append((pairs[0], pairs[1][:-1]))
    valid_file.close()
    
    for line in test_file:
        pairs = line.split(sep=',')
        test_pairs.append((pairs[0], pairs[1][:-1]))
    test_file.close()
    
    return train_pairs, valid_pairs, test_pairs


def process_data(file_pairs, predictions='frame'):
    """
    Generator function for datasets processing.
    
    Args:
        file_pairs: list of paths pairs to dataset
        predictions: 'frame' to set labels to frames one-hot matrix
                     'onset' to set labels to onsets one-hot matrix
    
    Yields:
        cqt, one-hot: CQT sequence with corresponding one-hot matrix
    """
    
    
    counter = 0;
    for pair in file_pairs:
        cqt_spectrogram = ap.cqt_matrix(pair[0])
        
        if predictions == 'frame':
            piano_roll = mp.load_midi_file(pair[1])        
            one_hot = mp.pretty_midi_to_frame_matrix(piano_roll)
        elif predictions == 'onset':
            raw_midi = pretty_midi.PrettyMIDI(pair[1])
            one_hot = mp.pretty_midi_to_onset_matrix(raw_midi)
        else:
            raise ValueError("Wrong predictions operation type.")
        
        cqt_spectrogram, one_hot = align_midi_wav_pairs(cqt_spectrogram,
                                                        one_hot,
                                                        matrices_type="single_pair")
        normalized_cqt_spec = log_normalization(cqt_spectrogram)        
        cqt_spectrogram, one_hot = crop_midi_cqt_pairs(normalized_cqt_spec, one_hot)
        cqt_spectrogram = cqt_spectrogram[:-1]
        one_hot = one_hot[:-1]
        counter += len(cqt_spectrogram)
        print("Currently processed ", counter, "sequences.")
        for cqt, one_hot in zip(cqt_spectrogram, one_hot):
            cqt = cqt.T
            one_hot = one_hot.T
            cqt = cqt.reshape(*cqt.shape, -1)
#             cqt = cqt.reshape((c.SEQUENCE_CHUNK_LENGTH + 2*c.CHUNK_PADDING, c.BINS_NUMBER, 1))
            yield cqt, one_hot

def get_dataset():
    """
    Process datasets as tf.data.Dataset objects which are filled with generators.
    
    Returns:
        train_dataset, valid_dataset, test_dataset: dataset generators
    """
    
    
    train_pairs, valid_pairs, test_pairs = get_datasets_pairs()
    
#     path = r'D:\School\Bc\model\MAPS\AkPnBcht\MUS'
#     pairs = create_midi_wav_pairs(path)
    train_generator = lambda: process_data(train_pairs, predictions='frame')
    valid_generator = lambda: process_data(valid_pairs, predictions='frame')
    test_generator = lambda: process_data(test_pairs, predictions='frame')
    
    train_dataset = tf.data.Dataset.from_generator(train_generator,
                                                   (tf.float32, tf.float32)).batch(c.BATCH_SIZE)
    train_dataset = train_dataset.apply(tf.data.experimental.shuffle_and_repeat(128))
    
    valid_dataset = tf.data.Dataset.from_generator(valid_generator,
                                                   (tf.float32, tf.float32)).batch(c.BATCH_SIZE)
    valid_dataset = valid_dataset.apply(tf.data.experimental.shuffle_and_repeat(9))
    
    test_dataset = tf.data.Dataset.from_generator(test_generator,
                                                  (tf.float32, tf.float32)).batch(c.BATCH_SIZE)
    
    return train_dataset, valid_dataset, test_dataset

def get_dataset_test():
    """Testing util (same as get_dataset() function) with smaller generators"""
    
    train_pairs, valid_pairs, test_pairs = get_datasets_pairs()
    
#     path = r'D:\School\Bc\model\MAPS\AkPnBcht\MUS'
#     pairs = create_midi_wav_pairs(path)
    train_generator = lambda: process_data(train_pairs, predictions='frame')
    valid_generator = lambda: process_data(valid_pairs, predictions='frame')
    test_generator = lambda: process_data(test_pairs, predictions='frame')
    
    train_dataset = tf.data.Dataset.from_generator(train_generator,
                                                   (tf.float32, tf.float32)).batch(c.BATCH_SIZE)
    train_dataset = train_dataset.apply(tf.data.experimental.shuffle_and_repeat(3))
    
    valid_dataset = tf.data.Dataset.from_generator(valid_generator,
                                                   (tf.float32, tf.float32)).batch(c.BATCH_SIZE)
    valid_dataset = valid_dataset.apply(tf.data.experimental.shuffle_and_repeat(3))
    
    test_dataset = tf.data.Dataset.from_generator(test_generator,
                                                  (tf.float32, tf.float32)).batch(c.BATCH_SIZE)
    
    return train_dataset, valid_dataset, test_dataset

In [1]:
# train_d, valid_d, test_d = get_dataset()
# for pair in train_pairs:
#     print(pair)

In [35]:
# path = r'D:\School\Bc\model\MAPS\AkPnBcht\MUS'
# pairs = create_midi_wav_pairs(path)
# pairs[0]

In [1]:
# cqt_matrices, midis, raw_midis = load_midi_wav_pairs(path, pairs)
# onset_midi = [mp.pretty_midi_to_onset_matrix(midi) for midi in raw_midis]
# frame_midi = [mp.pretty_midi_to_frame_matrix(midi) for midi in midis]
# cqt_matrices, frame_midi = align_midi_wav_pairs(cqt_matrices, frame_midi, matrices_type='array')
# cqt_norm = [log_normalization(wav) for wav in cqt_matrices]
# cqt_chunks, midi_chunks = crop_midi_cqt_pairs(cqt_norm, frame_midi, matrices_type='array')

In [12]:
# it1, it2, it3 = get_dataset()

In [2]:
# it = test_d.make_initializable_iterator()

# el = it.get_next()
# counter = 1
# with tf.Session() as sess:
#     sess.run(it.initializer)

#     c, m = sess.run(el)
#     print(c.shape,m.shape)
#     ap.create_spectrogram(c[2:627,:,0].T)
#     mp.plot_piano_roll(m.T)