# Data

In [None]:
# Pull assets
!git submodule update --init --recursive

## Process MIDI files
Convert each MIDI file into a sequence of events as described in the paper, and saved as a npy file.

### Classical String Quartets

In [None]:
import os
import json
import numpy as np
from tqdm import tqdm
from utilities.midi_io import MIDI
from utilities import mkdir

# Specify the time quantization set and time unit
from models import CSQ_TIME_QUANTIZATION
time_unit = "crotchet"

# Specify data augmentation parameters
# pitch_shifts = [-3, -2, -1, 0, 1, 2, 3]
# time_scales = [0.95, 0.975, 1.0, 1.025, 1.05]

# Uncomment the following lines for no augmentation
pitch_shifts = [0]
time_scales = [1.0]

# Get the list of midi files
index_filename = "aimgef-assets/CSQ/fast_first/index.json"
with open(index_filename, mode="r") as f:
    midi_list = json.load(f)

# Start processing
for midi in tqdm(midi_list):
    valid_set = [1211, 1219, 1240, 1827, 1893, 2322, 2368]
    split = "train" if midi["ID"] not in valid_set else "validation"

    for ps in pitch_shifts:
        for ts in time_scales:
            mid = MIDI(
                time_quantization=CSQ_TIME_QUANTIZATION,
                time_unit=time_unit
            )
            # The dimensions of sequence is [num_tracks, num_tokens]
            # num_tracks is 1 if merge_tracks is True
            sequence = mid.process(
                filename=os.path.join(
                    "aimgef-assets/CSQ/fast_first/midi",
                    f'{midi["ID"]}.mid'
                ),
                ignore_velocity=True,
                merge_tracks=True,
                pitch_shift=ps,
                time_scale=ts
            )
            filename = os.path.join(
                "datasets",
                "CSQ",
                split,
                f'{midi["ID"]}_p[{ps}]_t[{ts}].npy'
            )
            mkdir(filename)
            np.save(filename, np.array(sequence[0], dtype=int))


### Sanity check
For several reasons, the decoded mid file will not always be the same as the original:
- If tracks are merged during encoding, note on and off pairs could be mismatched
- The defined TIME_QUANTIZATION is not fine enough

In [None]:
import numpy as np
from utilities.midi_io import MIDI
from models import CSQ_TIME_QUANTIZATION

mid = MIDI(
    time_quantization=CSQ_TIME_QUANTIZATION,
    time_unit="crotchet",
)
seq = np.load("datasets/CSQ/validation/1240_p[0]_t[1.0].npy")
pm = mid.to_midi(seq, tempo=128)
pm.write("scratch/1240_p[0]_t[1.0].mid")

### Classical Piano Improvisations

In [None]:
# Download maestro dataset (zip)
import wget
maestro_url = "https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip"
wget.download(maestro_url, out="scratch")

In [None]:
# Unzip
import zipfile
with zipfile.ZipFile("scratch/maestro-v3.0.0-midi.zip", 'r') as zip_ref:
    zip_ref.extractall("scratch")

In [None]:
import os
import numpy as np
from pathlib import Path
from tqdm import tqdm
from utilities.midi_io import MIDI
from utilities import mkdir, get_maestro_midi_list

# Specify the time quantization set and time unit
from models import CPI_TIME_QUANTIZATION
time_unit = "time"

# Specify data augmentation parameters
# pitch_shifts = [-3, -2, -1, 0, 1, 2, 3]
# time_scales = [0.95, 0.975, 1.0, 1.025, 1.05]

# Uncomment the following lines for no augmentation
pitch_shifts = [0]
time_scales = [1.0]

# Get the list of midi files
for split in ["test", "validation", "train"]:
    midi_list = get_maestro_midi_list(
        "scratch/maestro-v3.0.0/maestro-v3.0.0.json",
        split
    )

    # Start processing
    for midi in tqdm(midi_list):
        for ps in pitch_shifts:
            for ts in time_scales:
                mid = MIDI(
                    time_quantization=CPI_TIME_QUANTIZATION,
                    time_unit=time_unit
                )
                # The dimensions of sequence is [num_tracks, num_tokens]
                # num_tracks is 1 if merge_tracks is True
                sequence = mid.process(
                    filename=os.path.join(
                        "scratch/maestro-v3.0.0",
                        midi
                    ),
                    ignore_velocity=True,
                    merge_tracks=True,
                    pitch_shift=ps,
                    time_scale=ts
                )
                filename = os.path.join(
                    "datasets",
                    "CPI",
                    split,
                    f'{Path(midi).stem}_p[{ps}]_t[{ts}].npy'
                )
                mkdir(filename)
                np.save(filename, np.array(sequence, dtype=int))


# Model Training

## CSQ-Transformer

In [None]:
import torch
from models.trainer import MTTrainer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Specify the trainer
cfg_path = "configs/csq_transformer.yaml"
trainer = MTTrainer(cfg_path=cfg_path, device=device)
trainer.train()