In [None]:
from data_processing.Datasets import get_maestro_dataset, collate_fn, get_nes_dataset
from torch.utils.data import DataLoader
import torch
from model.transformer import Transformer
from functools import partial
import muspy
import copy

# NOTE: THIS DOWNLOADS SOMETHING, NEEDED TO GENERATE MUSIC WITH MUSPY
muspy.download_musescore_soundfont()


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# data hyperparams
SEQ_LEN = 1500
BATCH_SIZE = 32
EVAL_BATCH_SIZE = 10

# model hyperparams
NUM_TOKENS = 388
DIM_MODEL = 512
NUM_HEADS = 2
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3
DROPOUT_P = 0.1

REPRESENTATION = "event"  # set to either 'event' or 'pitch'


In [None]:
# load device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
train_data, test_data = get_maestro_dataset(
    "data/maestro", representation=REPRESENTATION
)


In [None]:
# Build dataloaders
train_dataloader = DataLoader(
    dataset=train_data,
    batch_size=BATCH_SIZE,
    collate_fn=partial(collate_fn, seq_len=SEQ_LEN, device=device),
    shuffle=False,
)

val_dataloader = DataLoader(
    dataset=test_data,
    batch_size=EVAL_BATCH_SIZE,
    collate_fn=partial(collate_fn, seq_len=SEQ_LEN, device=device),
    shuffle=False,
)


In [None]:
# create model, optim, criterion
model = Transformer(
    num_tokens=NUM_TOKENS,
    dim_model=DIM_MODEL,
    num_heads=NUM_HEADS,
    num_encoder_layers=NUM_ENCODER_LAYERS,
    num_decoder_layers=NUM_DECODER_LAYERS,
    dropout_p=DROPOUT_P,
).to(device)


## Prepare Primer

This cell loads a small sample from the validation dataset that is passed into the transformer to be used as a baseline for generating more music.


In [None]:
# prepare primer
primer = torch.tensor(test_data.__getitem__(10), dtype=torch.long)

labels = copy.deepcopy(primer)

# best so far has been item 10, with primer of 200 -> 225
primer = primer[200:250].to(device)
print(primer.shape)
labels = labels[200:250].to(device)
print(labels.shape)


In [None]:
# load saved model params
model.load_state_dict(torch.load("./results/transformer/model_params/model-event-may-11.pth", map_location=torch.device(device)))
model.to(device)

# set to test
model.eval()

# pass in primer
with torch.no_grad():
    data = model.generate(primer.to(device), device, labels.to(device), target_seq_length=800)

    # decode the returned info
    data = data.to('cpu').clone().detach().numpy()

    music = muspy.from_event_representation(data)

    muspy.write_audio("./results/transformer/music/event-15-epoch.wav", music, "wav")


## Visualize the music


In [None]:
muspy.show_pianoroll(music)
