In [35]:
import torch
#import torch.nn as nn
#from model3 import MusicTransformer
from config import get_config, get_weights_file_path, latest_weights_file_path
from train import get_model, top_k_decode, tokenizer, midi_to_events, arrangeSequence
import altair as alt
import pandas as pd
import numpy as np
import warnings
import music21
warnings.filterwarnings("ignore")

In [9]:
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


In [10]:
config = get_config()
model = get_model(config, len(list(tokenizer.keys())) ).to(device)

# Load the pretrained weights
preload = config["preload"]
model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state['model_state_dict'])

Loading latest weights


<All keys matched successfully>

In [11]:
def getOutputs(filename, notes=int, start=int):
    seq_len = get_config()["seq_len"]-2
    fileNotes = midi_to_events(filename)[start:]
    target = fileNotes[start:]
    #Return the sequence with the number of notes, but ignore timesteps
    notecount = 0
    for i, element in enumerate(fileNotes):
        if str(element).startswith("t") == False: #Not a timestep so must be a note
            notecount+=1
        if notecount >= notes:
            fileNotes = fileNotes[:(i+1)]
            target = target[(i+1):seq_len]
            break
    input = list(map(lambda x: tokenizer[x], fileNotes))

    #Create special tokens for padding
    sos_token = torch.tensor([tokenizer["EOS"]], dtype=torch.int64)
    eos_token = torch.tensor([tokenizer["SOS"]], dtype=torch.int64)
    pad_token = torch.tensor([tokenizer["PAD"]], dtype=torch.int64)

    num_padding_tokens = seq_len - len(input)
    input = torch.cat([
        sos_token, #1 additional
        torch.tensor(input, dtype=torch.int64),
        eos_token, #2 additional
        torch.tensor([pad_token] * num_padding_tokens, dtype=torch.int64),])
    #Return (input, expected, predicted)
    return fileNotes, target, top_k_decode(model, input, get_config(), device, 3)

In [34]:
def mtx2df(m, max_row, max_col, row_tokens, col_tokens):
    return pd.DataFrame(
        [
            (
                r,
                c,
                float(m[r, c]),
                "%.3d %s" % (r, row_tokens[r] if len(row_tokens) > r else "<blank>"),
                "%.3d %s" % (c, col_tokens[c] if len(col_tokens) > c else "<blank>"),
            )
            for r in range(m.shape[0])
            for c in range(m.shape[1])
            if r < max_row and c < max_col
        ],
        columns=["row", "column", "value", "row_token", "col_token"],
    )

def get_attn_map(layer: int, head: int):
    attn = model.decoder_blocks[layer].self_attn.attention_scores
    return attn[0, head].data

def attn_map(layer, head, row_tokens, col_tokens, max_sentence_len):
    df = mtx2df(
        get_attn_map(layer, head),
        max_sentence_len,
        max_sentence_len,
        row_tokens,
        col_tokens,
    )
    return (
        alt.Chart(data=df)
        .mark_rect()
        .encode(
            x=alt.X("col_token", axis=alt.Axis(title="")),
            y=alt.Y("row_token", axis=alt.Axis(title="")),
            color="value",
            tooltip=["row", "column", "value", "row_token", "col_token"],
        )
        #.title(f"Layer {layer} Head {head}")
        .properties(height=400, width=400, title=f"Layer {layer} Head {head}")
        .interactive()
    )

def get_all_attention_maps(layers: list[int], heads: list[int], row_tokens: list, col_tokens, max_sentence_len: int):
    charts = []
    for layer in layers:
        rowCharts = []
        for head in heads:
            rowCharts.append(attn_map(layer, head, row_tokens, col_tokens, max_sentence_len))
        charts.append(alt.hconcat(*rowCharts))
    return alt.vconcat(*charts)

In [45]:
def readableTokens(events):
    durationDic = {
        "1": "32nd",
        "2": "16th",
        "4": "8th",
        "8": "Quarter",
        "16": "Half",
        "32": "Whole"
    }
    readable = []
    if isinstance(events, str):
        events = [events]
    for id in events:
        if id.startswith("t"): #Timestep
            readable.append(durationDic[id[1:]] + " rest")
        else:
            midiPitch, duration = id.split("-")
            readable.append(music21.pitch.Pitch(midi=int(midiPitch)).nameWithOctave + " " + durationDic[duration])
    return readable

In [46]:
#I tokenized everything already in BarDataset so i need to untokenize it
source, target, predicted = getOutputs(r"C:\Users\alexm\Downloads\Sound Testing 2\midi songs\Debussy_Reverie.mid", 10, 30)
func = np.vectorize(lambda x: list(tokenizer.keys())[x])
print("Source:"," ".join(source))
print("Target:"," ".join(target))
print("Predicted:", " ".join(list(filter(lambda x: x != "SOS", func(predicted.detach().cpu().numpy()).tolist()))))
input_tokens = readableTokens(source)#list(map(lambda x: tokenizer[x], source))
output_tokens = readableTokens(target)#list(map(lambda x: tokenizer[x], target))
if "[PAD]" in source: sentence_len = source.index("[PAD]")
else: sentence_len = config["seq_len"] - 2

Source: t4 67-16 t8 74-32 t4 62-4 t4 60-2 t4 58-16 t8 60-4 t4 76-4 t1 62-4 t4 77-4 t1 67-32
Target: t4 58-8 t8 60-4 t4 62-2 t4 67-16 t4 70-8 t4 62-4 t4 74-8 t1 60-4 t4 58-4 t4 76-8 t1 57-16 t4 58-4 t4 77-16 t1 60-8 t4 65-8 t8 72-32 t4 60-2 t4 58-4 t4 57-2 t4 55-16 t4 57-4 t4 58-4 t4 64-8 t4 67-16 t4 58-4 t4 57-1 t4 55-1 t4 69-16 t1 41-32 t8 48-16 t4 57-4 t4 65-4 t4 62-4 t4 57-2 t4 60-2 t4 65-4 t4 41-32 t4 69-32 t1 48-16 t4 57-8 t4 65-8 t4 62-4 t4 57-2 t4 60-2 t8 65-2 t8 81-16 t1 38-8 t8 45-4 t4 50-2 t4 53-4 t4 76-32 t1 57-2 t4 60-2 t4 64-2 t4 60-4 t4 57-4 t4 53-2 t4 72-4 t1 57-1 t4 76-8 t1 60-1 t4 43-8 74-8 t4 50-2 t4 70-4 t1 55-1 t4 58-1 t1 67-2 t8 81-16 t1 38-8 t4 45-2 t4 50-1 t4 53-2 t4 76-32 t1 57-2 t4 60-4 t4 64-4 t4 60-4 t4 57-4 t4 53-2 t4 72-4 t1 57-1 t4 76-8 t1 60-1 t4 74-8 t1 43-8 t4 50-4 t4 70-4 t1 55-1 t4
Predicted: 82-2 t4 74-2 62-4 66-2 55-16 50-2 66-2 82-2 55-16 55-16 55-16 70-4 66-2 66-2 55-16 66-2 70-4 55-16 55-16 82-2 39-4 39-4 57-4 55-16 74-2 55-16 74-2 60-4 58-16 60-

In [47]:
# Decoder Self-Attention
layers = [0,1,2,3,4,5]
heads = [0,1,2,3,4,5,6,7]
# get_all_attention_maps("decoder", layers, heads, input_tokens, input_tokens, min(50, sentence_len))
# min(len(input_tokens), sentence_len)
get_all_attention_maps(layers, heads, input_tokens, output_tokens, 20)

In [31]:
def top_k_tokens(model, input, config, device, k=3):
    sos_idx = tokenizer['SOS']
    eos_idx = tokenizer['EOS']

    output = torch.empty(1, 1).fill_(sos_idx).type_as(input["input"]).to(device)

    while True:
        if output.size(1) == config["seq_len"]:
            break

        out = model(input["input"].unsqueeze(0))

        _, top_k_words = torch.topk(out[:, -1], k=k, dim=1)

        next_word = top_k_words[:, torch.randint(k, (1,), dtype=torch.long)].squeeze()

        output = torch.cat(
            [output, torch.empty(1, 1).type_as(input["input"]).fill_(next_word.item()).to(device)], dim=1
        )

        if next_word == eos_idx:
            break

    # Get the values and probabilities for the top-k predictions
    values_and_probs = []
    for word_idx in top_k_words.squeeze().tolist():
        word = list(tokenizer.keys())[word_idx]
        probability = torch.nn.functional.softmax(out[:, -1], dim=1)[0, word_idx].item()
        values_and_probs.extend([word+" : "+str(round(probability*100, 2))+"%"])

    # Construct the output string
    output_str = f"Input: {readableTokens(input['src_seq'])}\nTop {k} next predictions: {', '.join(map(str, values_and_probs))}"

    return output_str#output.squeeze(0), output_str
print(top_k_tokens(model, arrangeSequence(r"C:\Users\alexm\Downloads\Sound Testing 2\midi songs\Debussy_Reverie.mid", 10, 30), get_config(), device, 3) )

Input Token: ['t4', '67-16', 't8', '74-32', 't4', '62-4', 't4', '60-2', 't4', '58-16', 't8', '60-4', 't4', '76-4', 't1', '62-4', 't4', '77-4', 't1', '67-32']
Top 3 next predictions: 82-2 : 26.09%, 50-2 : 18.71%, 55-16 : 13.1%
