# Module to Generate Inference Sample
- Load model final trained checkpoint
- randomly select some parameters and generate several midi based on the select paramters
- convert midi to wav
- run FAD score

In [11]:
# import all the necessary modules
import os
import argparse
import torch
import pandas as pd
import numpy as np

from ast import literal_eval
from typing import Dict

from tqdm.notebook import tqdm

# load COMMU modules
from commu.midi_generator.generate_pipeline import MidiGenerationPipeline
from commu.preprocessor.utils import constants
from commu.midi_generator.container import ModelArguments
from commu.midi_generator.model_initializer import ModelInitializeTask
from commu.midi_generator.info_preprocessor import PreprocessTask
from commu.midi_generator.midi_inferrer import InferenceTask
from commu.midi_generator.sequence_postprocessor import PostprocessTask

In [12]:
map_location = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(map_location)
map_location

'cuda'

In [13]:
# declare path variables
MODEL_DIR = "./models"
DATA_DIR = "./dataset"

In [14]:
# # load torch model checkpoint
# checkpoint = torch.load(f"{MODEL_DIR}/good_checkpoint_best.pt", map_location='cpu')
# checkpoint

# Helpher Functions

In [15]:

def parse_args() -> Dict[str, argparse.ArgumentParser]:
    model_arg_parser = argparse.ArgumentParser(description="Model Arguments")
    input_arg_parser = argparse.ArgumentParser(description="Input Arguments")

    # Model Arguments
    model_arg_parser.add_argument("--checkpoint_dir", type=str)

    # Input Arguments
    input_arg_parser.add_argument("--output_dir", type=str, required=True)

    ## Input meta
    input_arg_parser.add_argument("--bpm", type=int)
    input_arg_parser.add_argument("--audio_key", type=str, choices=list(constants.KEY_MAP.keys()))
    input_arg_parser.add_argument("--time_signature", type=str, choices=list(constants.TIME_SIG_MAP.keys()))
    input_arg_parser.add_argument("--pitch_range", type=str, choices=list(constants.PITCH_RANGE_MAP.keys()))
    input_arg_parser.add_argument("--num_measures", type=float)
    input_arg_parser.add_argument(
        "--inst", type=str, choices=list(constants.INST_MAP.keys()),
    )
    input_arg_parser.add_argument(
        "--genre", type=str, default="cinematic", choices=list(constants.GENRE_MAP.keys())
    )
    input_arg_parser.add_argument(
        "--track_role", type=str, choices=list(constants.TRACK_ROLE_MAP.keys())
    )
    input_arg_parser.add_argument(
        "--rhythm", type=str, default="standard", choices=list(constants.RHYTHM_MAP.keys())
    )
    input_arg_parser.add_argument("--min_velocity", type=int, choices=range(1, 128))
    input_arg_parser.add_argument("--max_velocity", type=int, choices=range(1, 128))
    input_arg_parser.add_argument(
        "--chord_progression", type=str, help='Chord progression ex) C-C-E-E-G-G ...'
    )
    # Inference 시 필요 정보
    input_arg_parser.add_argument("--num_generate", type=int)
    input_arg_parser.add_argument("--top_k", type=int, default=32)
    input_arg_parser.add_argument("--temperature", type=float, default=0.95)

    arg_dict = {
        "model_args": model_arg_parser,
        "input_args": input_arg_parser
    }
    return arg_dict

# Load Model Paramaters

In [17]:
# generate some MIDI file for FAD scoring
# load in the metadata
metadata_df = pd.read_csv(f"{DATA_DIR}/meta_filtered_commu_full_inst_genre.csv",
                          index_col=[0], 
                          converters={"chord_progressions": literal_eval}).reset_index()

# get val and commu id
metadata_df = metadata_df[(metadata_df["id"].str.contains('commu')) & (metadata_df["split_data"] == "val")]
# metadata_df = metadata_df[(metadata_df["split_data"] == "val")]
metadata_df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 763 entries, 30696 to 33718
Data columns (total 20 columns):
 #   Column               Non-Null Count  Dtype 
---  ------               --------------  ----- 
 0   audio_key            763 non-null    object
 1   chord_progressions   763 non-null    object
 2   pitch_range          763 non-null    object
 3   num_measures         763 non-null    int64 
 4   bpm                  763 non-null    int64 
 5   genre                763 non-null    object
 6   track_role           763 non-null    object
 7   inst                 763 non-null    object
 8   sample_rhythm        763 non-null    object
 9   time_signature       763 non-null    object
 10  min_velocity         763 non-null    int64 
 11  max_velocity         763 non-null    int64 
 12  split_data           763 non-null    object
 13  id                   763 non-null    object
 14  track_roll           0 non-null      object
 15  unique_chord_n_note  763 non-null    object
 16  inst_ma

In [8]:
# filter out the instruments
filtered_inst = [
    "accordion",
    "acoustic_bass",
    "acoustic_grand_piano",
    "acoustic_guitar_nylon",
    "acoustic_piano",
    "agogo",
    "alto_sax",
    "choir_aahs",
    "electric_guitar_muted",
    "electric_piano_1",
    "ensemble",
    "honky_tonk_piano",
    "melodic_tom",
    "overdriven_guitar",
    "piccolo",
    "reverse_cymbal",
    "soprano_sax",
    "steel_drums",
    "string_cello",
    "string_ensemble",
    "string_ensemble_1",
    "string_ensemble_2",
    "string_violin",
    "synth_drum",
    "synthstrings_2",
    "taiko_drum",
    "tango_accordion",
    "timpani",
    "trumpet",
    "viola",
    "voice_oohs",
    "woodblock"
]
# filtered out the instrument and grab only inst with at least 20 songs
inst_filtered_meta_df = metadata_df[metadata_df['inst'].isin(filtered_inst)]
inst_count = metadata_df.groupby("inst").size().to_frame().reset_index().rename(columns={0: "count"})
inst_count = inst_count[inst_count['count'] >= 20]
inst_filtered_meta_df = inst_filtered_meta_df[inst_filtered_meta_df['inst'].isin(inst_count.inst)]
inst_filtered_meta_df.groupby("inst").size().to_frame().reset_index().rename(columns={0: "count"})


Unnamed: 0,inst,count
0,acoustic_bass,117
1,acoustic_grand_piano,40
2,acoustic_guitar_nylon,129
3,acoustic_piano,180
4,alto_sax,24
5,choir_aahs,47
6,electric_guitar_muted,111
7,ensemble,174
8,overdriven_guitar,107
9,piccolo,40


In [19]:
# # move files to it respected location
# for i, row in tqdm(inst_filtered_meta_df.iterrows()):
#     # create wav directory if not exists
#     if not os.path.exists(f"{DATA_DIR}/raw_validation_data/{row['inst']}/midi"):
#         os.makedirs(f"{DATA_DIR}/raw_validation_data/{row['inst']}/midi")

#     file_name = row["id"]+".mid"
#     os.rename(f"{DATA_DIR}/raw_validation_data/raw/{file_name}", f"{DATA_DIR}/raw_validation_data/{row['inst']}/midi/{os.path.basename(file_name)}")
    

0it [00:00, ?it/s]

In [9]:
# filtered out the genre and grab only inst with at least 20 songs
genre_count = metadata_df.groupby("genre").size().to_frame().reset_index().rename(columns={0: "count"})
genre_count = genre_count[genre_count['count'] >= 20]
genre_filtered_meta_df = metadata_df[metadata_df['genre'].isin(genre_count.genre)]
genre_filtered_meta_df.groupby("genre").size().to_frame().reset_index().rename(columns={0: "count"})

Unnamed: 0,genre,count
0,blues,23
1,cinematic,571
2,classical,118
3,country,118
4,dance,99
5,electronic,102
6,folk,33
7,folk_world_country,27
8,instrumental,38
9,jazz,67


In [16]:
# initialize model and input variables

for i, row in tqdm(inst_filtered_meta_df.iterrows()):

    # for the three unknown columns, set it to default
    if row["pitch_range"] == "unknown":
        row["pitch_range"] = "mid"
        
    if row["track_role"] == "unknown":
        row["track_role"] = "sub_melody"

    # concat chord progression with hyphen
    chord_progression = "-".join(row["chord_progressions"][0])

    try:
        # initialize model and input variables
        model_args, _ = parse_args()["model_args"].parse_known_args(args=["--checkpoint_dir", f"{MODEL_DIR}/commu_meta_reduced_checkpoint_best_lowest_val_NLL_perturb.pt"])
        input_args, _ = parse_args()["input_args"].parse_known_args(args=["--output_dir", f"{DATA_DIR}/reduced_encoding_val/best_val_NLL_perturb/{row['inst']}/midi/",
                                                                        "--bpm", f"{row['bpm']}",
                                                                        "--audio_key", f"{row['audio_key']}",
                                                                        "--time_signature", f"{row['time_signature']}",
                                                                        "--pitch_range", f"{row['pitch_range']}",
                                                                        "--num_measures", f"{row['num_measures']}",
                                                                        "--inst", f"{row['inst']}",
                                                                        "--genre", f"{row['genre']}",
                                                                        "--min_velocity", f"{row['min_velocity']}",
                                                                        "--max_velocity", f"{row['max_velocity']}",
                                                                        "--track_role", f"{row['track_role']}",
                                                                        "--chord_progression", F"{chord_progression}",
                                                                        "--num_generate", "3"])
        # instantiate model pipeline
        pipeline = MidiGenerationPipeline()
        pipeline.initialize_model(vars(model_args))
        pipeline.initialize_generation()

        # initialize model for interences
        inference_cfg = pipeline.model_initialize_task.inference_cfg
        model = pipeline.model_initialize_task.execute()

        # encode input data
        encoded_meta = pipeline.preprocess_task.execute(vars(input_args))
        input_data = pipeline.preprocess_task.input_data
        
        # generate the MIDI
        pipeline.inference_task(
            model=model,
            input_data=input_data,
            inference_cfg=inference_cfg
        )
        sequences = pipeline.inference_task.execute(encoded_meta)

        # postprocess the generated MIDI
        pipeline.postprocess_task(input_data=input_data)
        midi_dict = pipeline.postprocess_task.execute(sequences=sequences)
    except:
        pass

0it [00:00, ?it/s]

2023-11-28 17:11:41,513 | INFO | ComMU | Generating the idx: 1
2023-11-28 17:11:45,063 | INFO | ComMU | correct_length: 4
2023-11-28 17:11:45,065 | INFO | ComMU | [0, 584, 602, 627, 635, 638, 646, 651, 721, 725, 735, 737, 2, 432, 222, 432, 191, 75, 311, 440, 191, 82, 311, 448, 191, 84, 311, 456, 191, 87, 311, 464, 191, 82, 311, 472, 191, 84, 311, 480, 191, 87, 311, 488, 191, 79, 311, 496, 191, 79, 311, 504, 191, 82, 311, 512, 191, 84, 311, 520, 191, 87, 311, 528, 191, 82, 311, 536, 191, 84, 311, 544, 191, 87, 311, 552, 191, 79, 311, 2, 432, 285, 432, 191, 70, 311, 440, 191, 74, 311, 448, 191, 77, 311, 456, 191, 79, 311, 464, 191, 82, 311, 472, 191, 86, 311, 480, 191, 82, 311, 488, 191, 86, 311, 496, 191, 82, 311, 504, 191, 86, 311, 512, 191, 82, 311, 520, 191, 86, 311, 528, 191, 82, 311, 536, 191, 86, 311, 544, 191, 82, 311, 552, 191, 86, 311, 2, 432, 199, 432, 191, 75, 311, 440, 191, 79, 311, 448, 191, 84, 311, 456, 191, 87, 311, 464, 191, 79, 311, 472, 191, 79, 311, 480, 191, 84, 311

OOV: 701


2023-11-28 17:55:36,207 | INFO | ComMU | Generating the idx: 1
2023-11-28 17:55:37,205 | INFO | ComMU | correct_length: 16
2023-11-28 17:55:37,206 | INFO | ComMU | [0, 584, 611, 628, 634, 640, 643, 651, 671, 728, 731, 737, 2, 432, 195, 2, 432, 258, 2, 432, 195, 2, 432, 262, 496, 258, 2, 432, 195, 2, 432, 258, 2, 432, 195, 2, 432, 258, 2, 432, 199, 496, 195, 2, 432, 276, 538, 185, 67, 347, 2, 432, 195, 496, 217, 2, 432, 217, 496, 276, 539, 140, 67, 431, 2, 432, 195, 2, 432, 258, 2, 432, 195, 2, 432, 258, 496, 195, 1]
2023-11-28 17:55:37,207 | INFO | ComMU | Generating the idx: 2
2023-11-28 17:55:39,586 | INFO | ComMU | correct_length: 16
2023-11-28 17:55:39,587 | INFO | ComMU | [0, 584, 611, 628, 634, 640, 643, 651, 671, 728, 731, 737, 2, 432, 195, 2, 432, 258, 2, 432, 195, 2, 432, 262, 496, 258, 2, 432, 195, 2, 432, 258, 2, 432, 195, 2, 432, 258, 2, 432, 199, 496, 195, 537, 139, 76, 307, 540, 139, 76, 307, 544, 139, 72, 306, 546, 143, 74, 306, 548, 143, 76, 306, 552, 142, 76, 307, 555,

OOV: 610


2023-11-28 18:25:53,943 | INFO | ComMU | Generating the idx: 1
2023-11-28 18:25:54,825 | INFO | ComMU | correct_length: 8
2023-11-28 18:25:54,826 | INFO | ComMU | [0, 576, 602, 627, 635, 639, 642, 651, 675, 679, 731, 737, 2, 432, 222, 432, 142, 82, 367, 496, 142, 79, 367, 2, 432, 262, 432, 144, 82, 367, 496, 144, 82, 367, 2, 432, 267, 496, 145, 77, 367, 2, 432, 222, 432, 142, 82, 367, 496, 285, 496, 142, 82, 367, 2, 432, 222, 432, 145, 82, 367, 496, 143, 79, 367, 2, 432, 262, 496, 145, 82, 367, 2, 432, 267, 432, 145, 80, 367, 496, 145, 84, 367, 2, 432, 285, 1]
2023-11-28 18:25:54,827 | INFO | ComMU | Generating the idx: 2
2023-11-28 18:25:56,717 | INFO | ComMU | correct_length: 8
2023-11-28 18:25:56,718 | INFO | ComMU | [0, 576, 602, 627, 635, 639, 642, 651, 675, 679, 731, 737, 2, 432, 222, 432, 142, 82, 335, 464, 144, 82, 335, 496, 144, 82, 319, 528, 143, 74, 319, 544, 144, 75, 319, 2, 432, 262, 432, 145, 82, 335, 464, 144, 82, 335, 496, 143, 82, 319, 512, 143, 82, 319, 528, 145, 82, 

In [18]:
# initialize model and input variables

for i, row in tqdm(metadata_df.iterrows()):

    # concat chord progression with hyphen
    chord_progression = "-".join(row["chord_progressions"][0])

    # initialize model and input variables
    model_args, _ = parse_args()["model_args"].parse_known_args(args=["--checkpoint_dir", f"{MODEL_DIR}/commu_meta_reduced_checkpoint_best_lowest_val_NLL_perturb.pt"])
    input_args, _ = parse_args()["input_args"].parse_known_args(args=["--output_dir", f"{DATA_DIR}/reduced_encoding_val/midi/perturb_mini_batch_45K",
                                                                    "--bpm", f"{row['bpm']}",
                                                                    "--audio_key", f"{row['audio_key']}",
                                                                    "--time_signature", f"{row['time_signature']}",
                                                                    "--pitch_range", f"{row['pitch_range']}",
                                                                    "--num_measures", f"{row['num_measures']}",
                                                                    "--inst", f"{row['inst']}",
                                                                    "--genre", f"{row['genre']}",
                                                                    "--min_velocity", f"{row['min_velocity']}",
                                                                    "--max_velocity", f"{row['max_velocity']}",
                                                                    "--track_role", f"{row['track_role']}",
                                                                    "--chord_progression", F"{chord_progression}",
                                                                    "--num_generate", "3"])
    # instantiate model pipeline
    pipeline = MidiGenerationPipeline()
    pipeline.initialize_model(vars(model_args))
    pipeline.initialize_generation()

    # initialize model for interences
    inference_cfg = pipeline.model_initialize_task.inference_cfg
    model = pipeline.model_initialize_task.execute()

    # encode input data
    encoded_meta = pipeline.preprocess_task.execute(vars(input_args))
    input_data = pipeline.preprocess_task.input_data
    
    # generate the MIDI
    pipeline.inference_task(
        model=model,
        input_data=input_data,
        inference_cfg=inference_cfg
    )
    sequences = pipeline.inference_task.execute(encoded_meta)

    # postprocess the generated MIDI
    pipeline.postprocess_task(input_data=input_data)
    midi_dict = pipeline.postprocess_task.execute(sequences=sequences)

0it [00:00, ?it/s]

2023-11-29 10:19:28,949 | INFO | ComMU | Generating the idx: 1
2023-11-29 10:19:32,573 | INFO | ComMU | correct_length: 4
2023-11-29 10:19:32,575 | INFO | ComMU | [0, 584, 602, 627, 635, 638, 646, 651, 721, 725, 735, 737, 2, 432, 222, 432, 191, 82, 311, 440, 191, 70, 311, 448, 191, 79, 311, 456, 191, 70, 311, 464, 191, 79, 311, 472, 191, 70, 311, 480, 191, 79, 311, 488, 191, 70, 311, 496, 191, 79, 311, 504, 191, 70, 311, 512, 191, 79, 311, 520, 191, 70, 311, 528, 191, 79, 311, 536, 191, 70, 311, 544, 191, 79, 311, 552, 191, 70, 311, 2, 432, 285, 432, 191, 82, 311, 440, 191, 70, 311, 448, 191, 77, 311, 456, 191, 70, 311, 464, 191, 82, 311, 472, 191, 70, 311, 480, 191, 77, 311, 488, 191, 70, 311, 496, 191, 77, 311, 504, 191, 70, 311, 512, 191, 77, 311, 520, 191, 70, 311, 528, 191, 82, 311, 536, 191, 70, 311, 544, 191, 77, 311, 552, 191, 70, 311, 2, 432, 199, 432, 191, 84, 311, 440, 191, 72, 311, 448, 191, 79, 311, 456, 191, 70, 311, 464, 191, 84, 311, 472, 191, 72, 311, 480, 191, 79, 311

OOV: 646


2023-11-29 10:39:24,690 | INFO | ComMU | Generating the idx: 1
2023-11-29 10:39:26,510 | INFO | ComMU | correct_length: 8
2023-11-29 10:39:26,511 | INFO | ComMU | [0, 576, 602, 627, 635, 639, 646, 651, 672, 673, 734, 737, 2, 432, 229, 2, 432, 274, 496, 143, 84, 319, 496, 144, 96, 319, 512, 146, 84, 319, 512, 145, 96, 319, 528, 144, 80, 319, 528, 145, 92, 319, 2, 432, 229, 496, 143, 86, 319, 496, 145, 98, 319, 512, 146, 84, 319, 512, 145, 96, 319, 528, 145, 82, 319, 528, 145, 94, 319, 2, 432, 274, 496, 139, 84, 319, 496, 141, 96, 319, 512, 137, 84, 319, 512, 139, 96, 319, 528, 141, 80, 319, 528, 139, 92, 319, 2, 432, 262, 2, 432, 274, 2, 432, 229, 496, 142, 82, 319, 496, 144, 94, 319, 512, 142, 87, 319, 512, 143, 99, 319, 528, 139, 82, 319, 528, 139, 94, 319, 544, 142, 80, 319, 544, 142, 92, 319, 2, 432, 274, 432, 139, 84, 319, 432, 141, 96, 319, 464, 139, 84, 319, 464, 139, 96, 319, 496, 139, 80, 319, 496, 139, 92, 319, 1]
2023-11-29 10:39:26,514 | INFO | ComMU | Generating the idx: 2


OOV: 646


2023-11-29 11:14:00,173 | INFO | ComMU | Generating the idx: 1
2023-11-29 11:14:02,639 | INFO | ComMU | correct_length: 8
2023-11-29 11:14:02,641 | INFO | ComMU | [0, 573, 623, 629, 633, 639, 646, 651, 704, 705, 733, 737, 2, 432, 245, 496, 285, 559, 171, 51, 305, 2, 432, 229, 474, 171, 51, 314, 496, 171, 51, 326, 517, 171, 43, 326, 538, 171, 39, 326, 2, 432, 245, 432, 171, 53, 374, 496, 285, 516, 171, 53, 327, 539, 171, 46, 326, 2, 432, 229, 473, 171, 51, 328, 496, 171, 51, 326, 517, 171, 43, 326, 538, 171, 38, 326, 2, 432, 274, 432, 171, 44, 370, 496, 285, 496, 171, 46, 370, 558, 171, 51, 326, 2, 432, 263, 473, 171, 50, 326, 496, 200, 496, 171, 48, 326, 517, 171, 46, 327, 539, 171, 38, 326, 2, 432, 282, 432, 171, 45, 373, 496, 171, 45, 370, 2, 432, 285, 496, 259, 1]
2023-11-29 11:14:02,644 | INFO | ComMU | Generating the idx: 2
2023-11-29 11:14:04,286 | INFO | ComMU | correct_length: 8
2023-11-29 11:14:04,288 | INFO | ComMU | [0, 573, 623, 629, 633, 639, 646, 651, 704, 705, 733, 737, 

# Model Inference - Deconstruction
- Same code as the main for loop above just desconstructed for debugging purposes

In [8]:
# initialize model and input variables
model_args, _ = parse_args()["model_args"].parse_known_args(args=["--checkpoint_dir", f"{MODEL_DIR}/good_checkpoint_best.pt"])
input_args, _ = parse_args()["input_args"].parse_known_args(args=["--output_dir", f"{DATA_DIR}/original-commu/midi",
                                                                  "--bpm", "120",
                                                                  "--audio_key", "cmajor",
                                                                  "--time_signature", "4/4",
                                                                  "--pitch_range", "mid_high",
                                                                  "--num_measures", "8",
                                                                  "--inst", "violin",
                                                                  "--genre", "cinematic",
                                                                  "--min_velocity", "60",
                                                                  "--max_velocity", "80",
                                                                  "--track_role", "main_melody",
                                                                  "--chord_progression", "Am-Am-Am-Am-Am-Am-Am-Am-G-G-G-G-G-G-G-G-F-F-F-F-F-F-F-F-E-E-E-E-E-E-E-E-Am-Am-Am-Am-Am-Am-Am-Am-G-G-G-G-G-G-G-G-F-F-F-F-F-F-F-F-E-E-E-E-E-E-E-E",
                                                                  "--num_generate", "1"])



In [9]:
# instantiate model pipeline
pipeline = MidiGenerationPipeline()
pipeline.initialize_model(vars(model_args))
pipeline.initialize_generation()

In [10]:
# initialize model for interences
inference_cfg = pipeline.model_initialize_task.inference_cfg
model = pipeline.model_initialize_task.execute()

In [11]:
# encode input data
encoded_meta = pipeline.preprocess_task.execute(vars(input_args))
input_data = pipeline.preprocess_task.input_data
encoded_meta, input_data

([584, 602, 627, 635, 639, 646, 651, 694, 704, 730, 737],
 TransXlInputData(bpm=120, audio_key='cmajor', time_signature='4/4', pitch_range='mid_high', num_measures=8.0, inst='violin', genre='cinematic', min_velocity=60, max_velocity=80, track_role='main_melody', rhythm='standard', output_dir=WindowsPath('dataset/reduced_encoding/midi/52K'), num_generate=1, top_k=32, temperature=0.95, chord_progression=['Am', 'Am', 'Am', 'Am', 'Am', 'Am', 'Am', 'Am', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'F', 'F', 'F', 'F', 'F', 'F', 'F', 'F', 'E', 'E', 'E', 'E', 'E', 'E', 'E', 'E', 'Am', 'Am', 'Am', 'Am', 'Am', 'Am', 'Am', 'Am', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'F', 'F', 'F', 'F', 'F', 'F', 'F', 'F', 'E', 'E', 'E', 'E', 'E', 'E', 'E', 'E']))

In [12]:
# visualize the encoded output
num_conditional_tokens = len(encoded_meta)

seq = [0]
ctx = np.array(seq + encoded_meta[: num_conditional_tokens - 1], dtype=np.int32)[
    :, np.newaxis
]
context = torch.from_numpy(ctx).to(device).type(torch.long)
_, init_mems = model.forward_generate(context, mems=None)
init_seq = seq + encoded_meta[:num_conditional_tokens]
init_seq, init_mems.shape

([0, 584, 602, 627, 635, 639, 646, 651, 694, 704, 730, 737],
 torch.Size([7, 11, 1, 500]))

In [13]:
# generate the MIDI
pipeline.inference_task(
    model=model,
    input_data=input_data,
    inference_cfg=inference_cfg
)
sequences = pipeline.inference_task.execute(encoded_meta)
sequences, type(sequences)

2023-11-22 19:35:34,655 | INFO | ComMU | Generating the idx: 1
2023-11-22 19:35:36,694 | INFO | ComMU | correct_length: 8
2023-11-22 19:35:36,696 | INFO | ComMU | [0, 584, 602, 627, 635, 639, 646, 651, 694, 704, 730, 737, 2, 432, 199, 432, 165, 75, 431, 2, 432, 285, 464, 162, 77, 338, 496, 162, 79, 335, 528, 168, 82, 335, 2, 432, 267, 432, 165, 84, 367, 496, 162, 87, 367, 2, 432, 258, 432, 170, 91, 335, 464, 170, 91, 367, 528, 170, 91, 335, 2, 432, 199, 432, 170, 84, 431, 2, 432, 285, 2, 432, 267, 2, 432, 258, 1]


([[0,
   584,
   602,
   627,
   635,
   639,
   646,
   651,
   694,
   704,
   730,
   737,
   2,
   432,
   199,
   432,
   165,
   75,
   431,
   2,
   432,
   285,
   464,
   162,
   77,
   338,
   496,
   162,
   79,
   335,
   528,
   168,
   82,
   335,
   2,
   432,
   267,
   432,
   165,
   84,
   367,
   496,
   162,
   87,
   367,
   2,
   432,
   258,
   432,
   170,
   91,
   335,
   464,
   170,
   91,
   367,
   528,
   170,
   91,
   335,
   2,
   432,
   199,
   432,
   170,
   84,
   431,
   2,
   432,
   285,
   2,
   432,
   267,
   2,
   432,
   258,
   1]],
 list)

In [14]:
# postprocess the generated MIDI
pipeline.postprocess_task(input_data=input_data)
midi_dict = pipeline.postprocess_task.execute(sequences=sequences)

In [19]:
DATA_DIR

'./dataset'

In [6]:


midi_file = f"{DATA_DIR}/reduced_encoding/midi/52K/main_melody_violin_mid_high/main_melody_violin_mid_high_000.mid"
# midi = mid_parser.MidiFile(f"{DATA_DIR}/original-commu/midi/sub_melody_flute_mid_high_002_d55c4ef1-b270-4b74-89e1-c9a7b3ac23dc.mid")



# from mido import MidiFile, MidiTrack, Message

# def change_instrument(midi_file_path, new_instrument):
#     # Load MIDI file
#     midi_file = MidiFile(midi_file_path)

#     # Iterate through each track in the MIDI file
#     for i, track in enumerate(midi_file.tracks):
#         for msg in track:
#             # Check if the message is a program change message
#             if msg.type == 'program_change':
#                 # Update the program (instrument) number
#                 msg.program = new_instrument

#     # Save the modified MIDI file
#     midi_file.save(midi_file_path)

# # Replace 'your_file.mid' with the path to your MIDI file
# change_instrument(midi_file, new_instrument=41)  


In [11]:
# using miditoolkit to read in an midi file and check the instrument
from miditoolkit.midi import parser as mid_parser
from miditoolkit.midi import containers as ct
from miditoolkit.midi import utils as mid_utils

midi_file = f"{DATA_DIR}/reduced_encoding/midi/52K/main_melody_violin_mid_high/main_melody_violin_mid_high_000_test.mid"

midi = mid_parser.MidiFile(midi_file)
midi.instruments[0] #.program = 0
# save the midi file
# midi.dump(f"{DATA_DIR}/reduced_encoding/midi/52K/main_melody_violin_mid_high/main_melody_violin_mid_high_000_test.mid")

Instrument(program=0, is_drum=False, name="")

In [15]:
# from pathlib import Path
from typing import List

from miditoolkit import MidiFile

# from commu.midi_generator.container import TransXlInputData
# from commu.preprocessor.encoder import EventSequenceEncoder
from commu.preprocessor.utils.container import MidiInfo


In [16]:
import copy
from typing import Dict

import miditoolkit
import numpy as np


from commu.preprocessor.encoder.event_tokens import base_event, TOKEN_OFFSET
from commu.preprocessor.utils.constants import (
    BPM_INTERVAL,
    DEFAULT_POSITION_RESOLUTION,
    DEFAULT_TICKS_PER_BEAT,
    VELOCITY_INTERVAL,
    SIG_TIME_MAP,
    KEY_NUM_MAP
)

NUM_VELOCITY_BINS = int(128 / VELOCITY_INTERVAL)
DEFAULT_VELOCITY_BINS = np.linspace(2, 127, NUM_VELOCITY_BINS, dtype=int)

class Item(object):
    def __init__(self, name, start, end, velocity, pitch):
        self.name = name
        self.start = start
        self.end = end
        self.velocity = velocity
        self.pitch = pitch

    def __repr__(self):
        return "Item(name={}, start={}, end={}, velocity={}, pitch={})".format(
            self.name, self.start, self.end, self.velocity, self.pitch
        )


class Event(object):
    def __init__(self, name, time, value, text):
        self.name = name
        self.time = time
        self.value = value
        self.text = text

    def __repr__(self):
        return "Event(name={}, time={}, value={}, text={})".format(
            self.name, self.time, self.value, self.text
        )


def mk_remi_map():
    event = copy.deepcopy(base_event)
    for i in range(DEFAULT_POSITION_RESOLUTION):
        event.append(f"Note Duration_{i}")
    for i in range(1, DEFAULT_POSITION_RESOLUTION + 1):
        event.append(f"Position_{i}/{DEFAULT_POSITION_RESOLUTION}")

    event2word = {k: v for k, v in zip(event, range(2, len(event) + 2))}
    word2event = {v: k for k, v in zip(event, range(2, len(event) + 2))}

    return event2word, word2event

def add_flat_chord2map(event2word: Dict):
    flat_chord = ["Chord_ab:", "Chord_bb:", "Chord_db:", "Chord_eb:", "Chord_gb:"]
    scale = [
        "",
        "maj",
        "maj7",
        "7",
        "dim",
        "dim7",
        "+",
        "m",
        "m7",
        "sus4",
        "7sus4",
        "m6",
        "m7b5",
        "sus2",
        "add2",
        "6",
        "madd2",
        "mM7",
    ]

    flat_chords = []
    for c in flat_chord:
        for s in scale:
            flat_chords.append(c + s)

    for c in flat_chords:
        scale = c.split(":")[1]
        key = c.split(":")[0].split("_")[1][0]
        c = c.replace(":", "")
        if c.startswith("Chord_ab"):
            if scale == "" or scale == "maj" or scale == "6":
                event2word[c] = event2word["Chord_g#"]
            elif scale == "maj7" or scale == "add2" or scale == "sus2":
                event2word[c] = event2word["Chord_g#maj7"]
            elif scale == "7":
                event2word[c] = event2word["Chord_g#7"]
            elif scale == "dim" or scale == "dim7":
                event2word[c] = event2word["Chord_g#dim"]
            elif scale == "+":
                event2word[c] = event2word["Chord_g#+"]
            elif scale == "m" or scale == "m6" or scale == "mM7":
                event2word[c] = event2word["Chord_g#m"]
            elif scale == "m7" or scale == "madd2":
                event2word[c] = event2word["Chord_g#m7"]
            elif scale == "sus4" or scale == "7sus4":
                event2word[c] = event2word["Chord_g#sus4"]
            elif scale == "m7b5":
                event2word[c] = event2word["Chord_g#m7b5"]
        else:
            if scale == "" or scale == "maj" or scale == "6":
                new_key = chr(ord(key) - 1)
                word = "Chord_" + new_key + "#"
                event2word[c] = event2word[word]
            elif scale == "maj7" or scale == "add2" or scale == "sus2":
                new_key = chr(ord(key) - 1)
                word = "Chord_" + new_key + "#maj7"
                event2word[c] = event2word[word]
            elif scale == "7":
                new_key = chr(ord(key) - 1)
                word = "Chord_" + new_key + "#7"
                event2word[c] = event2word[word]
            elif scale == "dim" or scale == "dim7":
                new_key = chr(ord(key) - 1)
                word = "Chord_" + new_key + "#dim"
                event2word[c] = event2word[word]
            elif scale == "+":
                new_key = chr(ord(key) - 1)
                word = "Chord_" + new_key + "#+"
                event2word[c] = event2word[word]
            elif scale == "m" or scale == "m6" or scale == "mM7":
                new_key = chr(ord(key) - 1)
                word = "Chord_" + new_key + "#m"
                event2word[c] = event2word[word]
            elif scale == "m7" or scale == "madd2":
                new_key = chr(ord(key) - 1)
                word = "Chord_" + new_key + "#m7"
                event2word[c] = event2word[word]
            elif scale == "sus4" or scale == "7sus4":
                new_key = chr(ord(key) - 1)
                word = "Chord_" + new_key + "#sus4"
                event2word[c] = event2word[word]
            elif scale == "m7b5":
                new_key = chr(ord(key) - 1)
                word = "Chord_" + new_key + "#m7b5"
                event2word[c] = event2word[word]

    return event2word

def abstract_chord_types(event2word):
    chord = ["Chord_a:", "Chord_b:", "Chord_c:", "Chord_d:", "Chord_e:", "Chord_f:", "Chord_g:"]
    scale = ["7sus4", "m6", "sus2", "add2", "dim7", "6", "madd2", "mM7", ]

    chords = []
    for c in chord:
        for s in scale:
            chords.append(c + s)

    for c in chords:
        scale = c.split(":")[1]
        key = c.split(":")[0].split("_")[1][0]
        c = c.replace(":", "")
        if scale == "7sus4":
            word = "Chord_" + key + "sus4"
            event2word[c] = event2word[word]
        if scale == "m6":
            word = "Chord_" + key + "m"
            event2word[c] = event2word[word]
        if scale == "sus2" or scale == "add2":
            word = "Chord_" + key + "maj7"
            event2word[c] = event2word[word]
        if scale == "6":
            word = "Chord_" + key
            event2word[c] = event2word[word]
        if scale == "dim7":
            word = "Chord_" + key + "dim"
            event2word[c] = event2word[word]
        if scale == "madd2" or scale == "mM7":
            word = "Chord_" + key + "m7"
            event2word[c] = event2word[word]

    return event2word

def extract_events(
    input_path,
    duration_bins,
    ticks_per_bar=None,
    ticks_per_beat=None,
    chord_progression=None,
    num_measures=None,
    is_incomplete_measure=None,
):
    note_items = read_items(input_path)
    max_time = note_items[-1].end
    if not chord_progression[0]:
        return None
    else:
        items = note_items
    groups = group_items(items, max_time, ticks_per_bar)
    events = item2event(groups, duration_bins)
    beats_per_bar = int(ticks_per_bar/ticks_per_beat)

    if chord_progression:
        new_chords = chord_progression[0]
        events = insert_chord_on_event(
            events,
            new_chords,
            ticks_per_bar,
            num_measures,
            is_incomplete_measure,
            beats_per_bar,
        )

    return events

def read_items(file_path):
    midi_obj = miditoolkit.midi.parser.MidiFile(file_path)
    note_items = []
    notes = midi_obj.instruments[0].notes
    notes.sort(key=lambda x: (x.start, x.pitch))
    for note in notes:
        note_items.append(
            Item(
                name="Note",
                start=note.start,
                end=note.end,
                velocity=note.velocity,
                pitch=note.pitch,
            )
        )
    note_items.sort(key=lambda x: x.start)
    return note_items

def group_items(items, max_time, ticks_per_bar):
    items.sort(key=lambda x: x.start)
    downbeats = np.arange(0, max_time + ticks_per_bar, ticks_per_bar)
    groups = []
    for db1, db2 in zip(downbeats[:-1], downbeats[1:]):
        insiders = []
        for item in items:
            if (item.start >= db1) and (item.start < db2):
                insiders.append(item)
        if not insiders:
            insiders.append(Item(name="None", start=None, end=None, velocity=None, pitch="NN"))
        overall = [db1] + insiders + [db2]
        groups.append(overall)
    return groups

def item2event(groups, duration_bins):
    events = []
    n_downbeat = 0
    for i in range(len(groups)):
        if "NN" in [item.pitch for item in groups[i][1:-1]]:
            continue
        bar_st, bar_et = groups[i][0], groups[i][-1]
        n_downbeat += 1
        if groups[i][1].name == "Chord":
            events.append(Event(name="Bar", time=bar_st, value=None, text="{}".format(n_downbeat)))
        for item in groups[i][1:-1]:
            # position
            flags = np.linspace(bar_st, bar_et, DEFAULT_POSITION_RESOLUTION, endpoint=False)
            index = np.argmin(abs(flags - item.start))
            events.append(
                Event(
                    name="Position",
                    time=item.start,
                    value="{}/{}".format(index + 1, DEFAULT_POSITION_RESOLUTION),
                    text="{}".format(item.start),
                )
            )
            if item.name == "Note":
                # velocity
                velocity_index = (
                    np.searchsorted(DEFAULT_VELOCITY_BINS, item.velocity, side="right") - 1
                )
                events.append(
                    Event(
                        name="Note Velocity",
                        time=item.start,
                        value=velocity_index,
                        text="{}/{}".format(item.velocity, DEFAULT_VELOCITY_BINS[velocity_index]),
                    )
                )
                # pitch
                events.append(
                    Event(
                        name="Note On",
                        time=item.start,
                        value=item.pitch,
                        text="{}".format(item.pitch),
                    )
                )
                # duration
                duration = item.end - item.start
                index = np.argmin(abs(duration_bins - duration))
                events.append(
                    Event(
                        name="Note Duration",
                        time=item.start,
                        value=index,
                        text="{}/{}".format(duration, duration_bins[index]),
                    )
                )
            elif item.name == "Chord":
                events.append(
                    Event(
                        name="Chord",
                        time=item.start,
                        value=item.pitch,
                        text="{}".format(item.pitch),
                    )
                )
    return events

def insert_chord_on_event(
    events,
    chord_progression,
    tick_per_bar,
    num_measures,
    is_incomplete_measure,
    beats_per_bar,
):
    chord_idx_lst, chords = detect_chord(chord_progression, beats_per_bar)
    start_time = tick_per_bar * is_incomplete_measure
    chord_events = []
    for i in range(num_measures):
        chord_events.append(
            Event(name="Bar", time=i * tick_per_bar, value=None, text="{}".format(i + 1))
        )
        while chord_idx_lst and chord_idx_lst[0] < i + 1 - is_incomplete_measure:
            chord_position = chord_idx_lst.pop(0)
            chord_time = int(chord_position * tick_per_bar + start_time)
            chord = chords.pop(0)
            chord_events.append(
                Event(
                    name="Position",
                    time=chord_time,
                    value="{}/{}".format(
                        int((chord_position - i + is_incomplete_measure) * DEFAULT_POSITION_RESOLUTION) + 1,
                        DEFAULT_POSITION_RESOLUTION
                    ),
                    text=chord_time,
                )
            )
            chord_events.append(
                Event(name="Chord",
                      time=chord_time,
                      value=chord.split("/")[0].split("(")[0],
                      text=chord.split("/")[0].split("(")[0])
            )

    inserted_events = chord_events + events
    inserted_events.sort(key=lambda x: x.time)
    return inserted_events

def detect_chord(chord_progression, beats_per_bar):
    chords_per_bar = beats_per_bar * 2
    num_measures = int(len(chord_progression)/chords_per_bar)
    split_by_bar = np.array_split(np.array(chord_progression), num_measures)
    chord_idx = []
    chord_name = []
    for bar_idx, bar in enumerate(split_by_bar):
        for c_idx, chord in enumerate(bar):
            chord = chord.lower()
            if c_idx == 0 or chord != chord_name[-1]:
                chord_idx.append(bar_idx + c_idx / chords_per_bar)
                chord_name.append(chord)
    return chord_idx, chord_name

def word_to_event(words, word2event):
    events = []
    for word in words:
        try:
            event_name, event_value = word2event[word].split("_")
        except KeyError:
            if word == 1:
                # 따로 디코딩 되지 않는 EOS
                continue
            else:
                print(f"OOV: {word}")
            continue
        events.append(Event(event_name, None, event_value, None))
    return events

def write_midi(
    midi_info,
    word2event,
    duration_bins,
    beats_per_bar,
):
    events = word_to_event(midi_info.event_seq, word2event)
    # get downbeat and note (no time)
    temp_notes = []
    temp_chords = []
    for i in range(len(events) - 3):
        if events[i].name == "Bar" and i > 0:
            temp_notes.append("Bar")
            temp_chords.append("Bar")
        elif (
            events[i].name == "Position"
            and events[i + 1].name == "Note Velocity"
            and events[i + 2].name == "Note On"
            and events[i + 3].name == "Note Duration"
        ):
            # start time and end time from position
            position = int(events[i].value.split("/")[0]) - 1
            # velocity
            index = int(events[i + 1].value)
            velocity = int(DEFAULT_VELOCITY_BINS[index])
            # pitch
            pitch = int(events[i + 2].value)
            # duration
            index = int(events[i + 3].value)
            duration = duration_bins[index]
            # adding
            temp_notes.append([position, velocity, pitch, duration])
        elif events[i].name == "Position" and events[i + 1].name == "Chord":
            position = int(events[i].value.split("/")[0]) - 1
            temp_chords.append([position, events[i + 1].value])
    # get specific time for notes
    ticks_per_beat = DEFAULT_TICKS_PER_BEAT
    ticks_per_bar = ticks_per_beat * beats_per_bar
    notes = []
    current_bar = 0
    for note in temp_notes:
        if note == "Bar":
            current_bar += 1
        else:
            position, velocity, pitch, duration = note
            current_bar_st = current_bar * ticks_per_bar
            current_bar_et = (current_bar + 1) * ticks_per_bar
            flags = np.linspace(
                int(current_bar_st),
                int(current_bar_et),
                int(DEFAULT_POSITION_RESOLUTION),
                endpoint=False,
                dtype=int,
            )
            st = flags[position]
            # duration (end time)
            et = st + duration
            notes.append(miditoolkit.Note(velocity, pitch, st, et))
    print(notes)
    # get specific time for chords
    if len(temp_chords) > 0:
        chords = []
        current_bar = 0
        for chord in temp_chords:
            if chord == "Bar":
                current_bar += 1
            else:
                position, value = chord
                # position (start time)
                current_bar_st = current_bar * ticks_per_bar
                current_bar_et = (current_bar + 1) * ticks_per_bar
                flags = np.linspace(
                    current_bar_st, current_bar_et, DEFAULT_POSITION_RESOLUTION, endpoint=False, dtype=int
                )
                st = flags[position]
                chords.append([st, value])
    midi = miditoolkit.midi.parser.MidiFile()
    numerator, denominator = SIG_TIME_MAP[
        midi_info.time_signature
        - (TOKEN_OFFSET.TS.value + 1)
    ].split("/")
    ts = miditoolkit.midi.containers.TimeSignature(
        numerator=int(numerator), denominator=int(denominator), time=0
    )
    key_num = midi_info.audio_key - (TOKEN_OFFSET.KEY.value + 1)
    ks = miditoolkit.KeySignature(
        key_name=KEY_NUM_MAP[key_num],
        time=0)
    midi.time_signature_changes.append(ts)
    midi.key_signature_changes.append(ks)
    midi.ticks_per_beat = DEFAULT_TICKS_PER_BEAT
    # write instrument
    inst = miditoolkit.midi.containers.Instrument(0, is_drum=False)
    inst.notes = notes
    midi.instruments.append(inst)
    # write bpm info
    tempo_changes = []
    tempo_changes.append(
        miditoolkit.midi.containers.TempoChange(
            (midi_info.bpm - TOKEN_OFFSET.BPM.value)
            * BPM_INTERVAL,
            0,
        )
    )
    midi.tempo_changes = tempo_changes

    # write chord into marker
    if len(temp_chords) > 0:
        for c in chords:
            midi.markers.append(miditoolkit.midi.containers.Marker(text=c[1], time=c[0]))

    return midi


In [17]:
import math

import miditoolkit
import numpy as np

from commu.preprocessor.encoder import encoder_utils
from commu.preprocessor.encoder.event_tokens import TOKEN_OFFSET
from commu.preprocessor.utils.constants import (
    DEFAULT_POSITION_RESOLUTION,
    DEFAULT_TICKS_PER_BEAT,
    SIG_TIME_MAP
)

class EventSequenceEncoder:
    def __init__(self):
        self.event2word, self.word2event = encoder_utils.mk_remi_map()
        self.event2word = encoder_utils.add_flat_chord2map(self.event2word)
        self.event2word = encoder_utils.abstract_chord_types(self.event2word)
        self.position_resolution = DEFAULT_POSITION_RESOLUTION

    def encode(self, midi_paths, sample_info=None, for_cp=False):
        midi_file = miditoolkit.MidiFile(midi_paths)
        ticks_per_beat = midi_file.ticks_per_beat
        chord_progression = sample_info["chord_progressions"]
        num_measures = math.ceil(sample_info["num_measures"])
        numerator = int(sample_info["time_signature"].split("/")[0])
        denominator = int(sample_info["time_signature"].split("/")[1])
        is_incomplete_measure = sample_info["is_incomplete_measure"]

        beats_per_bar = numerator / denominator * 4
        ticks_per_bar = int(ticks_per_beat * beats_per_bar)
        duration_bins = np.arange(
            int(ticks_per_bar / self.position_resolution),
            ticks_per_bar + 1,
            int(ticks_per_bar / self.position_resolution),
            dtype=int,
        )

        events = encoder_utils.extract_events(
            midi_paths,
            duration_bins,
            ticks_per_bar=ticks_per_bar,
            ticks_per_beat=ticks_per_beat,
            chord_progression=chord_progression,
            num_measures=num_measures,
            is_incomplete_measure=is_incomplete_measure,
        )
        if for_cp:
            return events

        words = []
        for event in events:
            e = "{}_{}".format(event.name, event.value)
            if e in self.event2word:
                words.append(self.event2word[e])
            else:
                # OOV
                if event.name == "Note Velocity":
                    # replace with max velocity based on our training data
                    words.append(self.event2word["Note Velocity_63"])
                if event.name == "Note Duration":
                    # replace with max duration
                    words.append(self.event2word[f"Note Duration_{self.position_resolution-1}"])
                else:
                    # something is wrong
                    # you should handle it for your own purpose
                    print("OOV {}".format(e))
        words.append(TOKEN_OFFSET.EOS.value)  # eos token
        return np.array(words)

    def decode(
        self,
        midi_info,
    ):
        time_sig_word = midi_info.time_signature
        time_sig = SIG_TIME_MAP[time_sig_word - TOKEN_OFFSET.TS.value - 1]
        numerator = int(time_sig.split("/")[0])
        denominator = int(time_sig.split("/")[1])
        beats_per_bar = int(numerator/denominator * 4)

        ticks_per_bar = DEFAULT_TICKS_PER_BEAT * beats_per_bar

        duration_bins = np.arange(
            int(ticks_per_bar / self.position_resolution),
            ticks_per_bar + 1,
            int(ticks_per_bar / self.position_resolution),
            dtype=int,
        )

        decoded_midi = write_midi(
            midi_info,
            self.word2event,
            duration_bins=duration_bins,
            beats_per_bar=beats_per_bar,
        )


        return decoded_midi

In [18]:
def decode_event_sequence(
        generation_result: List[int],
        num_meta: int = 11
):
    encoded_meta = generation_result[1: num_meta + 1]
    event_sequence = generation_result[num_meta + 2:]
    decoder = EventSequenceEncoder()
    decoded_midi = decoder.decode(
        midi_info=MidiInfo(*encoded_meta, event_seq=event_sequence),
    )

    return decoded_midi

for idx, seq in enumerate(sequences):
    decoded_midi = decode_event_sequence(
        generation_result=seq,
    )


[Note(start=0, end=1920, pitch=72, velocity=69), Note(start=2400, end=2925, pitch=74, velocity=63), Note(start=2880, end=3360, pitch=76, velocity=63), Note(start=3360, end=3840, pitch=79, velocity=75), Note(start=3840, end=4800, pitch=81, velocity=69), Note(start=4800, end=5760, pitch=84, velocity=63), Note(start=5760, end=6240, pitch=88, velocity=79), Note(start=6240, end=7200, pitch=88, velocity=79), Note(start=7200, end=7680, pitch=88, velocity=79), Note(start=7680, end=9600, pitch=81, velocity=79)]
