## Train a MIDI Generator Model from Scratch

This notebook demonstrates training a simple language model from scratch to generate MIDI-like sequences in the "p:v:d:t" format.
We will use a basic Transformer model and a custom tokenizer tailored for this format.

### 1. Install Libraries

```python
!pip install datasets transformers[torch]
```

### 2. Import Libraries

In [1]:
import os
import warnings

os.environ["WANDB_DISABLED"] = "true"
warnings.filterwarnings('ignore')

In [2]:
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoConfig,
    Trainer,
    TrainingArguments,
    PreTrainedTokenizerFast,
    PreTrainedModel,
    GPT2LMHeadModel,
    GPT2Config
)
from typing import List, Dict, Tuple
from torch.utils.data import Dataset
from tqdm.auto import tqdm
import re
import numpy as np

### 3. Tokenizer

In [3]:
# We define a simple tokenizer that recognizes the "p:v:d:t" format and special tokens.

class SimpleMidiTokenizer:
    def __init__(self):
        self.vocab = {
            '<pad>': 0, '<s>': 1, '</s>': 2, '<unk>': 3,
            'p': 4, 'v': 5, 'd': 6, 't': 7, ':': 8,
            '0': 9, '1': 10, '2': 11, '3': 12, '4': 13, '5': 14, '6': 15, '7': 16, '8': 17, '9': 18, ' ':19
        }
        self.inv_vocab = {v: k for k, v in self.vocab.items()}
        self.vocab_size = len(self.vocab)
        self.pad_token_id = self.vocab['<pad>']
        self.bos_token_id = self.vocab['<s>']
        self.eos_token_id = self.vocab['</s>']

    def tokenize(self, text: str) -> List[int]:
        tokens = []
        tokens.append(self.bos_token_id) # Start of sequence token
        for element in text.split(' '): # Split by space to get 'p:v:d:t' blocks
            if not element: # Skip empty strings if any
                continue
            for char in element:
                if char in self.vocab:
                    tokens.append(self.vocab[char])
                else:
                    tokens.append(self.vocab['<unk>']) # Unknown token
            tokens.append(self.vocab[' ']) # Space token to separate note blocks
        tokens.pop() # Remove last space token if present
        tokens.append(self.eos_token_id) # End of sequence token
        return tokens

    def decode(self, token_ids: List[int], skip_special_tokens=True) -> str:
        text = ""
        for token_id in token_ids:
            if token_id in self.inv_vocab:
                token = self.inv_vocab[token_id]
                if skip_special_tokens:
                    if token not in ['<s>', '</s>', '<pad>']: # Exclude special tokens from decoded text
                        text += token
                else:
                    text += token
            else:
                text += '<unk>' # Handle unknown token ids if any
        return text.replace('  ', ' ') # Clean up double spaces


class BinnedQuadrupletMidiTokenizer:
    def __init__(self,
                 pitch_range: Tuple[int, int] = (0, 127),
                 volume_range: Tuple[int, int] = (0, 127),
                 duration_range: Tuple[int, int] = (0, 4000), # Example max
                 time_range: Tuple[int, int] = (0, 10000),  # Example max
                 duration_bins: int = 128, # Number of bins for duration
                 time_bins: int = 128      # Number of bins for time
                 ):

        self.pitch_range = pitch_range
        self.volume_range = volume_range
        # Store raw ranges and bin counts
        self.duration_info = {'range': duration_range, 'bins': duration_bins}
        self.time_info = {'range': time_range, 'bins': time_bins}

        self.vocab = {
            '<pad>': 0, '<s>': 1, '</s>': 2, '<unk>': 3, '<note>': 4
        }
        self.inv_vocab = {v: k for k, v in self.vocab.items()}
        self.next_token_id = len(self.vocab)

        # Store prefixes
        self.prefixes = {
            'pitch': '<pitch_', 'volume': '<volume_',
            'duration': '<duration_bin_', 'time': '<time_bin_'
        }

        # --- Add tokens ---
        # Pitch and Volume (exact values)
        self._add_exact_value_tokens(pitch_range, self.prefixes['pitch'])
        self._add_exact_value_tokens(volume_range, self.prefixes['volume'])

        # Duration and Time (binned values)
        self.duration_boundaries = self._calculate_log_boundaries(duration_range, duration_bins)
        self.time_boundaries = self._calculate_log_boundaries(time_range, time_bins)
        self._add_bin_tokens(duration_bins, self.prefixes['duration'])
        self._add_bin_tokens(time_bins, self.prefixes['time'])

        self.vocab_size = len(self.vocab)
        self.pad_token_id = self.vocab['<pad>']
        self.bos_token_id = self.vocab['<s>']
        self.eos_token_id = self.vocab['</s>']
        self.unk_token_id = self.vocab['<unk>']
        self.note_token_id = self.vocab['<note>']

        # Precompile regex
        self.note_pattern = re.compile(r'p(\d+):v(\d+):d(\d+):t(\d+)')

    def _calculate_log_boundaries(self, value_range: Tuple[int, int], num_bins: int) -> np.ndarray:
        """Calculates logarithmic bin boundaries."""
        min_val, max_val = value_range
        if min_val < 0: min_val = 0 # Ensure non-negative for log
        # Add 1 before log, subtract 1 after to handle 0 correctly
        # Use logspace from log10(min_val+1) to log10(max_val+1)
        boundaries = np.logspace(
            np.log10(min_val + 1),
            np.log10(max_val + 1),
            num=num_bins + 1 # Need num_bins + 1 boundaries for num_bins bins
        ) - 1
        # Ensure the first boundary is exactly the minimum value if it was >= 0
        if min_val >= 0:
             boundaries[0] = min_val
        return boundaries

    def _add_exact_value_tokens(self, value_range: Tuple[int, int], prefix: str):
        """Adds tokens for each exact value in the range."""
        for i in range(value_range[0], value_range[1] + 1):
            token_name = f'{prefix}{i}>'
            self.vocab[token_name] = self.next_token_id
            self.inv_vocab[self.next_token_id] = token_name
            self.next_token_id += 1

    def _add_bin_tokens(self, num_bins: int, prefix: str):
        """Adds tokens for each bin index."""
        for i in range(num_bins):
            token_name = f'{prefix}{i}>'
            self.vocab[token_name] = self.next_token_id
            self.inv_vocab[self.next_token_id] = token_name
            self.next_token_id += 1

    def _get_bin_index(self, value: int, boundaries: np.ndarray) -> int:
        """Finds the appropriate bin index for a value."""
        # np.digitize returns the index of the bin (starting from 1)
        # boundaries[i-1] <= x < boundaries[i]
        bin_index = np.digitize(value, boundaries[1:], right=False) # Use boundaries[1:] because digitize checks < boundary
        # Ensure index is within bounds [0, num_bins-1]
        return min(bin_index, len(boundaries) - 2) # len(boundaries) - 2 is the max bin index

    def _get_value_token_id(self, value: int, value_range: Tuple[int, int], prefix: str) -> int:
        """Gets the token ID for an exact value."""
        if value_range[0] <= value <= value_range[1]:
            token_name = f'{prefix}{value}>'
            return self.vocab.get(token_name, self.unk_token_id)
        return self.unk_token_id

    def _get_bin_token_id(self, value: int, boundaries: np.ndarray, prefix: str) -> int:
        """Gets the token ID for a binned value."""
        min_val = boundaries[0]
        max_val = boundaries[-1]
        if min_val <= value <= max_val:
             bin_index = self._get_bin_index(value, boundaries)
             token_name = f'{prefix}{bin_index}>'
             return self.vocab.get(token_name, self.unk_token_id) # Should exist, but fallback
        return self.unk_token_id

    def tokenize(self, text: str) -> List[int]:
        tokens = [self.bos_token_id]
        note_blocks = text.strip().split(' ')
        for block in note_blocks:
            if not block: continue
            match = self.note_pattern.match(block)
            if match:
                try:
                    p_val = int(match.group(1))
                    v_val = int(match.group(2))
                    d_val = int(match.group(3))
                    t_val = int(match.group(4))

                    p_token_id = self._get_value_token_id(p_val, self.pitch_range, self.prefixes['pitch'])
                    v_token_id = self._get_value_token_id(v_val, self.volume_range, self.prefixes['volume'])
                    d_token_id = self._get_bin_token_id(d_val, self.duration_boundaries, self.prefixes['duration'])
                    t_token_id = self._get_bin_token_id(t_val, self.time_boundaries, self.prefixes['time'])

                    # Check if any tokenization resulted in UNK
                    if all(tid != self.unk_token_id for tid in [p_token_id, v_token_id, d_token_id, t_token_id]):
                        tokens.append(self.note_token_id)
                        tokens.append(p_token_id)
                        tokens.append(v_token_id)
                        tokens.append(d_token_id)
                        tokens.append(t_token_id)
                    else:
                        # Handle UNK during tokenization (e.g., value truly out of range)
                        tokens.extend([self.unk_token_id] * 5) # Or just one UNK? Depends.
                except (ValueError, IndexError):
                    tokens.extend([self.unk_token_id] * 5) # Malformed block
            else:
                 tokens.append(self.unk_token_id) # Non-matching block

        tokens.append(self.eos_token_id)
        return tokens

    def _get_value_from_exact_token(self, token_id: int, prefix: str) -> int | None:
        """Helper to extract value from exact value token."""
        token = self.inv_vocab.get(token_id)
        if token and token.startswith(prefix) and token.endswith('>'):
            try:
                return int(token[len(prefix):-1])
            except ValueError:
                return None
        return None

    def _get_approx_value_from_bin_token(self, token_id: int, prefix: str, boundaries: np.ndarray) -> str | None:
         """Helper to get an approximate representation from bin token."""
         token = self.inv_vocab.get(token_id)
         if token and token.startswith(prefix) and token.endswith('>'):
             try:
                 bin_index = int(token[len(prefix):-1])
                 if 0 <= bin_index < len(boundaries) - 1:
                     # Return a string representing the bin range or midpoint
                     lower_bound = boundaries[bin_index]
                     upper_bound = boundaries[bin_index + 1]
                     # Simple representation: midpoint rounded
                     midpoint = int(round((lower_bound + upper_bound) / 2))
                     # Or return range string: f"({int(lower_bound)}-{int(upper_bound)})"
                     return str(midpoint) # Return midpoint as string
                 else: return "<bin_idx_err>"
             except ValueError:
                 return "<bin_parse_err>"
         return None


    def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str:
        notes_str = []
        i = 0
        while i < len(token_ids):
            token_id = token_ids[i]
            token = self.inv_vocab.get(token_id)

            special_tokens_to_handle = ['<s>', '</s>', '<pad>', '<unk>']

            if skip_special_tokens and token in special_tokens_to_handle:
                i += 1
                continue

            if not skip_special_tokens and token in special_tokens_to_handle:
                 notes_str.append(token)
                 i += 1
                 continue

            if token == '<note>':
                if i + 4 < len(token_ids): # Check for full quadruplet
                    p_token_id = token_ids[i+1]
                    v_token_id = token_ids[i+2]
                    d_token_id = token_ids[i+3]
                    t_token_id = token_ids[i+4]

                    p_val = self._get_value_from_exact_token(p_token_id, self.prefixes['pitch'])
                    v_val = self._get_value_from_exact_token(v_token_id, self.prefixes['volume'])
                    # Get approximate value string for duration and time
                    d_val_approx = self._get_approx_value_from_bin_token(d_token_id, self.prefixes['duration'], self.duration_boundaries)
                    t_val_approx = self._get_approx_value_from_bin_token(t_token_id, self.prefixes['time'], self.time_boundaries)


                    if all(v is not None for v in [p_val, v_val, d_val_approx, t_val_approx]):
                        notes_str.append(f"p{p_val}:v{v_val}:d{d_val_approx}:t{t_val_approx}")
                        i += 5 # Move past the <note> and its 4 value tokens
                    else:
                        # Malformed note sequence after <note>
                        if not skip_special_tokens: notes_str.append("<unk_note>")
                        i += 1 # Move past the <note> token only
                else:
                    # Not enough tokens after <note>
                    if not skip_special_tokens: notes_str.append("<partial_note>")
                    i += 1
            else:
                 # Unexpected token
                 if not skip_special_tokens: notes_str.append(token if token else "<unk_decode>")
                 i += 1

        return " ".join(notes_str)


### 4. Midi Dataset

In [4]:
class MidiDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]['text']
        midi_sequence = self._preprocess_text(text) # Preprocess inside dataset
        tokenized_sequence = self.tokenizer.tokenize(midi_sequence)
        padded_sequence = self._pad_sequence(tokenized_sequence)
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.tokenizer.pad_token_id else -100 for label in padded_sequence]
        return {
            'input_ids': torch.tensor(padded_sequence),
            'labels': torch.tensor(labels) # Labels are same as input_ids for LM
        }

    def _preprocess_text(self, text):
        """Removes the system prompt and keeps only the MIDI sequence."""
        try:
            midi_sequence_start = text.find("[/INST]") + len("[/INST]")
            midi_sequence_end = text.find("</s>")
            if midi_sequence_end != -1: # Handle case where </s> is present
                midi_sequence = text[midi_sequence_start:midi_sequence_end].strip()
            else: # If </s> is not found, take until the end
                midi_sequence = text[midi_sequence_start:].strip()
            return midi_sequence
        except:
            print(f"Warning: Could not process example: {text[:100]}...")
            return text # Fallback, keep original if processing fails

    def _pad_sequence(self, tokenized_sequence):
        """Pads or truncates a tokenized sequence to max_length."""
        if len(tokenized_sequence) > self.max_length:
            tokenized_sequence = tokenized_sequence[:self.max_length] # Truncate if longer
        padding_length = self.max_length - len(tokenized_sequence)
        tokenized_sequence.extend([self.tokenizer.pad_token_id] * padding_length)
        return tokenized_sequence

### 5. Preprocess data

In [5]:
def calculate_dt_ranges(dataset):
    """
    Calculates the minimum and maximum duration (d) and time (t) values
    found in the 'text' column of the dataset after preprocessing.
    """
    min_d, max_d = float('inf'), float('-inf')
    min_t, max_t = float('inf'), float('-inf')

    # Precompile regex for faster parsing
    note_pattern = re.compile(r'p(\d+):v(\d+):d(\d+):t(\d+)') # Slightly more robust

    print("Calculating duration (d) and time (t) ranges from dataset...")
    for example in tqdm(dataset):
        text = example['text']
        # --- Preprocessing logic (mirrors MidiDataset._preprocess_text) ---
        try:
            midi_sequence_start = text.find("[/INST]") + len("[/INST]")
            midi_sequence_end = text.find("</s>")
            if midi_sequence_end != -1:
                midi_sequence = text[midi_sequence_start:midi_sequence_end].strip()
            else:
                midi_sequence = text[midi_sequence_start:].strip()
        except:
            # Skip problematic examples during range calculation
            continue
        # --- End Preprocessing ---

        # Find all note matches in the preprocessed sequence
        # We can split by space first for potentially better performance on long strings
        note_blocks = midi_sequence.split(' ')
        for block in note_blocks:
            match = note_pattern.match(block) # Use match since we expect it at the start of the block
            if match:
                try:
                    # Extract d and t values as integers
                    d_val = int(match.group(3))
                    t_val = int(match.group(4))

                    # Update min/max
                    min_d = min(min_d, d_val)
                    max_d = max(max_d, d_val)
                    min_t = min(min_t, t_val)
                    max_t = max(max_t, t_val)
                except (ValueError, IndexError):
                    continue

    # Handle cases where no valid values were found
    if min_d == float('inf'): min_d = 0
    if max_d == float('-inf'): max_d = 2000 # Sensible default max if none found
    if min_t == float('inf'): min_t = 0
    if max_t == float('-inf'): max_t = 2000 # Sensible default max if none found

    print("Calculation complete.")
    # Ensure min is not greater than max if only one value was found
    if min_d > max_d: max_d = min_d
    if min_t > max_t: max_t = min_t

    return (min_d, max_d), (min_t, max_t)

In [6]:
def filter_by_max_t(example, max_t_threshold):
    """
    Checks if all 't' values in a single dataset example are <= max_t_threshold.
    Returns True to keep the example, False to discard it.
    """
    text = example['text']
    # --- Preprocessing logic ---
    try:
        midi_sequence_start = text.find("[/INST]") + len("[/INST]")
        midi_sequence_end = text.find("</s>")
        if midi_sequence_end != -1: midi_sequence = text[midi_sequence_start:midi_sequence_end].strip()
        else: midi_sequence = text[midi_sequence_start:].strip()
    except: return True # Keep examples where preprocessing fails
    # --- End Preprocessing ---

    t_pattern = re.compile(r':t(\d+)')
    matches = t_pattern.finditer(midi_sequence)
    try:
        for match in matches:
            t_val = int(match.group(1))
            if t_val > max_t_threshold: return False # Discard example
    except ValueError: return True # Keep examples with parsing errors
    except Exception: return True # Keep on other unexpected errors
    return True # Keep if all checks pass

In [7]:
dataset_name = "fegounna/GMP_4K"
raw_dataset_split = "train[:40000]" # Adjust size as needed
dataset = load_dataset(dataset_name, split=raw_dataset_split)
MAX_T_THRESHOLD = 10000

# --- Load Raw Dataset ---
print(f"Loading raw dataset split: {raw_dataset_split}...")
raw_dataset = load_dataset(dataset_name, split=raw_dataset_split)
print(f"Original dataset size: {len(raw_dataset)}")

# --- Filter Dataset ---
print(f"Filtering dataset with max_t_threshold = {MAX_T_THRESHOLD}...")
filtered_dataset = raw_dataset.filter(
    lambda example: filter_by_max_t(example, MAX_T_THRESHOLD)
)
print(f"Filtered dataset size: {len(filtered_dataset)}")

# --- Calculate Ranges on FILTERED data ---
print("Calculating ranges on filtered dataset...")
duration_range, pause_range = calculate_dt_ranges(filtered_dataset) # Use filtered_dataset
print(f"Calculated Duration (d) Range: {duration_range}")
print(f"Calculated Pause (t) Range: {pause_range}") # This range should now respect the threshold


# Instantiate BINNED Tokenizer using calculated ranges and desired bins
DURATION_BINS = 2048 # Control vocab size vs precision trade-off
TIME_BINS = 4096   # Control vocab size vs precision trade-off

tokenizer = BinnedQuadrupletMidiTokenizer(
    duration_range=duration_range,
    time_range=pause_range,
    duration_bins=DURATION_BINS,
    time_bins=TIME_BINS
)

Loading raw dataset split: train[:40000]...
Original dataset size: 40000
Filtering dataset with max_t_threshold = 10000...
Filtered dataset size: 39839
Calculating ranges on filtered dataset...
Calculating duration (d) and time (t) ranges from dataset...


  0%|          | 0/39839 [00:00<?, ?it/s]

Calculation complete.
Calculated Duration (d) Range: (0, 4613)
Calculated Pause (t) Range: (0, 9993)


In [8]:
# Create the training dataset using the tokenizer
train_dataset = MidiDataset(dataset, tokenizer, max_length=512)

print("\nTokenizer vocab size:", tokenizer.vocab_size)
print("Dataset size:", len(train_dataset))
example_dataset = train_dataset[0]
print("\nExample from dataset:")
print("Input IDs (first 50):", example_dataset['input_ids'][:50])
print("Labels (first 50):", example_dataset['labels'][:50])
decoded_example = tokenizer.decode(example_dataset['input_ids'].tolist())
print("\nDecoded example (first part):", decoded_example[:200] + "...")


Tokenizer vocab size: 6405
Dataset size: 40000

Example from dataset:
Input IDs (first 50): tensor([   1,    4,   86,  211, 1709, 4430,    4,   81,  201, 1489, 3105,    4,
          62,  193, 1714, 4552,    4,   89,  220, 1719, 2617,    4,   50,  212,
        1718, 2617,    4,   81,  200, 1593, 4445,    4,   86,  212, 1637, 2617,
           4,   57,  193, 1718, 4452,    4,   62,  202, 1717, 2925,    4,   81,
         206, 1497])
Labels (first 50): tensor([   1,    4,   86,  211, 1709, 4430,    4,   81,  201, 1489, 3105,    4,
          62,  193, 1714, 4552,    4,   89,  220, 1719, 2617,    4,   50,  212,
        1718, 2617,    4,   81,  200, 1593, 4445,    4,   86,  212, 1637, 2617,
           4,   57,  193, 1718, 4452,    4,   62,  202, 1717, 2925,    4,   81,
         206, 1497])

Decoded example (first part): p81:v78:d389:t117 p76:v68:d157:t5 p57:v60:d398:t154 p84:v87:d406:t1 p45:v79:d404:t1 p76:v67:d241:t121 p81:v79:d289:t1 p52:v60:d404:t123 p57:v69:d402:t3 p76:v73:d162:t159 p84:v

### 6. Build a Simple Model from Scratch

In [9]:
# We create a small GPT-2 like model configuration and instantiate the model.

config = GPT2Config(
    vocab_size=tokenizer.vocab_size,
    n_embd=384,      # Embedding dimension
    n_head=6,        # Number of attention heads
    n_layer=6,       # Number of layers
    n_positions=512, # Max sequence length - adjust if needed
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id
)

model = GPT2LMHeadModel(config)
print(f"Model parameters: {model.num_parameters()}")


Model parameters: 13303680


### 7. Set up Training

In [10]:
# We define training arguments and the Trainer.

output_dir = "./models"
training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=6,
    per_device_train_batch_size=16,
    save_steps=500,
    save_total_limit=2,
    logging_steps=500,
    learning_rate=5e-4,
    # lr_scheduler_type='cosine_with_restarts', # Optional learning rate scheduler
    warmup_steps=100,
    weight_decay=0.01,
    report_to=None, # No wandb for simple notebook
    fp16=False,      # Set to True if using GPU with FP16 support
    bf16=False,      # Set to True if using GPU with BF16 support (e.g., Ampere GPUs)
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [None]:
# ### 8. Train the Model

trainer.train(resume_from_checkpoint=False)

# ### 9. Save Trained Model and Tokenizer

model.save_pretrained(output_dir)

print(f"Model saved to {output_dir}")

### 10. Example Generation

In [17]:
import torch

# Assuming 'tokenizer', 'trainer' are already defined and the model is loaded

def generate_midi_sequence_nucleus(
    prompt_text: str = "",
    max_new_tokens: int = 256, # Renamed for clarity: generates this many NEW tokens
    top_p: float = 0.9,        # Nucleus sampling probability threshold (e.g., 0.85, 0.9, 0.95)
    temperature: float = 0.8,  # Optional: slight temperature can smooth probabilities before top_p
    device = None
):
    """Generates MIDI sequence using nucleus sampling."""

    if device is None:
      device = trainer.model.device # Use the model's device by default

    input_tokens = tokenizer.tokenize(prompt_text)
    input_tensor = torch.tensor([input_tokens]).to(device)
    input_length = input_tensor.shape[1] # Length of the prompt tokens

    # Calculate the total max_length for the generate function
    total_max_length = input_length + max_new_tokens

    print(f"Generating sequence with Nucleus Sampling (top_p={top_p}, temp={temperature})...")
    print(f"Prompt length: {input_length} tokens, Max new tokens: {max_new_tokens}, Total max length: {total_max_length}")

    generated_token_ids = trainer.model.generate(
        input_tensor,
        max_length=total_max_length,    # Total desired sequence length
        do_sample=True,                 # MUST be True for sampling strategies
        top_p=top_p,                    # Enable nucleus sampling
        temperature=temperature,        # Control randomness (1.0 = no change)
        num_return_sequences=1,
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id
    )

    # Extract only the generated tokens (excluding the prompt)
    generated_part_ids = generated_token_ids[0, input_length:]

    # Decode the generated part
    generated_sequence = tokenizer.decode(generated_part_ids.tolist(), skip_special_tokens=True) # Skip special tokens for cleaner output

    return generated_sequence.strip() # Remove leading/trailing whitespace

# --- Example Usage ---

prompt = "" # Start generation from scratch
# Let's generate approximately 510 *new* tokens
# Note: Tokenizer adds BOS/EOS, so actual number might vary slightly depending on tokenizer implementation
max_new_tokens_to_generate = 510

generated_midi = generate_midi_sequence_nucleus(
    prompt_text=prompt,
    max_new_tokens=max_new_tokens_to_generate,
    top_p=0.95,       # Common value for nucleus sampling
    temperature=0.9  # Slightly reduced temperature often works well with top_p
)

print(f"\n--- Generated MIDI Sequence ---")
print(generated_midi)
print(f"--- End of Sequence ---")

# To generate the midi file use text_to_midi.py

Generating sequence with Nucleus Sampling (top_p=0.95, temp=0.9)...
Prompt length: 2 tokens, Max new tokens: 510, Total max length: 512

--- Generated MIDI Sequence ---
p63:v55:d4:t1 p70:v77:d11:t2 p58:v61:d57:t5 p76:v65:d11:t5 p59:v63:d4:t676 p71:v91:d4604:t3 p71:v81:d4604:t0 p59:v70:d4604:t1 p59:v67:d4604:t4 p59:v69:d4604:t1 p62:v58:d4604:t1 p59:v57:d4:t613 p54:v47:d4604:t1 p59:v46:d9:t1 p54:v51:d4604:t5 p66:v52:d4170:t3 p83:v74:d4585:t5 p63:v68:d3840:t1 p66:v70:d4274:t0 p66:v70:d4187:t3 p87:v73:d4604:t2 p71:v70:d4604:t0 p59:v67:d4604:t1 p78:v70:d4604:t3 p59:v69:d4604:t1 p54:v61:d11:t0 p58:v67:d4547:t0 p64:v67:d5:t2 p66:v73:d4604:t8 p58:v65:d4604:t1197 p61:v66:d4604:t1 p54:v66:d4604:t8 p54:v66:d4604:t6 p70:v82:d4604:t0 p58:v63:d4604:t1 p64:v71:d4604:t0 p54:v67:d4604:t1 p73:v85:d4604:t0 p58:v73:d4604:t1 p54:v65:d4604:t5 p61:v66:d4604:t4 p54:v68:d4604:t5 p59:v70:d4454:t658 p78:v100:d3777:t5 p64:v73:d4292:t2 p54:v62:d3888:t5 p47:v58:d3640:t0 p54:v58:d4604:t7 p59:v69:d4604:t4 p54:v65:d46