In [1]:
%load_ext autoreload
%autoreload 2
from aria.utils import _load_weight
from aria.config import load_model_config
from aria.tokenizer import AbsTokenizer, SeparatedAbsTokenizer
from src.load_aria_weights import get_p2q
import argparse
import torch
from torch.nn.functional import log_softmax
from accelerate import Accelerator
import torch.nn.functional as F

from aria.data.midi import MidiDict
torch.set_printoptions(threshold=10_000)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from train_utils import load_config


M_PATH = "./inference/weights/model.safetensors"
tokenizer = AbsTokenizer()
tokenizer.vocab_size
tokenizer.add_tokens_to_vocab(['<SEP>'])
config = load_config("train_config.json")
model = get_p2q(config, tokenizer, M_PATH, False)
# model = get_p2q(config, tokenizer)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

accelerator = Accelerator()
model = accelerator.prepare(model)


In [3]:
midi_path = "datasets/paired-dataset-5/performance/audio-https%3A%2F%2Fwww.youtube.com%2Fwatch%3Fv%3DL2KdUIeYTQo.mid"
midi_dict = MidiDict.from_midi(midi_path)
tokenized_midi = tokenizer._tokenize_midi_dict(midi_dict=midi_dict)
encoded_midi_seq = tokenizer.encode(tokenized_midi)
input_midi_len = len(encoded_midi_seq)
input_midi_len

4101

In [10]:

import math


def encode_pair(model, input_sequence, seq_len, device):
    window_size = math.floor(seq_len)
    lr_overlap = math.floor(seq_len / 4)
    inputs = []
    encoded_outputs = []

    for start in range(0, len(input_sequence), lr_overlap):
        end = start + window_size
        chunk = input_sequence[start:end]
        
        if len(chunk) < window_size + lr_overlap:
            # Padding the last chunk if it's smaller than the window size
            chunk = chunk + [0] * (window_size - len(chunk))
        
        chunk_tensor = torch.tensor(chunk, dtype=torch.long).unsqueeze(0).to(device)
        
        encoded_chunk = model.encode(chunk_tensor)
        # encoded_chunk = torch.randn(1, 512, 768, device=device)
        inputs.extend(chunk_tensor)
        encoded_outputs.append(encoded_chunk)

    # Concatenate all encoded chunks along the sequence dimension
    return inputs, encoded_outputs

def encode_single(model, input_sequence, seq_len, device):
    inputs = []
    encoded_outputs = []

    for start in range(0, 128):
        end = start + seq_len
        chunk = input_sequence[start:end]
        
        if len(chunk) < seq_len:
            # Padding the last chunk if it's smaller than the window size
            chunk = chunk + [0] * (seq_len - len(chunk))
        
        chunk_tensor = torch.tensor(chunk, dtype=torch.long).unsqueeze(0).to(device)
        
        encoded_chunk = model.encode(chunk_tensor)
        # encoded_chunk = torch.randn(1, 512, 768, device=device)
        inputs.extend(chunk_tensor)
        encoded_outputs.append(encoded_chunk)

    # Concatenate all encoded chunks along the sequence dimension
    return inputs, encoded_outputs

inputs, encoder_output = encode_single(model, encoded_midi_seq, 512, device)

In [16]:


def single_greedy_search(model, encoder_output_list, max_length, device):

    start_token = tokenizer.encode(['<S>'])[0]
    end_token = tokenizer.encode(['<E>'])[0]
    pad_token = tokenizer.encode(['<P>'])[0]
    
    sequences = [start_token]  # Initialize with the start token
    initial_shift = 0
    encoder_output = encoder_output_list[initial_shift]


    for counter in range(128):
        encoder_output = encoder_output_list[counter]
        current_length = len(sequences)
        padded_sequences = sequences + [pad_token] * (max_length - current_length)
        
        # Convert the sequence list into a tensor wiaath the correct shape
        input_tensor = torch.tensor(padded_sequences, device=device).unsqueeze(0)  # Adding batch dimension

        with torch.amp.autocast('cuda'):
            logits = model.logits(input_tensor, encoder_output)
        
        logits = logits[:, -1, :]  # Take the logits of the last token
        next_token = torch.argmax(logits, dim=-1).item()  # Choose the token with the highest probability
        
        sequences.append(next_token)  # Append the chosen token to the sequence
        
        if next_token == end_token:  # Stop if the end token is generated
            break
    
    return sequences

def greedy_search(model, encoder_output_list, max_length, device):

    start_token = tokenizer.encode(['<S>'])[0]
    end_token = tokenizer.encode(['<E>'])[0]
    pad_token = tokenizer.encode(['<P>'])[0]
    
    sequences = [start_token]  # Initialize with the start token
    initial_shift = 0
    encoder_output = encoder_output_list[initial_shift]


    for counter in range(max_length):
        if counter % lr_overlap == 0:
            initial_shift += 1
            encoder_output = encoder_output_list[initial_shift]
        current_length = len(sequences)
        padded_sequences = sequences + [pad_token] * (max_length - current_length)
        
        # Convert the sequence list into a tensor wiaath the correct shape
        input_tensor = torch.tensor(padded_sequences, device=device).unsqueeze(0)  # Adding batch dimension

        with torch.amp.autocast('cuda'):
            logits = model.logits(input_tensor, encoder_output)
        
        logits = logits[:, -1, :]  # Take the logits of the last token
        next_token = torch.argmax(logits, dim=-1).item()  # Choose the token with the highest probability
        
        sequences.append(next_token)  # Append the chosen token to the sequence
        
        if next_token == end_token:  # Stop if the end token is generated
            break
    
    return sequences

def beam_search(model, encoder_output_list, max_length, device, beam_width=5):

    start_token = tokenizer.encode(['<S>'])[0]
    end_token = tokenizer.encode(['<E>'])[0]
    pad_token = tokenizer.encode(['<P>'])[0]
    
    sequences = [[start_token]]  # Initialize with the start token
    scores = torch.zeros(1, device=device)  # Keep track of the scores of the sequences
    initial_shift = 0
    encoder_output = encoder_output_list[initial_shift]

    for counter in range(max_length):
        if counter % lr_overlap == 0:
            initial_shift += 1
            encoder_output = encoder_output_list[initial_shift]

        all_candidates = []

        for i in range(len(sequences)):
            seq = sequences[i]
            score = scores[i]

            current_length = len(seq)
            padded_sequences = seq + [pad_token] * (max_length - current_length)
            
            # Convert the sequence list into a tensor with the correct shape
            input_tensor = torch.tensor(padded_sequences, device=device).unsqueeze(0)  # Adding batch dimension

            with torch.amp.autocast('cuda'):
                logits = model.logits(input_tensor, encoder_output)
            
            logits = logits[:, -1, :]  # Take the logits of the last token
            probabilities = torch.nn.functional.log_softmax(logits, dim=-1)  # Apply log softmax to get log-probs

            top_k_probs, top_k_tokens = probabilities.topk(beam_width)  # Get the top k tokens and their log-probs

            for j in range(beam_width):
                candidate = seq + [top_k_tokens[0, j].item()]
                candidate_score = score + top_k_probs[0, j]
                all_candidates.append((candidate_score, candidate))
        
        # Sort all candidates by score and select the best k
        all_candidates = sorted(all_candidates, key=lambda x: x[0], reverse=True)
        sequences = [x[1] for x in all_candidates[:beam_width]]
        scores = torch.tensor([x[0] for x in all_candidates[:beam_width]], device=device)
        
        # If all sequences have reached the end token, stop early
        if all(seq[-1] == end_token for seq in sequences):
            break
    
    # Return the sequence with the highest score
    best_sequence = sequences[0]
    return best_sequence



In [17]:
decoded_seq = single_greedy_search(model, encoder_output, 512, 'cuda')

In [18]:
raw_output = tokenizer.decode(decoded_seq)

In [19]:
raw_output

['<S>',
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('dur', 70),
 ('onset', 3950),
 ('dur', 70),
 ('onset', 3950),
 ('dur', 70),
 ('onset', 3950),
 ('onset', 3950),
 ('onset', 3950),
 ('onset', 3950),
 ('onset', 3950),
 ('dur', 70),
 ('onset', 3950),
 ('onset', 3950),
 ('onset', 3950),
 ('onset', 3950),
 ('onset', 3950),
 ('onset', 3950),
 ('onset', 3950),
 ('onset', 3950),
 ('onset', 3950),
 ('onset', 3950),
 ('dur', 70),
 ('onset', 3950),
 ('onset', 3950),
 ('onset', 3950),
 ('onset', 3950),
 ('onset', 3950),
 ('onset', 3950),
 ('onset', 3950),
