In [12]:
import copy
import librosa
import numpy as np
import pretty_midi
from transformers import Pop2PianoForConditionalGeneration, Pop2PianoProcessor, Pop2PianoTokenizer
from encoder import encode_plus
import sys
sys.path.append("./pop2piano")

In [13]:
import copy
def crop_midi(midi, start_beat, end_beat, extrapolated_beatsteps):
    start = extrapolated_beatsteps[start_beat]
    end = extrapolated_beatsteps[end_beat]
    out = copy.deepcopy(midi)
    for note in out.instruments[0].notes.copy():
        if note.start > end or note.start < start:
            out.instruments[0].notes.remove(note)
        # interpolate index of start note

        # lower = len(extrapolated_beatsteps[extrapolated_beatsteps <= note.start]) - 1
        lower = np.searchsorted(extrapolated_beatsteps, note.start, side='left') - 1
        note.start = lower
        note.start = int(note.start - start_beat)

        lower = np.searchsorted(extrapolated_beatsteps, note.end, side='left') - 1
        # lower = len(extrapolated_beatsteps[extrapolated_beatsteps <= note.end]) - 1
        note.end = lower
        note.end = int(note.end - start_beat)
        if note.end == note.start:
            note.end += 1
    return out

In [14]:
model = Pop2PianoForConditionalGeneration.from_pretrained("./cache/model")
processor = Pop2PianoProcessor.from_pretrained("./cache/processor")
tokenizer = Pop2PianoTokenizer.from_pretrained("./cache/tokenizer")

print("Loaded pretrained model, processor, and tokenizer.")
# cache the model, processor, and tokenizer to avoid downloading them again
# model.save_pretrained("./cache/model")
# processor.save_pretrained("./cache/processor")
# tokenizer.save_pretrained("./cache/tokenizer")



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model, processor, and tokenizer.


In [15]:
# load an example audio file and corresponding ground truth midi file
audio_path = "./processed/audio/Pat Benatar - Hit Me with Your Best Shot.ogg"
# audio_path = "./processed/audio/Aerosmith - Same Old Song & Dance.ogg"
audio, sr = librosa.load(audio_path, sr=44100)  # feel free to change the sr to a suitable value.

# convert the audio file to tokens
inputs = processor(audio=audio, sampling_rate=sr, return_tensors="pt")


# load ground truth midi file
# midi = pretty_midi.PrettyMIDI("./processed/midi/Mountain - Mississippi Queen.mid")
# ground_truth_midi_path = "./processed/midi/Mountain - Mississippi Queen.mid"
# ground_truth_midi_path = "mountain_out_gen.mid"
ground_truth_midi_path = "./processed/piano_midi/Pat Benatar - Hit Me with Your Best Shot.mid"
midi = pretty_midi.PrettyMIDI(ground_truth_midi_path)



In [16]:
inputs.beatsteps[0].shape

torch.Size([747])

In [17]:
# # convert the midi file to tokens
batches = [crop_midi(midi, i, i+8, inputs.extrapolated_beatstep[0]).instruments[0].notes for i in range(2, len(inputs.extrapolated_beatstep[0])-10, 8)]
# # remove empty batches
# batches = [batch for batch in batches if len(batch) > 0]

In [18]:
batches[-4]

[Note(start=0.000000, end=1.000000, pitch=40, velocity=77),
 Note(start=0.000000, end=1.000000, pitch=52, velocity=77),
 Note(start=0.000000, end=1.000000, pitch=59, velocity=77),
 Note(start=0.000000, end=1.000000, pitch=64, velocity=77),
 Note(start=0.000000, end=1.000000, pitch=68, velocity=77),
 Note(start=0.000000, end=1.000000, pitch=71, velocity=77)]

In [19]:
len(inputs.beatsteps[0])

747

In [20]:
711/8

88.875

In [21]:
len(batches)

94

In [22]:
labels = []
offset = 0
for batch in batches:
    # print(f"outer offset: {offset}")
    label, offset = encode_plus(tokenizer, batch, return_tensors="pt", time_offset=0)
    labels.append(label["token_ids"])
labels = [np.append([0], np.append(label, [1, 0])) for label in labels]
gt_longest_length = max([len(label) for label in labels])
model_output = model.generate(inputs["input_features"], generation_config=model.generation_config, return_dict_in_generate=True, output_logits=True, min_new_tokens=gt_longest_length)
longest_length = len(model_output.sequences[0])
padded_labels = np.array([np.pad(label, (0, longest_length - len(label))) for label in labels])


In [23]:
padded_labels.shape

(94, 127)

In [24]:
def one_hot_convert(t_labels, vocab_size):
    # Your vocabulary size
    vocab_size = 2400

    # Create a tensor to hold the one-hot encoded versions
    one_hot_tensor = torch.zeros((*t_labels.shape, vocab_size))

    # Iterate over each element of the original tensor
    for i in range(t_labels.size(0)):
        for j in range(t_labels.size(1)):
            # Get the value from the original tensor
            value = int(t_labels[i, j])
            # One-hot encode the value
            one_hot = torch.zeros(vocab_size)
            one_hot[value] = 1
            # Assign it to the corresponding position in the new tensor
            one_hot_tensor[i, j] = one_hot
    return one_hot_tensor

In [42]:
from torch.nn import CrossEntropyLoss
import torch
loss_fct = CrossEntropyLoss()
logits = torch.stack(model_output.logits).transpose(0,1)
logits = torch.nan_to_num(logits, nan=0.0, posinf=5, neginf=-5)
print(logits.transpose(0,1).shape)
t_labels = torch.tensor(padded_labels)
t_labels = t_labels[:,1:]
one_hot = one_hot_convert(t_labels, 2400)
# generate one hot from t_labels

# print(t_labels.shape)
loss = loss_fct(logits, one_hot)

torch.Size([126, 94, 2400])


In [40]:
logits

tensor([[[-4.2818e+00, -1.0000e+07, -1.4853e+01,  ..., -1.5166e+01,
          -1.4933e+01, -1.5089e+01],
         [-1.0468e+01, -1.0000e+07, -1.5300e+01,  ..., -1.5634e+01,
          -1.5360e+01, -1.5520e+01],
         [-8.7273e+00, -1.0000e+07, -1.5466e+01,  ..., -1.5696e+01,
          -1.5534e+01, -1.5634e+01],
         ...,
         [ 4.3484e+01,  3.3260e+00, -4.7254e+00,  ..., -4.8456e+00,
          -4.6384e+00, -4.8502e+00],
         [ 4.3511e+01,  3.2668e+00, -4.7422e+00,  ..., -4.8624e+00,
          -4.6527e+00, -4.8638e+00],
         [ 4.3643e+01,  3.2227e+00, -4.7192e+00,  ..., -4.8386e+00,
          -4.6297e+00, -4.8388e+00]],

        [[-4.8721e+00, -1.0000e+07, -1.4158e+01,  ..., -1.4365e+01,
          -1.4191e+01, -1.4358e+01],
         [-1.2420e+01, -1.0000e+07, -1.8735e+01,  ..., -1.9031e+01,
          -1.8883e+01, -1.9039e+01],
         [-1.0057e+01, -1.0000e+07, -1.7442e+01,  ..., -1.7573e+01,
          -1.7482e+01, -1.7578e+01],
         ...,
         [ 4.2634e+01,  5

In [43]:
loss

tensor(1.3154)

In [28]:
tokenizer.num_bars = 2
output = tokenizer.batch_decode(np.array(padded_labels),feature_extractor_output=inputs)

In [29]:
output['pretty_midi_objects'][0].write("mountain_out_sanity_check.mid")