In [None]:
from data_processing.Datasets import get_maestro_dataset, collate_fn
from torch.utils.data import DataLoader
import torch
from model.transformer import Transformer
from functools import partial
import muspy
import copy

# NOTE: THIS DOWNLOADS SOMETHING, NEEDED TO GENERATE MUSIC WITH MUSPY
muspy.download_musescore_soundfont()


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# data hyperparams
REPRESENTATION = "pitch"  # set to either 'event' or 'pitch'
SEQ_LEN = 1000
BATCH_SIZE = 32
EVAL_BATCH_SIZE = 10

# model hyperparams
# num tokens is based on the data representation
if REPRESENTATION == "pitch":
    NUM_TOKENS = 130
else:
    NUM_TOKENS = 388
DIM_MODEL = 512
NUM_HEADS = 2
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3
DROPOUT_P = 0.1


In [None]:
# load device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
train_data, test_data = get_maestro_dataset(
    "data/maestro", representation=REPRESENTATION
)


In [None]:
# create model, optim, criterion
model = Transformer(
    num_tokens=NUM_TOKENS,
    dim_model=DIM_MODEL,
    num_heads=NUM_HEADS,
    num_encoder_layers=NUM_ENCODER_LAYERS,
    num_decoder_layers=NUM_DECODER_LAYERS,
    dropout_p=DROPOUT_P,
).to(device)


## Prepare Primer

This cell loads a small sample from the validation dataset that is passed into the transformer to be used as a baseline for generating more music.


In [None]:
# prepare primer
primer = torch.tensor(test_data.__getitem__(10), dtype=torch.long)

labels = copy.deepcopy(primer)

print(primer.shape)

primer = primer[100:150]
print(primer.shape)
labels = labels[100:150]
print(labels.shape)

# display primer pianoroll
primer_arr = primer.clone().detach().numpy()

if REPRESENTATION == "pitch":
    music_primer = muspy.from_pitch_representation(primer_arr)
elif REPRESENTATION == "event":
    music_primer = muspy.from_event_representation(primer_arr)

muspy.show_pianoroll(music_primer)


In [None]:
# load saved model params
model.load_state_dict(torch.load("./results/transformer/model_params/model-pitch-may-11.pth", map_location=torch.device(device)))
model.to(device)

# set to test
model.eval()

# pass in primer
with torch.no_grad():
    data = model.generate(primer, device, labels, target_seq_length=1000)

# decode the returned info
data = data.to('cpu').detach().numpy()

if REPRESENTATION == "pitch":
    music = muspy.from_pitch_representation(data)
elif REPRESENTATION == "event":
    music = muspy.from_event_representation(data)
else:
    print(
        "Please use either event or pitch based representation, depending on what the model was trained on."
    )
    print("If you don't know what either of those are, set REPRESENTATION='event'")

muspy.write_audio("./results/transformer/music/test-pitch.wav", music, "wav")


## Visualize the music


In [None]:
muspy.show_pianoroll(music)
