# augment dataset

augment a given collection of MIDI files by applying each possible combination of the following transformations to each file:

1. 12 semitone transpose
2. NUM_BEATS beat time shift
3. velocity scaling (5 steps on the range [0.75, 1.25])

total augmentation factor: $12 * 8 * 5 = 480$

## setup

In [1]:
import os
import hashlib
from pathlib import Path
import pretty_midi
import numpy as np
import itertools
import redis
from concurrent.futures import ProcessPoolExecutor, as_completed
from rich.progress import track

from typing import List

In [None]:
# filesystem
redis_url = "redis://localhost:6379"
dataset_name = "careful"
input_dir = os.path.join("..", "data", "datasets", f"{dataset_name}")
output_dir = os.path.join("..", "data", "outputs")
files = [f for f in os.listdir(input_dir) if f.endswith(".mid")]

num_processes = 8  # os.cpu_count()
split_files = np.array_split(files, num_processes)  # type: ignore

# transformations
NUM_SEMITONES = 12
NUM_BEATS = 8
semis = list(range(NUM_SEMITONES))
beats = list(range(NUM_BEATS))
vels = [0.75, 0.875, 1.0, 1.125, 1.25]
transformation_table = [list(p) for p in itertools.product(semis, beats, vels)]

## helper functions

In [None]:
from matplotlib import pyplot as plt


def plot_images(
    images,
    titles,
    shape=None,
    main_title=None,
    set_axis: str = "off",
    bpm: int = 80,
    beats: int = NUM_BEATS,
    draw_beats=True,
) -> None:
    """Plot images vertically with a vertical line at each beat.

    Args:
        images: A list of images to plot.
        titles: Titles for each subplot.
        shape: The shape of the grid for subplots.
        main_title: The main title for all subplots.
        set_axis: Whether to display axis ('on' or 'off').
        bpm: Beats per minute of the MIDI file.
        beats: Total number of beats in the MIDI file.
    """
    plt.style.use("dark_background")

    if shape is None:
        shape = [len(images), 1]

    # calculate spacing between beats in seconds
    beat_interval = 60.0 / bpm  # seconds per beat

    plt.figure(figsize=(12, 12))

    if main_title:
        plt.suptitle(main_title)
    for num_plot in range(len(images)):
        plt.subplot(shape[0], shape[1], num_plot + 1)
        plt.imshow(
            np.squeeze(images[num_plot]),
            aspect="auto",
            origin="lower",
            cmap="magma",
            vmin=0.0,
            vmax=127.0,
            interpolation="nearest",
        )

        # draw vertical lines for beats
        if draw_beats:
            for beat in range(beats + 1):
                plt.axvline(x=beat * beat_interval * 100, color="green", linestyle=":")

        plt.title(titles[num_plot])
        plt.axis(set_axis)

    plt.tight_layout()
    plt.show()

In [None]:
def transform(
    midi_path: str, semitones: int, beats: int, vel: float
) -> pretty_midi.PrettyMIDI:
    """
    Transpose and shift a MIDI file by a specified number of semitones and beats.

    Args:
        midi_path (str): The path to the MIDI file.
        semitones (int): Number of semitones to transpose the MIDI file.
        beats (int): Number of beats to shift the MIDI events.

    Returns:
        pretty_midi.PrettyMIDI: The modified MIDI file.
    """
    midi_data = pretty_midi.PrettyMIDI(midi_path)
    tempo = int(Path(midi_path).stem.split("-")[1])
    beats_per_second = tempo / 60.0
    shift_seconds = 1 / beats_per_second
    new_midi = pretty_midi.PrettyMIDI()

    # shift
    for instrument in midi_data.instruments:
        new_inst = pretty_midi.Instrument(
            program=instrument.program, is_drum=instrument.is_drum
        )
        for note in instrument.notes:
            # shift the start and end times of each note
            shifted_start = (note.start + shift_seconds * beats) % (
                NUM_BEATS / beats_per_second
            )
            shifted_end = (note.end + shift_seconds * beats) % (
                NUM_BEATS / beats_per_second
            )
            if shifted_end < shifted_start:  # handle wrapping around the cycle
                shifted_end += NUM_BEATS / beats_per_second
            new_note = pretty_midi.Note(
                velocity=note.velocity * vel,  # vel scale
                pitch=note.pitch + semitones,  # transpose
                start=shifted_start,
                end=shifted_end,
            )
            new_inst.notes.append(new_note)
        new_midi.instruments.append(new_inst)

    return new_midi

In [None]:
def augment(
    redis_url, midi_paths: List[str], output_path: str, transformations, index: int
) -> int:
    r = redis.Redis.from_url(redis_url)

    print(f"SUBR{index:02d} processing {len(midi_paths)} files")

    # all_prs = []
    for file_name in track(
        midi_paths,
        description=f"[{index:02d}] augmenting files...",
        refresh_per_second=1,
        update_period=1.0,
    ):
        with r.pipeline() as pipeline:
            file_path = os.path.join(input_dir, file_name)
            pipeline.set(
                f"prs:{file_name[:-4]}:s00b00v1.00",
                pretty_midi.PrettyMIDI(file_path).get_piano_roll().tobytes(),
            )

            for semi, beat, vel in transformations:
                pipeline.set(
                    f"prs:{file_name[:-4]}:s{semi:02d}b{beat:02d}v{vel:.02f}",
                    transform(file_path, semi, beat, vel).get_piano_roll().tobytes(),
                )

            pipeline.execute()

    return index

## run


In [None]:
# with ProcessPoolExecutor() as executor:
#     futures = {
#         executor.submit(
#             augment, redis_url, chunk, output_dir, transformation_table, i
#         ): chunk
#         for i, chunk in enumerate(split_files)
#     }

#     for future in as_completed(futures):
#         result = future.result()
#         print(f"process {result} complete")

augment(redis_url, files[:1], output_dir, transformation_table, 0)

In [4]:
lens = {}
path = "../data/outputs"
for file in os.listdir(path):
    pr = pretty_midi.PrettyMIDI(os.path.join(path, file)).get_piano_roll()
    if pr.shape[1] not in lens:
        lens[pr.shape[1]] = 0
    lens[pr.shape[1]] += 1
lens

{706: 1488,
 916: 252,
 1059: 24,
 632: 1284,
 744: 900,
 683: 1752,
 597: 1020,
 732: 972,
 961: 72,
 855: 492,
 827: 1008,
 769: 600,
 738: 792,
 721: 996,
 904: 336,
 664: 564,
 772: 648,
 532: 216,
 820: 936,
 760: 1152,
 750: 3348,
 800: 540,
 839: 252,
 705: 1488,
 680: 1632,
 511: 624,
 753: 1896,
 629: 1656,
 600: 2004,
 641: 792,
 498: 264,
 907: 372,
 610: 1836,
 739: 936,
 722: 1092,
 871: 696,
 747: 2184,
 719: 1188,
 746: 936,
 864: 360,
 869: 636,
 697: 1716,
 813: 384,
 698: 3096,
 876: 468,
 694: 1380,
 581: 372,
 755: 1404,
 823: 864,
 745: 1260,
 657: 864,
 595: 936,
 591: 804,
 981: 276,
 687: 1932,
 576: 564,
 773: 816,
 538: 756,
 783: 516,
 623: 408,
 714: 1032,
 704: 1656,
 550: 408,
 561: 132,
 631: 972,
 663: 492,
 792: 1212,
 573: 396,
 768: 624,
 840: 420,
 603: 1416,
 540: 936,
 589: 972,
 572: 540,
 936: 300,
 794: 912,
 673: 1368,
 492: 708,
 862: 324,
 676: 1812,
 607: 1848,
 994: 84,
 679: 1440,
 672: 1140,
 675: 2220,
 598: 1140,
 756: 1320,
 512: 456,


In [5]:
sorted(lens.keys())

[107,
 201,
 216,
 229,
 266,
 289,
 291,
 294,
 296,
 297,
 302,
 306,
 314,
 315,
 317,
 324,
 327,
 349,
 357,
 364,
 366,
 371,
 377,
 383,
 384,
 385,
 388,
 390,
 392,
 395,
 397,
 398,
 400,
 407,
 409,
 413,
 414,
 416,
 417,
 418,
 419,
 420,
 427,
 429,
 431,
 432,
 437,
 439,
 441,
 443,
 444,
 445,
 446,
 447,
 448,
 449,
 450,
 451,
 452,
 453,
 454,
 455,
 456,
 457,
 458,
 459,
 460,
 461,
 462,
 463,
 464,
 465,
 466,
 467,
 468,
 469,
 470,
 471,
 472,
 473,
 474,
 475,
 476,
 477,
 478,
 479,
 480,
 481,
 482,
 483,
 484,
 485,
 486,
 487,
 488,
 489,
 490,
 491,
 492,
 493,
 494,
 495,
 496,
 497,
 498,
 499,
 500,
 501,
 502,
 503,
 504,
 505,
 506,
 507,
 508,
 509,
 510,
 511,
 512,
 513,
 514,
 515,
 516,
 517,
 518,
 519,
 520,
 521,
 522,
 523,
 524,
 525,
 526,
 527,
 528,
 529,
 530,
 531,
 532,
 533,
 534,
 535,
 536,
 537,
 538,
 539,
 540,
 541,
 542,
 543,
 544,
 545,
 546,
 547,
 548,
 549,
 550,
 551,
 552,
 553,
 554,
 555,
 556,
 557,
 558,
 559,
 560