In [None]:

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.utils.data
import tqdm
import muspy
import src.utils as utils
import src.representation as representation
import src.dataset as dataset
import src.music_x_transformers as music_x_transformers
import pathlib
import src.advUtils as advUtils
import shutil


In [None]:
#load configurations
train_args = utils.load_json(".\pre_trained_models\mmt_sod_ape_training_logs.json")
encoding = representation.load_encoding("encoding.json")

sos = encoding["type_code_map"]["start-of-song"]
eos = encoding["type_code_map"]["end-of-song"]
beat_0 = encoding["beat_code_map"][0]
beat_4 = encoding["beat_code_map"][4]
beat_16 = encoding["beat_code_map"][16]

In [54]:
# Load training/testing/demo Data
data_set = advUtils.convert_extract_load(train_args,encoding, json_dir = "./data/test/json",repr_dir="./data/test/repr")

data_loader = torch.utils.data.DataLoader(
    data_set,
    shuffle=True,
    num_workers=1,
    collate_fn=dataset.MusicDataset.collate,
)

test_loader = data_loader
train_loader = data_loader
valid_loader = data_loader

In [63]:
# Load model
device = torch.device("cpu")
print(f"Creating the model...")
model = music_x_transformers.MusicXTransformer(
    dim=train_args["dim"],
    encoding=encoding,
    depth=train_args["layers"],
    heads=train_args["heads"],
    max_seq_len=train_args["max_seq_len"],
    max_beat=train_args["max_beat"],
    rotary_pos_emb=train_args["rel_pos_emb"],
    use_abs_pos_emb=train_args["abs_pos_emb"],
    emb_dropout=train_args["dropout"],
    attn_dropout=train_args["dropout"],
    ff_dropout=train_args["dropout"],
).to(device)

origModel = music_x_transformers.MusicXTransformer(
    dim=train_args["dim"],
    encoding=encoding,
    depth=train_args["layers"],
    heads=train_args["heads"],
    max_seq_len=train_args["max_seq_len"],
    max_beat=train_args["max_beat"],
    rotary_pos_emb=train_args["rel_pos_emb"],
    use_abs_pos_emb=train_args["abs_pos_emb"],
    emb_dropout=train_args["dropout"],
    attn_dropout=train_args["dropout"],
    ff_dropout=train_args["dropout"],
).to(device)

firstTrain = model
origModel.load_state_dict(torch.load("./pre_trained_models/mmt_sod_ape_best_model.pt", map_location=device))
model.load_state_dict(torch.load("./pre_trained_models/mmt_sod_ape_best_model.pt", map_location=device))

print(f"Loaded the pretrained model weights")
model.eval()

Creating the model...
Loaded the pretrained model weights


MusicXTransformer(
  (decoder): MusicAutoregressiveWrapper(
    (net): MusicTransformerWrapper(
      (token_emb): ModuleList(
        (0): TokenEmbedding(
          (emb): Embedding(5, 512)
        )
        (1): TokenEmbedding(
          (emb): Embedding(257, 512)
        )
        (2): TokenEmbedding(
          (emb): Embedding(13, 512)
        )
        (3): TokenEmbedding(
          (emb): Embedding(129, 512)
        )
        (4): TokenEmbedding(
          (emb): Embedding(33, 512)
        )
        (5): TokenEmbedding(
          (emb): Embedding(65, 512)
        )
      )
      (pos_emb): AbsolutePositionalEmbedding(
        (emb): Embedding(1024, 512)
      )
      (emb_dropout): Dropout(p=0.2, inplace=False)
      (project_emb): Identity()
      (attn_layers): Decoder(
        (layers): ModuleList(
          (0): ModuleList(
            (0): ModuleList(
              (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
              (1): None
              (2): None
    

In [7]:
def generate(n,sample_dir,model=model,modes=["unconditioned"],seq_len=1024,temperature=1,filter_logits="top_k",filter_thresh=0.9):
    with torch.no_grad():
        data_iter = iter(test_loader)
        for i in tqdm.tqdm(range(n), ncols=80):
            batch = next(data_iter)
            print("Generating based on",batch['name'])
            for mode in modes:
                if(mode=="unconditioned"):
                    tgt_start = torch.zeros((1, 1, 6), dtype=torch.long, device=device)
                    tgt_start[:, 0, 0] = sos
                elif(mode=="instrument_informed"):
                    prefix_len = int(np.argmax(batch["seq"][0, :, 1] >= beat_0))
                    tgt_start = batch["seq"][:1, :prefix_len].to(device)
                elif(mode=="4_beat"):
                    cond_len = int(np.argmax(batch["seq"][0, :, 1] >= beat_4))
                    tgt_start = batch["seq"][:1, :cond_len].to(device)
                elif(mode=="16_beat"):
                    cond_len = int(np.argmax(batch["seq"][0, :, 1] >= beat_16))
                    tgt_start = batch["seq"][:1, :cond_len].to(device)
                # Generate new samples
                generated = model.generate(
                    tgt_start,
                    seq_len,
                    eos_token=eos,
                    temperature=temperature,
                    filter_logits_fn=filter_logits,
                    filter_thres=filter_thresh,
                    monotonicity_dim=("type", "beat"),
                )
                generated_np = torch.cat((tgt_start, generated), 1).cpu().numpy()

                # Save the results
                advUtils.save_result(
                    f"{i}_{mode}", generated_np[0], sample_dir, encoding,savecsv=False,savetxt=False,savenpy=False,savepng=False,savejson=False
                )

In [8]:
#Test generation
#generate(2,"./samples",seq_len=50,modes=["unconditioned","instrument_informed"])

<h1>Transfer Learning</h1>

In [69]:
model.decoder.net.to_logits

ModuleList(
  (0): Linear(in_features=512, out_features=5, bias=True)
  (1): Linear(in_features=512, out_features=257, bias=True)
  (2): Linear(in_features=512, out_features=13, bias=True)
  (3): Linear(in_features=512, out_features=129, bias=True)
  (4): Linear(in_features=512, out_features=33, bias=True)
  (5): Linear(in_features=512, out_features=65, bias=True)
)

In [77]:
#Freeze All network layers
for m in model.parameters():
    m.requires_grad = False

new_heads = model.decoder.net.to_logits
new_heads.requires_grad = True

for child in new_heads.children():
    child.reset_parameters()
    for param in child.parameters():
        param.requires_grad = True

print(model.decoder.net.to_logits)

# layers_to_unfreeze = ["decoder.net.to_logits.5.weight","decoder.net.to_logits.4.weight","decoder.net.to_logits.3.weight","decoder.net.to_logits.2.weight","decoder.net.to_logits.1.weight","decoder.net.to_logits.0.weight","decoder.net.norm.weight"]

# for m in model.named_parameters():
#     state_dict = model.state_dict()
#     if(m[0] in layers_to_unfreeze):
#       newLayer = torch.rand(m[1].shape,requires_grad=True)
#       if(m[0]=="decoder.net.norm.weight"):
#          newLayer=torch.add(newLayer,1)
#       else:
#          newLayer=torch.add(newLayer,-0.5)
#       state_dict[m[0]] = newLayer
#       model.load_state_dict(state_dict)

# for m in model.named_parameters():
#     if(m[0] in layers_to_unfreeze):
#        m[1].requires_grad=True

n_parameters = sum(p.numel() for p in model.parameters())
n_trainables = sum(
   p.numel() for p in model.parameters() if p.requires_grad
)
print(f"Number of parameters: {n_parameters}")
print(f"Number of trainable parameters: {n_trainables}")




ModuleList(
  (0): Linear(in_features=512, out_features=5, bias=True)
  (1): Linear(in_features=512, out_features=257, bias=True)
  (2): Linear(in_features=512, out_features=13, bias=True)
  (3): Linear(in_features=512, out_features=129, bias=True)
  (4): Linear(in_features=512, out_features=33, bias=True)
  (5): Linear(in_features=512, out_features=65, bias=True)
)
Number of parameters: 19944950
Number of trainable parameters: 257526


In [64]:
#comparative generation:
generate(2,"./samples",seq_len=50,model=origModel,modes=["unconditioned","instrument_informed","4_beat","16_beat"])
generate(2,"./samples2",seq_len=50,model=firstTrain,modes=["unconditioned","instrument_informed","4_beat","16_beat"])

  0%|                                                     | 0/2 [00:00<?, ?it/s]

Generating based on ['albinoni_sonate_da_chiesa_6_(c)icking-archive']


 50%|██████████████████████▌                      | 1/2 [00:48<00:48, 48.43s/it]

Generating based on ['anglebert_fugue_3_(c)mccoy']


100%|█████████████████████████████████████████████| 2/2 [01:49<00:00, 54.67s/it]
  0%|                                                     | 0/2 [00:00<?, ?it/s]

Generating based on ['anglebert_fugue_3_(c)mccoy']


 50%|██████████████████████▌                      | 1/2 [00:21<00:21, 21.22s/it]

Generating based on ['albinoni_sonate_da_chiesa_6_(c)icking-archive']


100%|█████████████████████████████████████████████| 2/2 [00:34<00:00, 17.04s/it]


In [78]:
#Training the model
def get_lr_multiplier(
    step, warmup_steps, decay_end_steps, decay_end_multiplier
):
    """Return the learning rate multiplier with a warmup and decay schedule.

    The learning rate multiplier starts from 0 and linearly increases to 1
    after `warmup_steps`. After that, it linearly decreases to
    `decay_end_multiplier` until `decay_end_steps` is reached.

    """
    if step < warmup_steps:
        return (step + 1) / warmup_steps
    if step > decay_end_steps:
        return decay_end_multiplier
    position = (step - warmup_steps) / (decay_end_steps - warmup_steps)
    return 1 - (1 - decay_end_multiplier) * position


def train(out_dir,model):
    """Main function."""
    # Parse the command-line arguments

    # Make sure the output directory exists
    pathlib.Path(out_dir).mkdir(exist_ok=True)
    pathlib.Path(out_dir+"/checkpoints").mkdir(exist_ok=True)
    # Get the specified device
    device = torch.device("cpu")
    

    # Summarize the model
    n_parameters = sum(p.numel() for p in model.parameters())
    n_trainables = sum(
        p.numel() for p in model.parameters() if p.requires_grad
    )
    print(f"Number of parameters: {n_parameters}")
    print(f"Number of trainable parameters: {n_trainables}")

    # Create the optimizer
    optimizer = torch.optim.Adam(model.parameters(), train_args["learning_rate"])
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda step: get_lr_multiplier(
            step,
            train_args["lr_warmup_steps"],
            train_args["lr_decay_steps"],
            train_args["lr_decay_multiplier"],
        ),
    )

    # Create a file to record losses
    loss_csv = open(out_dir+"/loss.csv", "w")
    loss_csv.write(
        "step,train_loss,valid_loss,type_loss,beat_loss,position_loss,"
        "pitch_loss,duration_loss,instrument_loss\n"
    )
    # Initialize variables
    step = 0
    min_val_loss = float("inf")
    if train_args["early_stopping"]:
        count_early_stopping = 0

    # Iterate for the specified number of steps
    train_iterator = iter(train_loader)
    while step < train_args["steps"]:

        # Training
        print(f"Training...")
        model.train()
        recent_losses = []

        for batch in (pbar := tqdm.tqdm(range(train_args["valid_steps"]), ncols=80)):
            # Get next batch
            try:
                batch = next(train_iterator)
            except StopIteration:
                # Reinitialize dataset iterator
                train_iterator = iter(train_loader)
                batch = next(train_iterator)

            # Get input and output pair
            seq = batch["seq"].to(device)
            mask = batch["mask"].to(device)

            # Update the model parameters
            optimizer.zero_grad()
            loss = model(seq, mask=mask)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                model.parameters(), train_args["grad_norm_clip"]
            )
            optimizer.step()
            scheduler.step()

            # Compute the moving average of the loss
            recent_losses.append(float(loss))
            if len(recent_losses) > 10:
                del recent_losses[0]
            train_loss = np.mean(recent_losses)
            pbar.set_postfix(loss=f"{train_loss:8.4f}")

            step += 1

        # Release GPU memory right away
        del seq, mask

        # Validation
        print(f"Validating...")
        model.eval()
        with torch.no_grad():
            total_loss = 0
            total_losses = [0] * 6
            count = 0
            for batch in valid_loader:
                # Get input and output pair
                seq = batch["seq"].to(device)
                mask = batch["mask"].to(device)

                # Pass through the model
                loss, losses = model(seq, return_list=True, mask=mask)

                # Accumulate validation loss
                count += len(batch)
                total_loss += len(batch) * float(loss)
                for idx in range(6):
                    total_losses[idx] += float(losses[idx])
        val_loss = total_loss / count
        individual_losses = [l / count for l in total_losses]
        print(f"Validation loss: {val_loss:.4f}")
        print(
            f"Individual losses: type={individual_losses[0]:.4f}, "
            f"beat: {individual_losses[1]:.4f}, "
            f"position: {individual_losses[2]:.4f}, "
            f"pitch: {individual_losses[3]:.4f}, "
            f"duration: {individual_losses[4]:.4f}, "
            f"instrument: {individual_losses[5]:.4f}"
        )

        # Release GPU memory right away
        del seq, mask

        # Write losses to file
        loss_csv.write(
            f"{step},{train_loss},{val_loss},{individual_losses[0]},"
            f"{individual_losses[1]},{individual_losses[2]},"
            f"{individual_losses[3]},{individual_losses[4]},"
            f"{individual_losses[5]}\n"
        )

        # Save the model
        checkpoint_filename = out_dir+"/checkpoints/"+f"model_{step}.pt"
        torch.save(model.state_dict(), checkpoint_filename)
        print(f"Saved the model to: {checkpoint_filename}")

        # Copy the model if it is the best model so far
        if val_loss < min_val_loss:
            min_val_loss = val_loss
            shutil.copyfile(
                checkpoint_filename,
                out_dir+"/checkpoints/"+"best_model.pt",
            )
            # Reset the early stopping counter if we found a better model
            if train_args["early_stopping"]:
                count_early_stopping = 0
        elif train_args["early_stopping"]:
            # Increment the early stopping counter if no improvement is found
            count_early_stopping += 1

        # Early stopping
        if (
            train_args["early_stopping"]
            and count_early_stopping > train_args["early_stopping_tolerance"]
        ):
            print(
                "Stopped the training for no improvements in "
                f"{train_args['early_stopping_tolerance']} rounds."
            )
            break

    # Log minimum validation loss
    print(f"Minimum validation loss achieved: {min_val_loss}")

    # Save the optimizer states
    optimizer_filename = out_dir+"/checkpoints/"+f"optimizer_{step}.pt"
    torch.save(optimizer.state_dict(), optimizer_filename)
    print(f"Saved the optimizer state to: {optimizer_filename}")

    # Save the scheduler states
    scheduler_filename = out_dir+"checkpoints/"+f"scheduler_{step}.pt"
    torch.save(scheduler.state_dict(), scheduler_filename)
    print(f"Saved the scheduler state to: {scheduler_filename}")

    # Close the file
    loss_csv.close()

In [79]:
train("experiment",model)

Number of parameters: 19944950
Number of trainable parameters: 257526
Training...


100%|███████████████████████| 1000/1000 [21:22:02<00:00, 76.92s/it, loss=1.7796]


Validating...
Validation loss: 1.3894
Individual losses: type=0.0010, beat: 0.0136, position: 0.0387, pitch: 0.1665, duration: 0.1195, instrument: 0.0080
Saved the model to: experiment/checkpoints/model_1000.pt
Training...


 19%|████▋                    | 186/1000 [38:38<2:49:06, 12.47s/it, loss=1.4610]


KeyboardInterrupt: 