In [17]:
%load_ext autoreload
%autoreload 2
from aria.utils import _load_weight
from aria.config import load_model_config
from aria.tokenizer import AbsTokenizer, SeparatedAbsTokenizer
from src.load_aria_weights import get_p2q
import argparse
import torch
from torch.nn.functional import log_softmax
from accelerate import Accelerator
import torch.nn.functional as F

from aria.data.midi import MidiDict
torch.set_printoptions(threshold=10_000)

import math

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
from train_utils import load_config


M_PATH = "./inference/weights/model.safetensors"
tokenizer = AbsTokenizer()
tokenizer.vocab_size
tokenizer.add_tokens_to_vocab(['<SEP>'])
config = load_config("train_config.json")
model = get_p2q(config, tokenizer, M_PATH, False)
# model = get_p2q(config, tokenizer)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

accelerator = Accelerator()
model = accelerator.prepare(model)


In [36]:
def single_greedy_search(model, enc_input, max_length, device):

    start_token = tokenizer.encode(['<S>'])[0]
    end_token = tokenizer.encode(['<E>'])[0]
    pad_token = tokenizer.encode(['<P>'])[0]
    
    sequences = [start_token]  # Initialize with the start token

    for pos in range(max_length):
        chunk = enc_input[pos:pos + 1096]
        
        if len(chunk) < max_length:
            # Padding the last chunk if it's smaller than the window size
            chunk = chunk + [pad_token] * (max_length - len(chunk))
        
        chunk_tensor = torch.tensor(chunk, dtype=torch.long).unsqueeze(0).to(device)
        
        encoder_output = model.encode(chunk_tensor)

        current_length = len(sequences)
        padded_sequences = sequences + [pad_token] * (max_length - current_length)
        
        # Convert the sequence list into a tensor wiaath the correct shape
        input_tensor = torch.tensor(padded_sequences, device=device).unsqueeze(0)  # Adding batch dimension

        with torch.amp.autocast('cuda'):
            logits = model.logits(input_tensor, encoder_output)
        
        logits = logits[:, -1, :]  # Take the logits of the last token
        next_token = torch.argmax(logits, dim=-1).item()  # Choose the token with the highest probability
        
        sequences.append(next_token)  # Append the chosen token to the sequence
        
        if next_token == end_token:  # Stop if the end token is generated
            break
    
    return sequences

def inference(midi_path):
    
    midi_dict = MidiDict.from_midi(midi_path)
    tokenized_midi = tokenizer._tokenize_midi_dict(midi_dict=midi_dict)
    encoded_midi_seq = tokenizer.encode(tokenized_midi)
    decoded_seq = single_greedy_search(model, encoded_midi_seq, 4096, 'cuda')
    raw_output = tokenizer.decode(decoded_seq)
    return tokenized_midi, raw_output

In [37]:
midi_path = "datasets/paired-dataset-5/performance/audio-https%3A%2F%2Fwww.youtube.com%2Fwatch%3Fv%3DL2KdUIeYTQo.mid"
input, output = inference(midi_path)
print(input[1:100])
print(output[:100])

['<S>', ('piano', 70, 45), ('onset', 0), ('dur', 4420), ('piano', 79, 45), ('onset', 950), ('dur', 3470), ('piano', 39, 30), ('onset', 970), ('dur', 1740), ('piano', 63, 45), ('onset', 1610), ('dur', 1100), ('piano', 55, 30), ('onset', 1620), ('dur', 1090), ('piano', 67, 60), ('onset', 2140), ('dur', 570), ('piano', 58, 45), ('onset', 2150), ('dur', 560), ('piano', 63, 60), ('onset', 2170), ('dur', 540), ('piano', 51, 45), ('onset', 2720), ('dur', 1700), ('piano', 77, 60), ('onset', 3250), ('dur', 1170), ('piano', 62, 60), ('onset', 3260), ('dur', 1160), ('piano', 56, 45), ('onset', 3270), ('dur', 1150), ('piano', 68, 60), ('onset', 3780), ('dur', 640), ('piano', 59, 45), ('onset', 3810), ('dur', 610), ('piano', 79, 75), ('onset', 3810), ('dur', 3860), ('piano', 62, 45), ('onset', 3830), ('dur', 590), ('piano', 77, 60), ('onset', 4350), ('dur', 3320), ('piano', 39, 60), ('onset', 4370), ('dur', 1730), ('piano', 63, 60), ('onset', 4970), ('dur', 1130), ('piano', 55, 45), ('onset', 4990)