In [None]:
from data_processing.Datasets import get_maestro_dataset, collate_fn
from torch.utils.data import DataLoader
import torch
from model.transformer import Transformer
from functools import partial
import muspy

# NOTE: THIS DOWNLOADS SOMETHING, NEEDED TO GENERATE MUSIC WITH MUSPY
muspy.download_musescore_soundfont()

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# training hyperparams
EPOCHS = 5
LEARNING_RATE = 0.01

# data hyperparams
SEQ_LEN = 700
BATCH_SIZE = 32
EVAL_BATCH_SIZE = 10

# model hyperparams
NUM_TOKENS = 388
DIM_MODEL = 512
NUM_HEADS = 2
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3
DROPOUT_P = 0.1

REPRESENTATION = 'event' # set to either 'event' or 'pitch'


In [None]:
# load device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
train_data, test_data = get_maestro_dataset("data/maestro", representation=REPRESENTATION)


In [None]:
# Build dataloaders
train_dataloader = DataLoader(
    dataset=train_data,
    batch_size=BATCH_SIZE,
    collate_fn=partial(collate_fn, seq_len=SEQ_LEN, device=device),
    shuffle=False,
)

val_dataloader = DataLoader(
    dataset=test_data,
    batch_size=EVAL_BATCH_SIZE,
    collate_fn=partial(collate_fn, seq_len=SEQ_LEN, device=device),
    shuffle=False,
)

In [None]:
# create model, optim, criterion
model = Transformer(
    num_tokens=NUM_TOKENS,
    dim_model=DIM_MODEL,
    num_heads=NUM_HEADS,
    num_encoder_layers=NUM_ENCODER_LAYERS,
    num_decoder_layers=NUM_DECODER_LAYERS,
    dropout_p=DROPOUT_P,
).to(device)


In [None]:
# load saved model params
model.load_state_dict(
    torch.load("./results/transformer/model_params/local_trained_model_may-10.pth")
)

# set to test
model.eval()

# get some data, set a sequence length
with torch.no_grad():
    # get first index from val
    batch = next(iter(val_dataloader))

    primer = batch[0]
    primer = primer.long().to(device)
    labels = batch[2].long().to(device)

    primer = primer.permute(1, 0)
    labels = labels.permute(1, 0)
    primer = torch.tensor(primer[:1], dtype=torch.long)
    labels = torch.tensor(labels[:1], dtype=torch.long)
    primer = primer.permute(1, 0)
    labels = labels.permute(1, 0)

    primer = primer[:50]
    labels = primer[:50]

    data = model.generate(primer, device, labels, target_seq_length=1000)

# decode the returned info
data = data.detach().numpy()

if REPRESENTATION == 'pitch':
    music = muspy.from_pitch_representation(data)
elif REPRESENTATION == 'event':
    music = muspy.from_event_representation(data)
else:
    print("Please use either event or pitch based representation, depending on what the model was trained on.")
    print("If you don't know what either of those are, set REPRESENTATION='event'")

muspy.write_audio("./results/transformer/music/test.wav", music, "wav")


## Visualize the music

In [None]:
muspy.show_pianoroll(music)