# ComMU Data Processing Module - Overview

In [52]:
# conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch

In [53]:
import sys
sys.path.append("..") # Adds higher directory to python modules path.

In [54]:
import copy
import enum
import os
import shutil
import tempfile
from ast import literal_eval
from dataclasses import dataclass, field, fields
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import miditoolkit
import numpy as np
import pandas as pd
import parmap

from commu.preprocessor import augment
from commu.preprocessor.utils import sync_key_augment
from commu.preprocessor.utils.exceptions import UnprocessableMidiError
from commu.preprocessor.encoder import MetaEncoder, EventSequenceEncoder
from commu.preprocessor.parser import MetaParser

import argparse
from multiprocessing import cpu_count
from pathlib import Path
from commu.preprocessor import PreprocessPipeline

from tqdm.notebook import tqdm

# Helper Functions

In [55]:
MIDI_EXTENSIONS = (".mid", ".MID", ".midi", ".MIDI")

class OutputSubDirName(str, enum.Enum):
    RAW = "raw"
    ENCODE_NPY = "output_npy"
    ENCODE_TMP = "npy_tmp"


class SubDirName(str, enum.Enum):
    RAW = "raw"
    ENCODE_NPY = "output_npy"
    ENCODE_TMP = "npy_tmp"
    AUGMENTED_TMP = "augmented_tmp"
    AUGMENTED = "augmented"


@dataclass
class OutputSubDirectory:
    encode_npy: Union[str, Path]
    encode_tmp: Union[str, Path]


@dataclass
class SubDirectory:
    raw: Union[str, Path]
    encode_npy: Union[str, Path]
    encode_tmp: Union[str, Path]
    augmented_tmp: Optional[Union[str, Path]] = field(default=None)
    augmented: Optional[Union[str, Path]] = field(default=None)


def get_output_sub_dir(root_dir: Union[str, Path]) -> OutputSubDirectory:
    result = dict()
    for name, member in OutputSubDirName.__members__.items():
        output_dir = root_dir.joinpath(member.value)
        output_dir.mkdir(exist_ok=True, parents=True)
        result[name.lower()] = output_dir
    return OutputSubDirectory(**result)


def get_sub_dir(
        root_dir: Union[str, Path], split: Optional[str]) -> SubDirectory:
    result = dict()
    for name, member in SubDirName.__members__.items():
        if split is None:
            sub_dir = root_dir.joinpath(member.value)
        else:
            sub_dir = root_dir.joinpath(split).joinpath(member.value)
        sub_dir.mkdir(exist_ok=True, parents=True)
        result[name.lower()] = sub_dir
    return SubDirectory(**result)


@dataclass
class EncodingOutput:
    meta: np.ndarray
    event_sequence: np.ndarray


def augment_data(
        source_dir: Union[str, Path],
        augmented_dir: Union[str, Path],
        augmented_tmp_dir: Union[str, Path],
        num_cores: int,
):
    augment.augment_data(
        midi_path=str(source_dir),
        augmented_dir=str(augmented_dir),
        augmented_tmp_dir=str(augmented_tmp_dir),
        num_cores=num_cores,
    )


def gather_sample_file(*source_dirs: Union[str, Path]) -> Dict[str, str]:
    def _gather(_source_dir):
        return {
            filename.stem: str(filename)
            for filename in Path(_source_dir).rglob("**/*")
            if filename.suffix in MIDI_EXTENSIONS
        }

    result = dict()
    for source_dir in source_dirs:
        result.update(_gather(source_dir))
    return result

## Load comMU Encoder Dictionary

In [56]:
# testing out the encding function
import math
from commu.preprocessor.encoder import encoder_utils
from commu.preprocessor.encoder.event_tokens import TOKEN_OFFSET
from commu.preprocessor.utils.constants import (
    DEFAULT_POSITION_RESOLUTION,
    SIG_TIME_MAP,
    BPM_INTERVAL,
    DEFAULT_TICKS_PER_BEAT,
    VELOCITY_INTERVAL,
    KEY_NUM_MAP,
    KEY_MAP
)


NUM_VELOCITY_BINS = int(128 / VELOCITY_INTERVAL)
DEFAULT_VELOCITY_BINS = np.linspace(2, 127, NUM_VELOCITY_BINS, dtype=int)

class Item(object):
    def __init__(self, name, start, end, velocity, pitch):
        self.name = name
        self.start = start
        self.end = end
        self.velocity = velocity
        self.pitch = pitch

    def __repr__(self):
        return "Item(name={}, start={}, end={}, velocity={}, pitch={})".format(
            self.name, self.start, self.end, self.velocity, self.pitch
        )


class Event(object):
    def __init__(self, name, time, value, text):
        self.name = name
        self.time = time
        self.value = value
        self.text = text

    def __repr__(self):
        return "Event(name={}, time={}, value={}, text={})".format(
            self.name, self.time, self.value, self.text
        )


event2word, word2event = encoder_utils.mk_remi_map()
event2word = encoder_utils.add_flat_chord2map(event2word)
event2word = encoder_utils.abstract_chord_types(event2word)
position_resolution = DEFAULT_POSITION_RESOLUTION

In [57]:
def read_items(file_path):
    """
        Function to read a MIDI file and return a list of Item objects.
        Input:
            file_path: path to the MIDI file
        Output:
            note_items: list of Item objects
    """
    # parse the MIDI file into MIDIFile Object
    midi_obj = miditoolkit.midi.parser.MidiFile(file_path)

    # extracts all notes from the MIDI object and short them by their pitch and start time
    notes = midi_obj.instruments[0].notes
    notes.sort(key=lambda x: (x.start, x.pitch))

    # iterate through the notes and create a list of Item objects
    note_items = []
    for note in notes:
        note_items.append(
            Item(
                name="Note",
                start=note.start,
                end=note.end,
                velocity=note.velocity,
                pitch=note.pitch,
            )
        )
    # short the list of notes by their start time
    note_items.sort(key=lambda x: x.start)
    return note_items


def group_items(items, max_time, ticks_per_bar):
    """
        Function to group a list of items by bars
        Inputs:
            items: list of Item objects
            max_time: maximum time of the MIDI file
            ticks_per_bar: ticks per bar
        Output:
            groups: list of groups of items

    """

    # sort the list of items by their start time
    # compute a list of downbeats (strongest beat) for each measure in the MIDI file
    items.sort(key=lambda x: x.start)
    downbeats = np.arange(0, max_time + ticks_per_bar, ticks_per_bar)

    # iterate through the list of downbeats and group the items by bars
    groups = []
    for db1, db2 in zip(downbeats[:-1], downbeats[1:]):
        insiders = []
        
        for item in items:
            # group each note item in a given bar
            if (item.start >= db1) and (item.start < db2):
                insiders.append(item)

        # if there is no note in a given bar, add a None item
        if not insiders:
            insiders.append(Item(name="None", start=None, end=None, velocity=None, pitch="NN"))
        
        # add the group of items between two downbeats
        overall = [db1] + insiders + [db2]
        groups.append(overall)
    return groups

In [58]:
def item2event(groups, duration_bins):
    """
        Function to convert a list of groups of items into a list of events
        Inputs:
            groups: list of groups of items
            duration_bins: list of duration bins
        Output:
            events: list of events
    """

    events = []
    n_downbeat = 0
    for i in range(len(groups)):
        if "NN" in [item.pitch for item in groups[i][1:-1]]:
            continue
        bar_st, bar_et = groups[i][0], groups[i][-1]
        n_downbeat += 1
        if groups[i][1].name == "Chord":
            events.append(Event(name="Bar", time=bar_st, value=None, text="{}".format(n_downbeat)))
        for item in groups[i][1:-1]:
            # position
            flags = np.linspace(bar_st, bar_et, DEFAULT_POSITION_RESOLUTION, endpoint=False)
            index = np.argmin(abs(flags - item.start))
            events.append(
                Event(
                    name="Position",
                    time=item.start,
                    value="{}/{}".format(index + 1, DEFAULT_POSITION_RESOLUTION),
                    text="{}".format(item.start),
                )
            )
            if item.name == "Note":
                # velocity
                velocity_index = (
                    np.searchsorted(DEFAULT_VELOCITY_BINS, item.velocity, side="right") - 1
                )
                events.append(
                    Event(
                        name="Note Velocity",
                        time=item.start,
                        value=velocity_index,
                        text="{}/{}".format(item.velocity, DEFAULT_VELOCITY_BINS[velocity_index]),
                    )
                )
                # pitch
                events.append(
                    Event(
                        name="Note On",
                        time=item.start,
                        value=item.pitch,
                        text="{}".format(item.pitch),
                    )
                )
                # duration
                duration = item.end - item.start
                index = np.argmin(abs(duration_bins - duration))
                events.append(
                    Event(
                        name="Note Duration",
                        time=item.start,
                        value=index,
                        text="{}/{}".format(duration, duration_bins[index]),
                    )
                )
            elif item.name == "Chord":
                events.append(
                    Event(
                        name="Chord",
                        time=item.start,
                        value=item.pitch,
                        text="{}".format(item.pitch),
                    )
                )
    return events

In [59]:
def detect_chord(chord_progression, beats_per_bar):
    """
        Function to detect chord from a chord progression
        Inputs:
            chord_progression { list }: list containing chord progression
            beats_per_bar: numerical value indicating number of beats per bar
        Output:
            chord_idx: list of chord index
            chord_name: list of chord name
    """
    # expand the beat per bar to get the number of chords per bar
    chords_per_bar = beats_per_bar * 2
    num_measures = int(len(chord_progression)/chords_per_bar)
    split_by_bar = np.array_split(np.array(chord_progression), num_measures)
    
    chord_idx = []
    chord_name = []
    for bar_idx, bar in enumerate(split_by_bar):
        for c_idx, chord in enumerate(bar):
            chord = chord.lower()
            if c_idx == 0 or chord != chord_name[-1]:
                chord_idx.append(bar_idx + c_idx / chords_per_bar)
                chord_name.append(chord)
    return chord_idx, chord_name


def insert_chord_on_event(
    events,
    chord_progression,
    tick_per_bar,
    num_measures,
    is_incomplete_measure,
    beats_per_bar):
    """
        Function to convert chord progression into a list of events
        and insert them into proper position to the of events
        Inputs:
            chord_progression: list of chord progression
            tick_per_bar: number of ticks per bar
            num_measures: number of measures in the MIDI file
            is_incomplete_measure: boolean value indicating whether the last measure is incomplete
            beats_per_bar: number of beats per bar
        Output:
            chord_events: list of events
    """
    chord_idx_lst, chords = detect_chord(chord_progression, beats_per_bar)
    start_time = tick_per_bar * is_incomplete_measure
    chord_events = []
    for i in range(num_measures):
        chord_events.append(
            Event(name="Bar", time=i * tick_per_bar, value=None, text="{}".format(i + 1))
        )
        while chord_idx_lst and chord_idx_lst[0] < i + 1 - is_incomplete_measure:
            chord_position = chord_idx_lst.pop(0)
            chord_time = int(chord_position * tick_per_bar + start_time)
            chord = chords.pop(0)
            chord_events.append(
                Event(
                    name="Position",
                    time=chord_time,
                    value="{}/{}".format(
                        int((chord_position - i + is_incomplete_measure) * DEFAULT_POSITION_RESOLUTION) + 1,
                        DEFAULT_POSITION_RESOLUTION
                    ),
                    text=chord_time,
                )
            )
            chord_events.append(
                Event(name="Chord",
                        time=chord_time,
                        value=chord.split("/")[0].split("(")[0],
                        text=chord.split("/")[0].split("(")[0])
            )

    inserted_events = chord_events + events
    inserted_events.sort(key=lambda x: x.time)
    return inserted_events

In [60]:
def encode_event_sequence(midi_path: Union[str, Path], sample_info: Dict) -> np.ndarray:
    # with tempfile.NamedTemporaryFile(suffix=Path(midi_path).suffix) as f:
    #     midi_obj = miditoolkit.MidiFile(midi_path)
    #     for idx in range(len(midi_obj.instruments)):
    #         try:
    #             if midi_obj.instruments[idx].name == "chord":
    #                 midi_obj.instruments.pop(idx)
    #         except IndexError:
    #             continue
    #     midi_obj.dump(f.name)
    event_sequence = np.array(EventSequenceEncoder().encode(midi_path, sample_info=sample_info))
    return event_sequence


def preprocess_midi(sample_info: Dict[str, Any], midi_path: Union[str, Path]) -> Optional[EncodingOutput]:
    midi_meta = MetaParser().parse(meta_dict=sample_info)
    # print(midi_meta)
    try:
        encoded_meta: List[Union[int, str]] = MetaEncoder().encode(midi_meta)
    except UnprocessableMidiError as e:
        print(f"{e}: {midi_path}")
        return None
    
    
    encoded_meta: np.ndarray = np.array(encoded_meta, dtype=object)
    encoded_event_sequence = np.array(
        encode_event_sequence(midi_path, sample_info), dtype=np.int16
    )
    return EncodingOutput(meta=encoded_meta, event_sequence=encoded_event_sequence)

In [61]:
def encode(midi_paths, sample_info=None, for_cp=False):
    midi_file = miditoolkit.MidiFile(midi_paths)
    ticks_per_beat = midi_file.ticks_per_beat
    chord_progression = sample_info["chord_progressions"]
    num_measures = math.ceil(sample_info["num_measures"])
    numerator = int(sample_info["time_signature"].split("/")[0])
    denominator = int(sample_info["time_signature"].split("/")[1])
    is_incomplete_measure = sample_info["is_incomplete_measure"]

    beats_per_bar = numerator / denominator * 4
    ticks_per_bar = int(ticks_per_beat * beats_per_bar)
    duration_bins = np.arange(
        int(ticks_per_bar / position_resolution),
        ticks_per_bar + 1,
        int(ticks_per_bar / position_resolution),
        dtype=int,
    )

    events = encoder_utils.extract_events(
        midi_paths,
        duration_bins,
        ticks_per_bar=ticks_per_bar,
        ticks_per_beat=ticks_per_beat,
        chord_progression=chord_progression,
        num_measures=num_measures,
        is_incomplete_measure=is_incomplete_measure,
    )
    if for_cp:
        return events


    words = []
    for event in events:
        e = "{}_{}".format(event.name, event.value)
        if e in event2word:
            words.append(event2word[e])
        else:
            # OOV
            if event.name == "Note Velocity":
                # replace with max velocity based on our training data
                words.append(event2word["Note Velocity_63"])
            if event.name == "Note Duration":
                # replace with max duration
                words.append(event2word[f"Note Duration_{position_resolution-1}"])
            else:
                # something is wrong
                # you should handle it for your own purpose
                print("OOV {}".format(e))
    words.append(TOKEN_OFFSET.EOS.value)  # eos token
    return np.array(words)

# Data-preprocessing
- MIDI representation is encoded by REMI
- 11 out of 12 metadata piece are place before the REMI representaion
- chord progression is represent by REMI (extends to 108 patterns) representation with tempo token remove
- Increase the resolution of the position and the duration token from 32 notes to 128 notes 

In [62]:
# set up some intial directories paramters 
root_dir = "../dataset/musicMIDI"
csv_path = "../dataset/midi_metadata_file_cleaned.csv"
# root_dir = "../dataset/commu_midi"
# csv_path = "../dataset/commu_meta.csv"
num_cores = 4

# expand root_dir to Path object
root_dir = Path(root_dir).expanduser()

# get all subdirectories from the root directory
default_sub_dir = get_sub_dir(root_dir, split=None)
split_sub_dir = get_sub_dir(root_dir, split='train')
default_sub_dir, split_sub_dir

(SubDirectory(raw=PosixPath('../dataset/musicMIDI/raw'), encode_npy=PosixPath('../dataset/musicMIDI/output_npy'), encode_tmp=PosixPath('../dataset/musicMIDI/npy_tmp'), augmented_tmp=PosixPath('../dataset/musicMIDI/augmented_tmp'), augmented=PosixPath('../dataset/musicMIDI/augmented')),
 SubDirectory(raw=PosixPath('../dataset/musicMIDI/train/raw'), encode_npy=PosixPath('../dataset/musicMIDI/train/output_npy'), encode_tmp=PosixPath('../dataset/musicMIDI/train/npy_tmp'), augmented_tmp=PosixPath('../dataset/musicMIDI/train/augmented_tmp'), augmented=PosixPath('../dataset/musicMIDI/train/augmented')))

In [63]:
# # # run data augmentation
# augment_data(
#     source_dir=split_sub_dir.raw,
#     augmented_dir=split_sub_dir.augmented,
#     augmented_tmp_dir=split_sub_dir.augmented_tmp,
#     num_cores=num_cores,
# )

In [64]:
# set up the encoded output directory
encode_tmp_dir = split_sub_dir.encode_tmp
encode_tmp_dir = Path(encode_tmp_dir)
encode_tmp_dir

PosixPath('../dataset/musicMIDI/train/npy_tmp')

In [65]:
# # Gather the raw MIDI files
# sample_id_to_path = gather_sample_file(
#     *(split_sub_dir.raw, split_sub_dir.augmented))
# sample_id_to_path.keys()

# Gather the raw MIDI files
sample_id_to_path = gather_sample_file(
    (split_sub_dir.raw))
sample_id_to_path.keys()

dict_keys(['0004806f96307e317d116040af5b7861_1', '0004806f96307e317d116040af5b7861_11', '0004806f96307e317d116040af5b7861_12', '0004806f96307e317d116040af5b7861_13', '0004806f96307e317d116040af5b7861_2', '0004806f96307e317d116040af5b7861_3', '0004806f96307e317d116040af5b7861_4', '0004806f96307e317d116040af5b7861_5', '0004806f96307e317d116040af5b7861_6', '0004806f96307e317d116040af5b7861_7', '0004806f96307e317d116040af5b7861_9', '001344339e5b6a6bf1bc16d70b7f91a2_10', '001344339e5b6a6bf1bc16d70b7f91a2_11', '001344339e5b6a6bf1bc16d70b7f91a2_12', '001344339e5b6a6bf1bc16d70b7f91a2_13', '001344339e5b6a6bf1bc16d70b7f91a2_14', '001344339e5b6a6bf1bc16d70b7f91a2_15', '001344339e5b6a6bf1bc16d70b7f91a2_16', '001344339e5b6a6bf1bc16d70b7f91a2_2', '001344339e5b6a6bf1bc16d70b7f91a2_3', '001344339e5b6a6bf1bc16d70b7f91a2_4', '001344339e5b6a6bf1bc16d70b7f91a2_5', '001344339e5b6a6bf1bc16d70b7f91a2_6', '001344339e5b6a6bf1bc16d70b7f91a2_7', '001344339e5b6a6bf1bc16d70b7f91a2_8', '001344339e5b6a6bf1bc16d70b7f

# Read In MIDI Metadata to use in Encoding
- For this experiment, let's try blanking out a specific column by replacing it with an empty string

In [66]:
# read in comMU metqdata
commu_metadata = pd.read_csv(f'{csv_path}', 
                             index_col=[0], 
                             converters={"chord_progressions": literal_eval}).reset_index()
# commu_metadata['chord_progressions'] = commu_metadata['chord_progressions'].apply(lambda x: [literal_eval(x[0])])
commu_metadata.head()

Unnamed: 0,audio_key,pitch_range,num_measures,bpm,genre,track_roll,inst,sample_rhythm,time_signature,min_velocity,max_velocity,split_data,id,chord_progressions,track_role
0,cmajor,unknown,10,192,electronic,unknown,dulcimer,unknown,2/4,64,87,unknown,0004806f96307e317d116040af5b7861_11,"[[Am, Am, C, C, F, F, Am, Am, B, B, G, G, B, B...",unknown
1,fmajor,unknown,7,112,electronic,unknown,brass_section,unknown,2/4,127,127,unknown,0004806f96307e317d116040af5b7861_12,"[[F, F, F, F, A, A, A, A, C, C, G, G, F, F, F,...",unknown
2,fmajor,unknown,12,163,electronic,unknown,percussive_organ,unknown,2/4,31,85,unknown,0004806f96307e317d116040af5b7861_13,"[[Dm, Dm, Dm, Dm, F, F, F, F, B, B, G, G, Dm, ...",unknown
3,dminor,unknown,19,109,electronic,unknown,synthstrings_1,unknown,2/4,85,96,unknown,0004806f96307e317d116040af5b7861_2,"[[Dm, Dm, Dm, Dm, F, F, Am, Am, Am, Am, G, G, ...",unknown
4,aminor,unknown,11,179,electronic,unknown,lead_1_square,unknown,2/4,27,113,unknown,0004806f96307e317d116040af5b7861_3,"[[C, C, G, G, A, A, C, C, A, A, G, G, G, G, C,...",unknown


In [67]:
commu_metadata.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 62062 entries, 0 to 62061
Data columns (total 15 columns):
 #   Column              Non-Null Count  Dtype 
---  ------              --------------  ----- 
 0   audio_key           62062 non-null  object
 1   pitch_range         62062 non-null  object
 2   num_measures        62062 non-null  int64 
 3   bpm                 62062 non-null  int64 
 4   genre               62062 non-null  object
 5   track_roll          62062 non-null  object
 6   inst                62062 non-null  object
 7   sample_rhythm       62062 non-null  object
 8   time_signature      62062 non-null  object
 9   min_velocity        62062 non-null  int64 
 10  max_velocity        62062 non-null  int64 
 11  split_data          62062 non-null  object
 12  id                  62062 non-null  object
 13  chord_progressions  62062 non-null  object
 14  track_role          62062 non-null  object
dtypes: int64(4), object(11)
memory usage: 7.1+ MB


In [68]:
# # # replace some columns with an empty string
# commu_metadata['track_role'] = 'unknown'
# commu_metadata['sample_rhythm'] = 'unknown'
# commu_metadata['pitch_range'] = 'unknown'
# commu_metadata.head()

In [69]:
# organize the sample_id_to_path dictionary into chunks
sample_infos_chunk = [
    (idx, arr.tolist())
    for idx, arr in enumerate(np.array_split(np.array(commu_metadata.to_dict('records')), num_cores))
]

# invert index to create a key value pair of sample_id to sample_info
sample_infos_inverted_index = []
copied_sample_infos_chunks = []
for data in sample_infos_chunk:  
    idx, sample_info = data
    copied_sample_infos_chunk = copy.deepcopy(list(sample_info))
    parent_sample_ids_to_info = {
        sample["id"]: sample for sample in copied_sample_infos_chunk
    }

    parent_sample_ids = set(parent_sample_ids_to_info.keys())
    copied_sample_infos_chunk.extend(
        [
            {"id": sample_id, "augmented": True}
            for sample_id in sample_id_to_path.keys()
            if sample_id.split("_")[0] in parent_sample_ids
        ]
    )

    sample_infos_inverted_index.append(parent_sample_ids_to_info)
    copied_sample_infos_chunks.append(copied_sample_infos_chunk)


#sample_infos_inverted_index[0], 
len(copied_sample_infos_chunks), len(copied_sample_infos_chunks[0])

(4, 15516)

In [70]:
copied_sample_infos_chunks[0][0]['id']

'0004806f96307e317d116040af5b7861_11'

In [71]:
all_meta_keys = copied_sample_infos_chunks[0][0].keys()
all_meta_keys

dict_keys(['audio_key', 'pitch_range', 'num_measures', 'bpm', 'genre', 'track_roll', 'inst', 'sample_rhythm', 'time_signature', 'min_velocity', 'max_velocity', 'split_data', 'id', 'chord_progressions', 'track_role'])

In [72]:
midi_meta_info = []
midi_files_path = []

for chunk in range(len(copied_sample_infos_chunks)):
    for sample_info_idx, sample_info in tqdm(enumerate(copied_sample_infos_chunks[chunk])):

        # for iteration, take a copy of the sample_info
        copied_sample_info = sample_info

        # for now we're only dealing with original MIDI
        # split the sample_id into its components 
        # get the bpm and audio_key
        if sample_info.get("augmented", 'None'):
            # print(copied_sample_info["id"])
            id_split = copied_sample_info["id"]#.split("_")

            # # if the sample_id is augmented, split the sample_id into its components
            # if len(id_split) > 1:
            #     parent_sample_id, audio_key, bpm = id_split

            #     for key in all_meta_keys:
            #         if key not in ["id"]:
            #             # if key == 'pitch_range':
            #             #     print(sample_infos_inverted_index[chunk][parent_sample_id][key])
            #             from_data = sample_infos_inverted_index[chunk][parent_sample_id][key]
            #             copied_sample_info[key] = from_data

            # else:
            #     parent_sample_id = id_split[0]


            # get some metadata pertinent information
            bpm = copied_sample_info.get("bpm")
            audio_key = copied_sample_info.get("audio_key")
            chord_progression = copied_sample_info.get("chord_progressions")
            rhythm = copied_sample_info.get("sample_rhythm")
            num_measures = copied_sample_info.get("num_measures")
            pitch_range = copied_sample_info.get("pitch_range")

            # if bpm or audio_key is None, skip the sample
            # if bpm is None or audio_key is None:
            #     continue
            
            key_origin = audio_key in KEY_MAP.keys() #["cmajor", "aminor"]
            
            # if the key origin is not cmajor or aminor, skip the sample
            # when there is a key origin, sync the cord progression with the audio key augment
            if not key_origin:
                continue
            try:
                chord_progression = sync_key_augment(
                    chord_progression[0],
                    audio_key.replace("minor", "").replace("major", ""),
                    audio_key[0], # audio_key is assigned to the front of the value
                )
            except IndexError:
                print("chord progression info is unknown")
                continue

            # assign sample rhythm
            copied_sample_info["rhythm"] = rhythm


            # is_incomplete_measure column addition
            if num_measures%4==0:
                copied_sample_info["is_incomplete_measure"] = False
            else:
                copied_sample_info["is_incomplete_measure"] = True


            # to this point, each MIDI file would have a dictionary of metadata information
            # fetch the MIDI path
            midi_path = sample_id_to_path.get(copied_sample_info["id"])
            # append the pre-processed MIDI meta info to the midi_meta_info list
            # append the midi data path to the midi_files_path list
            midi_meta_info.append(copied_sample_info)
            midi_files_path.append(midi_path)

# midi_meta_info, midi_files_path

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

In [73]:
# encoded_events = []

# # loop through midi/mete info pairs
# for midi_path, sample_info in tqdm(zip(midi_files_path[:1000], midi_meta_info[:1000])): 
#     if midi_path is None:
#         continue
#     try:
#         encoded_event = encode(midi_path, sample_info, for_cp=True)
#         encoded_events.append(encoded_event)
#     except Exception as e:
#         print(e)
#         pass
    
# encoded_events[0]

In [74]:
len(midi_files_path), len(midi_meta_info)

(62062, 62062)

In [76]:
encoded_data = []

# loop through midi/mete info pairs
for midi_path, sample_info in tqdm(zip(midi_files_path, midi_meta_info)): 
    if midi_path is None:
        continue
    try:
        encoding_output = preprocess_midi(sample_info, midi_path)
        encoded_data.append(encoding_output)
    except Exception as e:
        print(e)
        pass

encoded_data[0]

0it [00:00, ?it/s]

inst KeyError: synthstrings_1: ../dataset/musicMIDI/train/raw/0004806f96307e317d116040af5b7861_2.mid
inst KeyError: pad_3_polysynth: ../dataset/musicMIDI/train/raw/0004806f96307e317d116040af5b7861_9.mid
inst KeyError: synthstrings_2: ../dataset/musicMIDI/train/raw/001344339e5b6a6bf1bc16d70b7f91a2_10.mid
inst KeyError: synthstrings_2: ../dataset/musicMIDI/train/raw/001d789d86a9139c710af589a7a5f134_7.mid
inst KeyError: synthstrings_1: ../dataset/musicMIDI/train/raw/0048690121681712c62f5785d88ef755_5.mid
inst KeyError: synthstrings_1: ../dataset/musicMIDI/train/raw/0048690121681712c62f5785d88ef755_6.mid
OOV Note Velocity_-1
num measures ValueError: 1: ../dataset/musicMIDI/train/raw/005a933ca8e4b53a8ad03679c0f5ec73_10.mid
num measures ValueError: 1: ../dataset/musicMIDI/train/raw/006d122fe259c779d11c9731b26b4889_10.mid
OOV Note Velocity_-1
OOV Note Velocity_-1
OOV Note Velocity_-1
OOV Note Velocity_-1
OOV Note Velocity_-1
OOV Note Velocity_-1
OOV Note Velocity_-1
OOV Note Velocity_-1
OOV N

EncodingOutput(meta=array([598, 602, 631, 630, 640, 645, 665, 686, 698, 719, 726],
      dtype=object), event_sequence=array([  2, 432, 162, ...,  92, 315,   1], dtype=int16))

In [78]:
encoded_data_clean = [x for x in encoded_data if x is not None]
len(encoded_data_clean)

57367

In [27]:
encoded_data_clean[0]

EncodingOutput(meta=array([598, 602, 631, 630, 640, 645, 665, 686, 698, 719, 726],
      dtype=object), event_sequence=array([  2, 432, 162, ...,  92, 315,   1], dtype=int16))

In [None]:
import joblib

# save the encoded data
joblib.dump(encoded_data, split_sub_dir.encode_tmp.joinpath("meta_midi_encoded_data.joblib"))

['../dataset/commu_midi/train/npy_tmp/encoded_data.joblib']

In [None]:
encoded_data = []

for chunk in range(len(copied_sample_infos_chunks)):
    for sample_info_idx, sample_info in tqdm(enumerate(copied_sample_infos_chunks[chunk])):

        # for iteration, take a copy of the sample_info
        copied_sample_info = sample_info

        if sample_info.get("augmented", False):
            id_split = copied_sample_info["id"].split("_")
            bpm = copied_sample_info.get("bpm")
            audio_key = copied_sample_info.get("audio_key")

            if len(id_split) > 1:
                parent_sample_id, audio_key, bpm = id_split
            else:
                parent_sample_id = id_split[0]

            if bpm is None or audio_key is None:
                continue
            
            
            augmented_midi_path = sample_id_to_path[copied_sample_info["id"]]
            copied_sample_info = copy.deepcopy(sample_infos_inverted_index[chunk][parent_sample_id])
            copied_sample_info["bpm"] = int(bpm)

            key_origin = copied_sample_info["audio_key"] in ["cmajor", "aminor"]

            # when there is a key origin, sync the cord progression with the audio key augment
            if not key_origin:
                continue
            try:
                copied_sample_info["chord_progressions"] = sync_key_augment(
                    copied_sample_info["chord_progressions"][0],
                    audio_key.replace("minor", "").replace("major", ""),
                    copied_sample_info["audio_key"][0], # audio_key 값 앞쪽으로 할당
                )
            except IndexError:
                print(f"chord progression info is unknown: {augmented_midi_path}")
                continue

            copied_sample_info["audio_key"] = audio_key
            copied_sample_info["rhythm"] = copied_sample_info.get("sample_rhythm")

            # is_incomplete_measure column 추가
            if copied_sample_info["num_measures"]%4==0:
                copied_sample_info["is_incomplete_measure"] = False
            else:
                copied_sample_info["is_incomplete_measure"] = True


            # to this point, each MIDI file would have a dictionary of metadata information
            # fetch the MIDI path
            midi_path = sample_id_to_path.get(copied_sample_info["id"])
            
            if midi_path is None:
                continue
            try:
                encoding_output = preprocess_midi(
                    sample_info=copied_sample_info, midi_path=augmented_midi_path
                )
            except (IndexError, TypeError) as e:
                print(f"{e}: {augmented_midi_path}")
                continue
            except ValueError:
                print(f"num measures not allowed: {augmented_midi_path}")
                continue

            encoded_data.append(encoding_output)
encoding_output[0]

0it [00:00, ?it/s]

pitch range KeyError: : ../dataset/commu_midi/train/augmented/commu00001_abminor_110.mid
pitch range KeyError: : ../dataset/commu_midi/train/augmented/commu00001_abminor_115.mid
pitch range KeyError: : ../dataset/commu_midi/train/augmented/commu00001_abminor_120.mid
pitch range KeyError: : ../dataset/commu_midi/train/augmented/commu00001_abminor_125.mid
pitch range KeyError: : ../dataset/commu_midi/train/augmented/commu00001_abminor_130.mid
pitch range KeyError: : ../dataset/commu_midi/train/augmented/commu00001_aminor_110.mid
pitch range KeyError: : ../dataset/commu_midi/train/augmented/commu00001_aminor_115.mid
pitch range KeyError: : ../dataset/commu_midi/train/augmented/commu00001_aminor_120.mid
pitch range KeyError: : ../dataset/commu_midi/train/augmented/commu00001_aminor_125.mid
pitch range KeyError: : ../dataset/commu_midi/train/augmented/commu00001_aminor_130.mid
pitch range KeyError: : ../dataset/commu_midi/train/augmented/commu00001_bbminor_110.mid
pitch range KeyError: : ..

KeyboardInterrupt: 