In [11]:

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.utils.data
import tqdm
import muspy
import src.utils as utils
import src.representation as representation
import src.dataset as dataset
import src.music_x_transformers as music_x_transformers
import pathlib
import src.advUtils as advUtils


In [12]:
#load configurations
train_args = utils.load_json(".\pre_trained_models\mmt_sod_ape_training_logs.json")
encoding = representation.load_encoding("encoding.json")

sos = encoding["type_code_map"]["start-of-song"]
eos = encoding["type_code_map"]["end-of-song"]
beat_0 = encoding["beat_code_map"][0]
beat_4 = encoding["beat_code_map"][4]
beat_16 = encoding["beat_code_map"][16]

In [13]:
# Load training/testing/demo Data
testdata = advUtils.convert_extract_load(train_args,encoding, json_dir = "./data/test/json",repr_dir="./data/test/repr")

test_loader = torch.utils.data.DataLoader(
    testdata,
    shuffle=True,
    num_workers=1,
    collate_fn=dataset.MusicDataset.collate,
)

12 12
12 12


In [14]:
# Load model
device = torch.device("cpu")
print(f"Creating the model...")
model = music_x_transformers.MusicXTransformer(
    dim=train_args["dim"],
    encoding=encoding,
    depth=train_args["layers"],
    heads=train_args["heads"],
    max_seq_len=train_args["max_seq_len"],
    max_beat=train_args["max_beat"],
    rotary_pos_emb=train_args["rel_pos_emb"],
    use_abs_pos_emb=train_args["abs_pos_emb"],
    emb_dropout=train_args["dropout"],
    attn_dropout=train_args["dropout"],
    ff_dropout=train_args["dropout"],
).to(device)

model.load_state_dict(torch.load("./pre_trained_models/mmt_sod_ape_best_model.pt", map_location=device))
print(f"Loaded the pretrained model weights")
model.eval()

Creating the model...


Loaded the pretrained model weights


MusicXTransformer(
  (decoder): MusicAutoregressiveWrapper(
    (net): MusicTransformerWrapper(
      (token_emb): ModuleList(
        (0): TokenEmbedding(
          (emb): Embedding(5, 512)
        )
        (1): TokenEmbedding(
          (emb): Embedding(257, 512)
        )
        (2): TokenEmbedding(
          (emb): Embedding(13, 512)
        )
        (3): TokenEmbedding(
          (emb): Embedding(129, 512)
        )
        (4): TokenEmbedding(
          (emb): Embedding(33, 512)
        )
        (5): TokenEmbedding(
          (emb): Embedding(65, 512)
        )
      )
      (pos_emb): AbsolutePositionalEmbedding(
        (emb): Embedding(1024, 512)
      )
      (emb_dropout): Dropout(p=0.2, inplace=False)
      (project_emb): Identity()
      (attn_layers): Decoder(
        (layers): ModuleList(
          (0): ModuleList(
            (0): ModuleList(
              (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
              (1): None
              (2): None
    

In [15]:
def generate(n,sample_dir,modes=["unconditioned"],seq_len=1024,temperature=1,filter_logits="top_k",filter_thresh=0.9):
    with torch.no_grad():
        data_iter = iter(test_loader)
        for i in tqdm.tqdm(range(n), ncols=80):
            batch = next(data_iter)
            print("Generating based on",batch['name'])
            for mode in modes:
                if(mode=="unconditioned"):
                    tgt_start = torch.zeros((1, 1, 6), dtype=torch.long, device=device)
                    tgt_start[:, 0, 0] = sos
                elif(mode=="instrument_informed"):
                    prefix_len = int(np.argmax(batch["seq"][0, :, 1] >= beat_0))
                    tgt_start = batch["seq"][:1, :prefix_len].to(device)
                elif(mode=="4_beat"):
                    cond_len = int(np.argmax(batch["seq"][0, :, 1] >= beat_4))
                    tgt_start = batch["seq"][:1, :cond_len].to(device)
                elif(mode=="16_beat"):
                    cond_len = int(np.argmax(batch["seq"][0, :, 1] >= beat_16))
                    tgt_start = batch["seq"][:1, :cond_len].to(device)
                # Generate new samples
                generated = model.generate(
                    tgt_start,
                    seq_len,
                    eos_token=eos,
                    temperature=temperature,
                    filter_logits_fn=filter_logits,
                    filter_thres=filter_thresh,
                    monotonicity_dim=("type", "beat"),
                )
                generated_np = torch.cat((tgt_start, generated), 1).cpu().numpy()

                # Save the results
                advUtils.save_result(
                    f"{i}_{mode}", generated_np[0], sample_dir, encoding,savecsv=False,savetxt=False,savenpy=False,savepng=False,savejson=False
                )

In [16]:
#Test generation
generate(2,"./samples",seq_len=50,modes=["unconditioned","instrument_informed"])

  0%|                                                     | 0/2 [00:00<?, ?it/s]

Generating based on ['anglebert_fugue_3_(c)mccoy']


 50%|██████████████████████▌                      | 1/2 [00:39<00:39, 39.06s/it]

Generating based on ['albinoni_sonate_da_chiesa_6_(c)icking-archive']


100%|█████████████████████████████████████████████| 2/2 [01:03<00:00, 31.54s/it]


<h1>Transfer Learning</h1>