In [1]:
import sys
sys.path.append("..") # Adds higher directory to python modules path.

In [2]:
import argparse
from typing import Dict

from commu.midi_generator.generate_pipeline import MidiGenerationPipeline
from commu.preprocessor.utils import constants

In [3]:

def parse_args() -> Dict[str, argparse.ArgumentParser]:
    model_arg_parser = argparse.ArgumentParser(description="Model Arguments")
    input_arg_parser = argparse.ArgumentParser(description="Input Arguments")

    # Model Arguments
    model_arg_parser.add_argument("--checkpoint_dir", type=str)

    # Input Arguments
    input_arg_parser.add_argument("--output_dir", type=str, required=True)

    ## Input meta
    input_arg_parser.add_argument("--bpm", type=int)
    input_arg_parser.add_argument("--audio_key", type=str, choices=list(constants.KEY_MAP.keys()))
    input_arg_parser.add_argument("--time_signature", type=str, choices=list(constants.TIME_SIG_MAP.keys()))
    input_arg_parser.add_argument("--pitch_range", type=str, choices=list(constants.PITCH_RANGE_MAP.keys()))
    input_arg_parser.add_argument("--num_measures", type=float)
    input_arg_parser.add_argument(
        "--inst", type=str, choices=list(constants.INST_MAP.keys()),
    )
    input_arg_parser.add_argument(
        "--genre", type=str, default="cinematic", choices=list(constants.GENRE_MAP.keys())
    )
    input_arg_parser.add_argument(
        "--track_role", type=str, choices=list(constants.TRACK_ROLE_MAP.keys())
    )
    input_arg_parser.add_argument(
        "--rhythm", type=str, default="standard", choices=list(constants.RHYTHM_MAP.keys())
    )
    input_arg_parser.add_argument("--min_velocity", type=int, choices=range(1, 128))
    input_arg_parser.add_argument("--max_velocity", type=int, choices=range(1, 128))
    input_arg_parser.add_argument(
        "--chord_progression", type=str, help='Chord progression ex) C-C-E-E-G-G ...'
    )
    # Inference 시 필요 정보
    input_arg_parser.add_argument("--num_generate", type=int)
    input_arg_parser.add_argument("--top_k", type=int, default=32)
    input_arg_parser.add_argument("--temperature", type=float, default=0.95)

    arg_dict = {
        "model_args": model_arg_parser,
        "input_args": input_arg_parser
    }
    return arg_dict

In [4]:
# initialize model and input variables
model_args, _ = parse_args()["model_args"].parse_known_args(args=["--checkpoint_dir", "../checkpoints/checkpoint_best.pt"])
input_args, _ = parse_args()["input_args"].parse_known_args(args=["--output_dir", "../output",
                                                                  "--bpm", "120",
                                                                  "--audio_key", "aminor",
                                                                  "--time_signature", "4/4",
                                                                  "--pitch_range", "mid_high",
                                                                  "--num_measures", "8",
                                                                  "--inst", "acoustic_piano",
                                                                  "--genre", "cinematic",
                                                                  "--min_velocity", "60",
                                                                  "--max_velocity", "80",
                                                                  "--track_role", "main_melody",
                                                                  "--chord_progression", "Am-Am-Am-Am-Am-Am-Am-Am-G-G-G-G-G-G-G-G-F-F-F-F-F-F-F-F-E-E-E-E-E-E-E-E-Am-Am-Am-Am-Am-Am-Am-Am-G-G-G-G-G-G-G-G-F-F-F-F-F-F-F-F-E-E-E-E-E-E-E-E",
                                                                  "--num_generate", "3"])
input_args

Namespace(audio_key='aminor', bpm=120, chord_progression='Am-Am-Am-Am-Am-Am-Am-Am-G-G-G-G-G-G-G-G-F-F-F-F-F-F-F-F-E-E-E-E-E-E-E-E-Am-Am-Am-Am-Am-Am-Am-Am-G-G-G-G-G-G-G-G-F-F-F-F-F-F-F-F-E-E-E-E-E-E-E-E', genre='cinematic', inst='acoustic_piano', max_velocity=80, min_velocity=60, num_generate=3, num_measures=8.0, output_dir='../output', pitch_range='mid_high', rhythm='standard', temperature=0.95, time_signature='4/4', top_k=32, track_role='main_melody')

In [5]:
# check model argument
import torch

from commu.midi_generator.container import ModelArguments
from commu.midi_generator.model_initializer import ModelInitializeTask
from commu.midi_generator.info_preprocessor import PreprocessTask
from commu.midi_generator.midi_inferrer import InferenceTask
from commu.midi_generator.sequence_postprocessor import PostprocessTask

In [6]:
# instantiate model pipeline
pipeline = MidiGenerationPipeline()
pipeline.initialize_model(vars(model_args))
pipeline.initialize_generation()

In [7]:
# initialize model for interences
inference_cfg = pipeline.model_initialize_task.inference_cfg
model = pipeline.model_initialize_task.execute()

In [8]:
# encode input data
encoded_meta = pipeline.preprocess_task.execute(vars(input_args))
input_data = pipeline.preprocess_task.input_data
encoded_meta, input_data

([584, 623, 627, 635, 639, 642, 652, 684, 694, 720, 727],
 TransXlInputData(bpm=120, audio_key='aminor', time_signature='4/4', pitch_range='mid_high', num_measures=8.0, inst='acoustic_piano', genre='cinematic', min_velocity=60, max_velocity=80, track_role='main_melody', rhythm='standard', output_dir=WindowsPath('../output'), num_generate=3, top_k=32, temperature=0.95, chord_progression=['Am', 'Am', 'Am', 'Am', 'Am', 'Am', 'Am', 'Am', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'F', 'F', 'F', 'F', 'F', 'F', 'F', 'F', 'E', 'E', 'E', 'E', 'E', 'E', 'E', 'E', 'Am', 'Am', 'Am', 'Am', 'Am', 'Am', 'Am', 'Am', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'F', 'F', 'F', 'F', 'F', 'F', 'F', 'F', 'E', 'E', 'E', 'E', 'E', 'E', 'E', 'E']))

In [10]:
map_location = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(map_location)

In [11]:
import numpy as np

num_conditional_tokens = len(encoded_meta)

seq = [0]
ctx = np.array(seq + encoded_meta[: num_conditional_tokens - 1], dtype=np.int32)[
    :, np.newaxis
]
context = torch.from_numpy(ctx).to(device).type(torch.long)
_, init_mems = model.forward_generate(context, mems=None)
init_seq = seq + encoded_meta[:num_conditional_tokens]
init_seq, init_mems.shape


# def execute(self, encoded_meta) -> List[List[int]]:
#     num_conditional_tokens = len(encoded_meta)
#     idx = 0
#     sequences = []
#     while idx != self.input_data.num_generate:
#         with torch.no_grad():
#             logger.info("Generating the idx: " + str(idx + 1))
#             seq, mems = self.init_seq_and_mems(encoded_meta, num_conditional_tokens)
#             ## apply teaching reinforcement in this step   
#             seq = self.generate_sequence(seq, mems)
#             if seq is None:
#                 continue
#             if not self.validate_generated_sequence(seq):
#                 logger.error("Empty sequence generated")
#                 continue
#         sequences.append(seq)
#         idx += 1
#     return sequences


([0, 584, 623, 627, 635, 639, 642, 652, 684, 694, 720, 727],
 torch.Size([7, 11, 1, 500]))

In [12]:

def generate_sequence(self, seq, mems):
    logits = None
    teacher = TeacherForceTask(self.input_data)
    first_loop = True
    for _ in range(self.inference_cfg.GENERATION.generation_length):
        if seq[-1] == 1:
            break

        if teacher.next_tokens_forced:
            next_token = teacher.next_tokens_forced.pop(0)
            seq.append(next_token)
            logits, mems = self.calc_logits_and_mems(seq, mems)
            continue

        if teacher.no_sequence_appended:
            assert logits is not None
            teacher.no_sequence_appended = False
        elif first_loop:
            logits, _ = self.calc_logits_and_mems(seq, mems)
            first_loop = False
        else:
            logits, mems = self.calc_logits_and_mems(seq, mems)

        probs = self.calc_probs(logits)
        probs = self.apply_sampling(probs, teacher.wrong_tokens)

        # teacher forcing
        # in case with incomplete measure, trigger a flag after second bar token
        if not teacher.incomplete_filled:
            teacher.incomplete_filled = True if seq.count(TOKEN_OFFSET.BAR.value) > 1 else False

        # forcefully assign position 1/128 right after bar token
        if teacher.check_first_position(seq):
            teacher.teach_first_position()
            continue

        # in case there is one chord per bar
        if teacher.check_one_chord_per_bar_case(seq):
            teacher.teach_chord_token()
            continue

        # in case the chord changes within a bar
        if teacher.check_mul_chord_per_bar_case(seq):
            teacher.teach_chord_token()
            continue

        # teacher forcing followed by token inference so that we can check if the wrong token was generated
        try:
            token = self.infer_token(probs)
        except RuntimeError as e:
            logger.error(f"Sampling Error: {e}")
            seq = None
            break

        # generated token skipped necessary position
        if teacher.check_chord_position_passed(token):
            teacher.teach_chord_position()
            continue

        # wrong chord token generated
        if teacher.check_wrong_chord_token_generated(token):
            teacher.teach_wrong_chord_token(token)
            continue

        # eos generated but we got more chords to write
        if teacher.check_wrong_eos_generated(token):
            teacher.teach_remnant_chord()
            continue

        # bar token generated but num measures exceed
        if teacher.check_wrong_bar_token_generated(token):
            teacher.teach_eos()
            continue

        seq.append(token)

    try:
        teacher.validate_teacher_forced_sequence(seq)
    except Exception as error_message:
        logger.error(error_message)
        seq = None
    return seq

In [13]:
pipeline.inference_task(
    model=model,
    input_data=input_data,
    inference_cfg=inference_cfg
)
sequences = pipeline.inference_task.execute(encoded_meta)
sequences

2023-10-31 22:29:24,230 | INFO | ComMU | Generating the idx: 1
2023-10-31 22:29:25,973 | INFO | ComMU | correct_length: 8
2023-10-31 22:29:25,975 | INFO | ComMU | [0, 584, 623, 627, 635, 639, 642, 652, 684, 694, 720, 727, 2, 432, 199, 496, 162, 79, 335, 496, 162, 91, 335, 528, 165, 72, 335, 528, 162, 84, 335, 2, 432, 285, 432, 165, 77, 367, 496, 162, 86, 335, 496, 164, 98, 335, 528, 162, 77, 335, 528, 163, 89, 335, 2, 432, 267, 432, 165, 75, 367, 432, 165, 87, 367, 496, 165, 72, 359, 496, 165, 84, 359, 2, 432, 258, 432, 166, 79, 367, 432, 166, 91, 367, 496, 162, 79, 335, 496, 165, 91, 335, 528, 169, 79, 335, 528, 170, 91, 335, 2, 432, 199, 432, 165, 67, 431, 432, 169, 79, 431, 2, 432, 285, 432, 165, 74, 367, 432, 165, 86, 367, 496, 162, 82, 335, 496, 165, 94, 335, 528, 162, 77, 335, 528, 165, 89, 335, 2, 432, 267, 432, 165, 68, 399, 432, 165, 80, 399, 2, 432, 258, 432, 171, 71, 367, 432, 170, 83, 367, 1]
2023-10-31 22:29:25,976 | INFO | ComMU | Generating the idx: 2


[0, 584, 623, 627, 635, 639, 642, 652, 684, 694, 720, 727, 2, 432, 199, 496, 162, 79, 335, 496, 162, 91, 335, 528, 165, 72, 335, 528, 162, 84, 335, 2, 432, 285, 432, 165, 77, 367, 496, 162, 86, 335, 496, 164, 98, 335, 528, 162, 77, 335, 528, 163, 89, 335, 2, 432, 267, 432, 165, 75, 367, 432, 165, 87, 367, 496, 165, 72, 359, 496, 165, 84, 359, 2, 432, 258, 432, 166, 79, 367, 432, 166, 91, 367, 496, 162, 79, 335, 496, 165, 91, 335, 528, 169, 79, 335, 528, 170, 91, 335, 2, 432, 199, 432, 165, 67, 431, 432, 169, 79, 431, 2, 432, 285, 432, 165, 74, 367, 432, 165, 86, 367, 496, 162, 82, 335, 496, 165, 94, 335, 528, 162, 77, 335, 528, 165, 89, 335, 2, 432, 267, 432, 165, 68, 399, 432, 165, 80, 399, 2, 432, 258, 432, 171, 71, 367, 432, 170, 83, 367, 1]
[]


2023-10-31 22:29:29,868 | INFO | ComMU | correct_length: 8
2023-10-31 22:29:29,869 | INFO | ComMU | [0, 584, 623, 627, 635, 639, 642, 652, 684, 694, 720, 727, 2, 432, 199, 464, 168, 72, 319, 464, 169, 84, 319, 480, 170, 70, 319, 480, 170, 82, 319, 496, 170, 72, 319, 496, 170, 84, 319, 512, 170, 75, 319, 512, 170, 87, 319, 528, 170, 79, 319, 528, 170, 91, 319, 544, 170, 77, 319, 544, 170, 89, 319, 2, 432, 285, 464, 164, 70, 319, 464, 170, 82, 319, 480, 170, 74, 319, 480, 170, 86, 319, 496, 170, 72, 319, 496, 170, 84, 319, 512, 170, 70, 319, 512, 170, 82, 319, 528, 170, 72, 319, 528, 170, 84, 319, 544, 170, 74, 319, 544, 170, 86, 319, 2, 432, 267, 464, 164, 68, 319, 464, 170, 80, 319, 480, 170, 72, 319, 480, 170, 84, 319, 496, 170, 72, 319, 496, 170, 84, 319, 512, 170, 75, 319, 512, 170, 87, 319, 528, 170, 75, 319, 528, 170, 87, 319, 544, 170, 77, 319, 544, 170, 89, 319, 2, 432, 258, 464, 164, 67, 319, 464, 170, 79, 319, 480, 170, 70, 319, 480, 170, 82, 319, 496, 170, 72, 319, 496, 170, 

[0, 584, 623, 627, 635, 639, 642, 652, 684, 694, 720, 727, 2, 432, 199, 464, 168, 72, 319, 464, 169, 84, 319, 480, 170, 70, 319, 480, 170, 82, 319, 496, 170, 72, 319, 496, 170, 84, 319, 512, 170, 75, 319, 512, 170, 87, 319, 528, 170, 79, 319, 528, 170, 91, 319, 544, 170, 77, 319, 544, 170, 89, 319, 2, 432, 285, 464, 164, 70, 319, 464, 170, 82, 319, 480, 170, 74, 319, 480, 170, 86, 319, 496, 170, 72, 319, 496, 170, 84, 319, 512, 170, 70, 319, 512, 170, 82, 319, 528, 170, 72, 319, 528, 170, 84, 319, 544, 170, 74, 319, 544, 170, 86, 319, 2, 432, 267, 464, 164, 68, 319, 464, 170, 80, 319, 480, 170, 72, 319, 480, 170, 84, 319, 496, 170, 72, 319, 496, 170, 84, 319, 512, 170, 75, 319, 512, 170, 87, 319, 528, 170, 75, 319, 528, 170, 87, 319, 544, 170, 77, 319, 544, 170, 89, 319, 2, 432, 258, 464, 164, 67, 319, 464, 170, 79, 319, 480, 170, 70, 319, 480, 170, 82, 319, 496, 170, 72, 319, 496, 170, 84, 319, 512, 170, 72, 319, 512, 170, 84, 319, 528, 170, 74, 319, 528, 170, 86, 319, 544, 170, 77, 3

2023-10-31 22:29:30,981 | INFO | ComMU | correct_length: 8
2023-10-31 22:29:30,983 | INFO | ComMU | [0, 584, 623, 627, 635, 639, 642, 652, 684, 694, 720, 727, 2, 432, 199, 496, 162, 84, 335, 528, 162, 87, 331, 2, 432, 285, 432, 162, 89, 399, 528, 166, 91, 334, 2, 432, 267, 432, 166, 84, 431, 2, 432, 258, 432, 163, 86, 399, 528, 161, 87, 331, 2, 432, 199, 432, 161, 84, 335, 496, 162, 87, 334, 528, 162, 91, 331, 2, 432, 285, 432, 161, 89, 399, 2, 432, 267, 2, 432, 258, 432, 170, 86, 399, 1]


[0, 584, 623, 627, 635, 639, 642, 652, 684, 694, 720, 727, 2, 432, 199, 496, 162, 84, 335, 528, 162, 87, 331, 2, 432, 285, 432, 162, 89, 399, 528, 166, 91, 334, 2, 432, 267, 432, 166, 84, 431, 2, 432, 258, 432, 163, 86, 399, 528, 161, 87, 331, 2, 432, 199, 432, 161, 84, 335, 496, 162, 87, 334, 528, 162, 91, 331, 2, 432, 285, 432, 161, 89, 399, 2, 432, 267, 2, 432, 258, 432, 170, 86, 399, 1]
[]


[[0,
  584,
  623,
  627,
  635,
  639,
  642,
  652,
  684,
  694,
  720,
  727,
  2,
  432,
  199,
  496,
  162,
  79,
  335,
  496,
  162,
  91,
  335,
  528,
  165,
  72,
  335,
  528,
  162,
  84,
  335,
  2,
  432,
  285,
  432,
  165,
  77,
  367,
  496,
  162,
  86,
  335,
  496,
  164,
  98,
  335,
  528,
  162,
  77,
  335,
  528,
  163,
  89,
  335,
  2,
  432,
  267,
  432,
  165,
  75,
  367,
  432,
  165,
  87,
  367,
  496,
  165,
  72,
  359,
  496,
  165,
  84,
  359,
  2,
  432,
  258,
  432,
  166,
  79,
  367,
  432,
  166,
  91,
  367,
  496,
  162,
  79,
  335,
  496,
  165,
  91,
  335,
  528,
  169,
  79,
  335,
  528,
  170,
  91,
  335,
  2,
  432,
  199,
  432,
  165,
  67,
  431,
  432,
  169,
  79,
  431,
  2,
  432,
  285,
  432,
  165,
  74,
  367,
  432,
  165,
  86,
  367,
  496,
  162,
  82,
  335,
  496,
  165,
  94,
  335,
  528,
  162,
  77,
  335,
  528,
  165,
  89,
  335,
  2,
  432,
  267,
  432,
  165,
  68,
  399,
  432,
  165,
  80,
  399,
  

In [13]:
from commu.midi_generator.container import TransXlInputData
from commu.preprocessor.encoder import EventSequenceEncoder
from commu.preprocessor.utils.container import MidiInfo

In [14]:
generation_result = sequences[0]
num_meta=11

In [15]:
# break down the4 meta and the generated sequence
encoded_meta = generation_result[1: num_meta + 1]
event_sequence = generation_result[num_meta + 2:]
decoder = EventSequenceEncoder()


In [16]:
decoded_midi = decoder.decode(
    midi_info=MidiInfo(*encoded_meta, event_seq=event_sequence),
)
decoded_midi

ticks per beat: 480
max tick: 0
tempo changes: 1
time sig: 1
key sig: 1
markers: 8
lyrics: False
instruments: 1

In [17]:
for idx, seq in enumerate(sequences):
    print(idx, seq)


0 [0, 584, 623, 627, 635, 639, 642, 652, 684, 694, 720, 727, 2, 432, 199, 496, 162, 79, 322, 496, 162, 91, 322, 512, 162, 79, 322, 512, 162, 91, 322, 528, 162, 75, 322, 528, 162, 87, 322, 544, 162, 77, 319, 544, 162, 89, 319, 2, 432, 285, 432, 162, 74, 367, 432, 162, 86, 367, 496, 162, 77, 335, 496, 162, 89, 335, 528, 162, 74, 335, 528, 162, 86, 335, 2, 432, 267, 432, 162, 75, 367, 432, 162, 87, 367, 496, 162, 72, 324, 496, 162, 84, 324, 512, 162, 75, 322, 512, 162, 87, 322, 528, 162, 72, 319, 528, 162, 84, 319, 544, 162, 75, 319, 544, 162, 87, 319, 2, 432, 258, 432, 162, 74, 367, 432, 162, 86, 367, 496, 162, 72, 367, 496, 162, 84, 367, 2, 432, 199, 432, 163, 75, 335, 432, 163, 87, 335, 464, 163, 77, 335, 464, 163, 89, 335, 496, 162, 75, 335, 496, 162, 87, 335, 528, 162, 72, 319, 528, 163, 84, 319, 544, 162, 74, 319, 544, 162, 86, 319, 2, 432, 285, 432, 162, 74, 367, 432, 162, 86, 367, 496, 162, 70, 367, 496, 162, 82, 367, 2, 432, 267, 432, 162, 75, 431, 432, 162, 87, 431, 2, 432, 258,

In [18]:

pipeline.postprocess_task(input_data=input_data)
midi_dict = pipeline.postprocess_task.execute(
    sequences=sequences,
)

In [55]:
import json
midi_dict[0], type(midi_dict[0])

(ticks per beat: 480
 max tick: 0
 tempo changes: 1
 time sig: 1
 key sig: 1
 markers: 8
 lyrics: False
 instruments: 1,
 miditoolkit.midi.parser.MidiFile)

In [62]:
import miditoolkit


In [76]:
midi_file = miditoolkit.midi.parser.MidiFile("../dataset/commu_midi/train/raw/commu00001.mid")
midi_file

ticks per beat: 480
max tick: 15362
tempo changes: 1
time sig: 1
key sig: 1
markers: 8
lyrics: False
instruments: 1

In [80]:

import tempfile

In [82]:
original_midi_file = midi_dict[0]

In [84]:
import io
# Create a temporary file to write the MIDI data
temp_midi_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mid")
original_midi_file.dump(temp_midi_file.name)



# # Move the cursor to the start of the file-like object
# midi_buffer.seek(0)

# # display the MIDI file
# midi_buffer,midi_buffer

temp_midi_file.name

'C:\\Users\\ktrin\\AppData\\Local\\Temp\\tmpi415_0hw.mid'