In [13]:

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.utils.data
import tqdm
import muspy
import src.utils as utils
import src.representation as representation
import src.dataset as dataset
import src.music_x_transformers as music_x_transformers
import pathlib
import src.advUtils as advUtils
import shutil
import os
import ast
import re

In [4]:
#load configurations
train_args = utils.load_json("./pre_trained_models\mmt_sod_ape_training_logs.json")
encoding = representation.load_encoding("encoding.json")

sos = encoding["type_code_map"]["start-of-song"]
eos = encoding["type_code_map"]["end-of-song"]
beat_0 = encoding["beat_code_map"][0]
beat_4 = encoding["beat_code_map"][4]
beat_16 = encoding["beat_code_map"][16]

In [22]:
data_set = advUtils.convert_extract_load(train_args,encoding, repr_dir="../preparedData/repr/test")

data_loader = torch.utils.data.DataLoader(
    data_set,
    shuffle=True,
    num_workers=1,
    collate_fn=dataset.MusicDataset.collate,
)

<h1>Sequence Prediction</h1>

In [25]:
# Load model
device = torch.device("cpu")
print(f"Creating the model...")
model = music_x_transformers.MusicXTransformer(
    dim=train_args["dim"],
    encoding=encoding,
    depth=train_args["layers"],
    heads=train_args["heads"],
    max_seq_len=train_args["max_seq_len"],
    max_beat=train_args["max_beat"],
    rotary_pos_emb=train_args["rel_pos_emb"],
    use_abs_pos_emb=train_args["abs_pos_emb"],
    emb_dropout=train_args["dropout"],
    attn_dropout=train_args["dropout"],
    ff_dropout=train_args["dropout"],
).to(device)
model.load_state_dict(torch.load("../pre_trained_models/prediction/model_6000.pt", map_location=device))


Creating the model...


<All keys matched successfully>

In [26]:
def generate(n,sample_dir,model=model,modes=["unconditioned"],seq_len=1024,temperature=1,filter_logits="top_k",filter_thresh=0.9):
    with torch.no_grad():
        data_iter = iter(data_loader)
        for i in tqdm.tqdm(range(n), ncols=80):
            batch = next(data_iter)
            print("Generating based on",batch['name'])
            for mode in modes:
                if(mode=="unconditioned"):
                    tgt_start = torch.zeros((1, 1, 6), dtype=torch.long, device=device)
                    tgt_start[:, 0, 0] = sos
                elif(mode=="instrument_informed"):
                    prefix_len = int(np.argmax(batch["seq"][0, :, 1] >= beat_0))
                    tgt_start = batch["seq"][:1, :prefix_len].to(device)
                elif(mode=="4_beat"):
                    cond_len = int(np.argmax(batch["seq"][0, :, 1] >= beat_4))
                    tgt_start = batch["seq"][:1, :cond_len].to(device)
                elif(mode=="16_beat"):
                    cond_len = int(np.argmax(batch["seq"][0, :, 1] >= beat_16))
                    tgt_start = batch["seq"][:1, :cond_len].to(device)
                # Generate new samples
                generated = model.generate(
                    tgt_start,
                    seq_len,
                    eos_token=eos,
                    temperature=temperature,
                    filter_logits_fn=filter_logits,
                    filter_thres=filter_thresh,
                    monotonicity_dim=("type", "beat"),
                )
                generated_np = torch.cat((tgt_start, generated), 1).cpu().numpy()

                # Save the results
                advUtils.save_result(
                    f"{i}_{mode}", generated_np[0], sample_dir, encoding,savecsv=False,savetxt=False,savenpy=False,savepng=False,savejson=False
                )

In [27]:
generate(10,"../generatedSamples",seq_len=100,modes=["unconditioned","instrument_informed","beat_4","beat_16"])

  0%|                                                    | 0/10 [00:00<?, ?it/s]

Generating based on ['BeckerD-SS-7']


 10%|████▎                                      | 1/10 [01:58<17:47, 118.57s/it]

Generating based on ['BachJS-BWV861-2']


 20%|████████▊                                   | 2/10 [03:19<12:53, 96.63s/it]

Generating based on ['BuxtehudeD-BuxWV256-6']


 30%|█████████████▏                              | 3/10 [04:25<09:39, 82.72s/it]

Generating based on ['BachJS-BWV774']


 40%|█████████████████▌                          | 4/10 [05:39<07:53, 78.92s/it]

Generating based on ['LullyJB-LWV60-A1S01A']


 50%|██████████████████████                      | 5/10 [07:31<07:35, 91.17s/it]

Generating based on ['BeckerD-SS-28']


 60%|██████████████████████████▍                 | 6/10 [09:20<06:28, 97.24s/it]

Generating based on ['BuxtehudeD-BuxWV254-2']


 70%|██████████████████████████████▊             | 7/10 [11:03<04:57, 99.11s/it]

Generating based on ['BachJS-BWV776']


 80%|██████████████████████████████████▍        | 8/10 [12:50<03:22, 101.43s/it]

Generating based on ['BachJS-BWV988-V05']


 90%|██████████████████████████████████████▋    | 9/10 [14:30<01:41, 101.12s/it]

Generating based on ['BachJS-BWV1080-C16']


100%|███████████████████████████████████████████| 10/10 [16:17<00:00, 97.77s/it]
