# xLSTM Music Generation - Clean Pipeline

This notebook provides a clean, modular approach to generating music with your trained xLSTM model.

## Key Fixes from Your Original Code:

1. **Memory Issue Fixed**: The problem was `max_length` growing with each iteration
   - **Wrong**: `max_length = len(output.split()) + chunk_size` (creates quadratic memory growth)
   - **Right**: Use fixed `max_length` OR sliding window context

2. **Context Length**: You can use larger context during inference than training
   - Trained with 2048 ‚Üí Can infer with 4096 or more
   - But memory grows with context¬≤ in mLSTM

3. **Modular Design**: Clean separation of generation and conversion logic

In [1]:
print("Hello0")

Hello0


## Setup

In [2]:
import sys
sys.path.append("/scratch1/e20-fyp-xlstm-music-generation/e20fyptemp1/fyp-musicgen/repos/helibrunna")

from xlstm_music_generation import MusicGenerator, MIDIConverter, generate_music
from pathlib import Path

import os
os.environ['TORCH_CUDA_ARCH_LIST'] = '8.0;8.6;8.9'
os.environ['MAX_JOBS'] = '4'

  from .autonotebook import tqdm as notebook_tqdm


## Initialize Generator

In [19]:
MODEL_PATH = "/scratch1/e20-fyp-xlstm-music-generation/e20fyptemp1/fyp-musicgen/repos/helibrunna/output/lmd_remigen_xlstm/run_20260115-1028"

# For short sequences (< 2048 tokens)
generator = MusicGenerator(
    model_path=MODEL_PATH,
    context_length=8192,  # Same as training
    device="cuda"
)

converter = MIDIConverter()

Loading model from: /scratch1/e20-fyp-xlstm-music-generation/e20fyptemp1/fyp-musicgen/repos/helibrunna/output/lmd_remigen_xlstm/run_20260115-1028
[32m[1m   ‚ñÑ‚ñà    ‚ñà‚ñÑ       ‚ñÑ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  ‚ñÑ‚ñà        ‚ñÑ‚ñà  ‚ñÄ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñÑ     ‚ñÑ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà ‚ñà‚ñà‚ñà    ‚ñà‚ñÑ  ‚ñà‚ñà‚ñà‚ñÑ‚ñÑ‚ñÑ‚ñÑ   ‚ñà‚ñà‚ñà‚ñÑ‚ñÑ‚ñÑ‚ñÑ      ‚ñÑ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà 
[32m[22m  ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà     ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà ‚ñà‚ñà‚ñà       ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà   ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà ‚ñà‚ñà‚ñà‚ñÄ‚ñÄ‚ñÄ‚ñà‚ñà‚ñÑ ‚ñà‚ñà‚ñà‚ñÄ‚ñÄ‚ñÄ‚ñà‚ñà‚ñÑ   ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà 
[32m[1m  ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà     ‚ñà‚ñà‚ñà    ‚ñà‚ñÄ  ‚ñà‚ñà‚ñà       ‚ñà‚ñà‚ñà‚ñå   ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà   ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà ‚ñà‚ñà‚ñà   ‚ñà‚ñà‚ñà ‚ñà‚ñà‚ñà   ‚ñà‚ñà‚ñà   ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà 
[32m[22m ‚ñÑ‚ñà‚ñà‚ñà‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñà‚ñà‚ñà‚ñÑ‚ñÑ  ‚ñÑ‚ñà‚ñà‚ñà‚ñÑ‚ñÑ‚ñÑ     ‚ñà‚ñà‚ñà       ‚ñà‚ñà‚ñà‚ñå  ‚ñÑ‚ñà‚ñà‚ñà‚ñÑ‚ñÑ

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'T5Tokenizer'. 
The class this function is called from is 'PreTrainedTokenizerFast'.


‚úì Model loaded (context: 8192 tokens)


In [20]:
# Try single-shot generation with larger max_length
result = generator.generate(
    prompt="s-9 o-0 t-35",
    temperature=0.9,
    max_tokens=2000,
    verbose=True
)


üéµ Generating...
   Prompt: s-9 o-0 t-35...
   Max tokens: 2000
   Temperature: 0.9
‚úì Generated 2000 tokens (7 bars)


In [None]:
# Save
output_path = "./output/sample_8_ct4096-len1000.mid"
success = converter.tokens_to_midi(result['tokens'], output_path, clean=True)

if success:
    print(f"‚úì Saved: {output_path}")
    print(f"Generated {result['num_tokens']} tokens, {result['bars']} bars")
else:
    print("Decoding failed")

‚úì Saved: ./output/sample_8_ct4096-len100.mid
Generated 1999 tokens, 15 bars


In [5]:
# Debug: see what token is causing the error
tokens = result['tokens'].split()
print(f"Total tokens: {len(tokens)}")
print(f"Last 30 tokens: {tokens[-30:]}")

# Find incomplete triplets
for i, token in enumerate(tokens[-30:], start=len(tokens)-30):
    if token.startswith('p-'):
        if i+1 >= len(tokens) or not tokens[i+1].startswith('d-'):
            print(f"Incomplete at {i}: {token} (no duration)")
        elif i+2 >= len(tokens) or not tokens[i+2].startswith('v-'):
            print(f"Incomplete at {i}: {token} {tokens[i+1]} (no velocity)")

Total tokens: 3056
Last 30 tokens: ['d-3', 'v-12', 'd-3', 'v-12', 'i-128', 'p-166', 'd-3', 'v-12', 'b-1', 'd-3', 'v-12', 'd-3', 'v-12', 'd-2', 'd-3', 'd-14', 'd-2', 'd-2', 'd-2', 'p-77', 'd-2', 'v-12', 'o-6', 't-33', 'i-34', 'v-8', 'd-2', 'd-3', 'd-2', 'd-3']


## Example 1: Generate Short Piece

In [10]:
# Simple generation
result = generator.generate(
    prompt="s-9 o-0 t-38",
    temperature=0.8,
    max_tokens=1000,
    verbose=True
)

print(f"\nGenerated {result['num_tokens']} tokens, {result['bars']} bars")

# Convert to MIDI
output_path = "./output/test_song.mid"
converter.tokens_to_midi(result['tokens'], output_path)
print(f"Saved to: {output_path}")

üéµ Generating...
   Prompt: s-9 o-0 t-38...
   Max tokens: 1000
   Temperature: 0.8
‚úì Generated 1000 tokens (8 bars)

Generated 1000 tokens, 8 bars
Saved to: ./output/test_song.mid


## Example 2: Generate Long Piece (Chunked)

This uses **sliding window** approach to avoid memory issues.

In [11]:
# For long generation, use larger context
long_generator = MusicGenerator(
    model_path=MODEL_PATH,
    context_length=4096,  # Larger than training
    device="cuda"
)

result = long_generator.generate_long(
    prompt="s-9 o-30 t-33 i-128 p-176 d-6 v-23 o-36 t-33 i-128 p-173 d-6 v-23 o-42 t-33 i-128 p-171 d-6 v-23 b-1 s-9 o-0 t-33 i-4 p-81 d-25",
    temperature=0.8,
    target_bars=64,       # Generate 64 bars
    chunk_tokens=1024,    # 1024 new tokens per iteration
    max_iterations=2,
    verbose=True
)

# Save
output_path = "./output/long_song.mid"
converter.tokens_to_midi(result['tokens'], output_path)
print(f"\nSaved {result['bars']} bars to: {output_path}")

Loading model from: /scratch1/e20-fyp-xlstm-music-generation/e20fyptemp1/fyp-musicgen/repos/helibrunna/output/lmd_remigen_xlstm/run_20260115-1028
[32m[1m   ‚ñÑ‚ñà    ‚ñà‚ñÑ       ‚ñÑ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  ‚ñÑ‚ñà        ‚ñÑ‚ñà  ‚ñÄ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñÑ     ‚ñÑ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà ‚ñà‚ñà‚ñà    ‚ñà‚ñÑ  ‚ñà‚ñà‚ñà‚ñÑ‚ñÑ‚ñÑ‚ñÑ   ‚ñà‚ñà‚ñà‚ñÑ‚ñÑ‚ñÑ‚ñÑ      ‚ñÑ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà 
[32m[22m  ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà     ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà ‚ñà‚ñà‚ñà       ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà   ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà ‚ñà‚ñà‚ñà‚ñÄ‚ñÄ‚ñÄ‚ñà‚ñà‚ñÑ ‚ñà‚ñà‚ñà‚ñÄ‚ñÄ‚ñÄ‚ñà‚ñà‚ñÑ   ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà 
[32m[1m  ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà     ‚ñà‚ñà‚ñà    ‚ñà‚ñÄ  ‚ñà‚ñà‚ñà       ‚ñà‚ñà‚ñà‚ñå   ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà   ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà ‚ñà‚ñà‚ñà   ‚ñà‚ñà‚ñà ‚ñà‚ñà‚ñà   ‚ñà‚ñà‚ñà   ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà 
[32m[22m ‚ñÑ‚ñà‚ñà‚ñà‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñà‚ñà‚ñà‚ñÑ‚ñÑ  ‚ñÑ‚ñà‚ñà‚ñà‚ñÑ‚ñÑ‚ñÑ     ‚ñà‚ñà‚ñà       ‚ñà‚ñà‚ñà‚ñå  ‚ñÑ‚ñà‚ñà‚ñà‚ñÑ‚ñÑ

  @conditional_decorator(
  @conditional_decorator(
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'T5Tokenizer'. 
The class this function is called from is 'PreTrainedTokenizerFast'.


‚úì Model loaded (context: 4096 tokens)
üéµ Long generation (chunked)...
   Target: 64 bars
   Chunk size: 1024 tokens

üìù Iteration 1/2
   Context: 26 tokens
   Added: 1024 tokens (3 bars)
   Total: 1050 tokens (4 bars)

üìù Iteration 2/2
   Context: 1050 tokens
   Added: 1023 tokens (4 bars)
   Total: 2073 tokens (8 bars)

‚úì Generation complete!
   Final: 2073 tokens, 8 bars

Saved 8 bars to: ./output/long_song.mid


In [5]:
# Add detailed error reporting
cleaned = converter.clean_tokens(result['tokens'])
print(f"Cleaned: {len(cleaned.split())} tokens")
print(f"Last 20 cleaned: {cleaned.split()[-20:]}")

try:
    midi_obj = converter.decoder.decode_from_token_str_list(cleaned.split())
    print("‚úì Decoding successful!")
except Exception as e:
    print(f"Error: {type(e).__name__}: {e}")
    import traceback
    traceback.print_exc()

Cleaned: 3509 tokens
Last 20 cleaned: ['i-128', 'o-47', 'p-66', 'd-2', 'v-22', 'b-1', 'o-47', 'b-1', 'o-2', 'o-47', 'b-1', 'o-47', 'b-1', 'o-47', 'o-47', 'i-37', 'p-27', 'd-5', 'i-128', 'o-47']
Error: AssertionError: 


Traceback (most recent call last):
  File "/tmp/ipykernel_3928866/513666804.py", line 7, in <module>
    midi_obj = converter.decoder.decode_from_token_str_list(cleaned.split())
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/e20037/miniconda/envs/xlstm/lib/python3.11/site-packages/midiprocessor/midi_decoding.py", line 195, in decode_from_token_str_list
    midi_obj = self.decode_from_token_list(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/e20037/miniconda/envs/xlstm/lib/python3.11/site-packages/midiprocessor/midi_decoding.py", line 244, in decode_from_token_list
    return enc_remigen_utils.generate_midi_obj_from_remigen_token_list(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/e20037/miniconda/envs/xlstm/lib/python3.11/site-packages/midiprocessor/enc_remigen_utils.py", line 250, in generate_midi_obj_from_remigen_token_list
    assert last_item_type == const.DURATION_ABBR
           ^^^^^^

## Example 3: Batch Generation with Different Temperatures

In [None]:
temperatures = [0.5, 0.8, 1.0, 1.2]
output_dir = Path("./output/temp_comparison")
output_dir.mkdir(exist_ok=True, parents=True)

for temp in temperatures:
    print(f"\n{'='*60}")
    print(f"Temperature: {temp}")
    print('='*60)
    
    result = generator.generate(
        prompt="s-9 o-0 t-35",
        temperature=temp,
        max_tokens=2000,
        verbose=True
    )
    
    midi_path = output_dir / f"temp_{temp:.1f}.mid"
    converter.tokens_to_midi(result['tokens'], str(midi_path))
    print(f"‚úì Saved: {midi_path}")

## Example 4: Simple API (One Function)

In [None]:
# Generate 5 songs with one function call
outputs = generate_music(
    model_path=MODEL_PATH,
    num_songs=5,
    max_tokens=2048,
    temperature=0.8,
    output_dir="./output/batch"
)

print(f"\nGenerated {len(outputs)} songs")

## Example 5: Long Mode with Simple API

In [None]:
# Generate long pieces
outputs = generate_music(
    model_path=MODEL_PATH,
    num_songs=2,
    temperature=0.8,
    output_dir="./output/long_batch",
    long_mode=True,
    target_bars=64
)

print(f"\nGenerated {len(outputs)} long songs")

## Analyzing Generated MIDI

In [10]:
import pretty_midi

def analyze_midi(midi_path):
    midi = pretty_midi.PrettyMIDI(str(midi_path))
    
    print(f"File: {midi_path.name}")
    print(f"Duration: {midi.get_end_time():.2f}s")
    print(f"Instruments: {len(midi.instruments)}")
    print(f"Total notes: {sum(len(inst.notes) for inst in midi.instruments)}")
    
    for inst in midi.instruments:
        inst_type = "Drums" if inst.is_drum else f"Program {inst.program}"
        print(f"  - {inst.name}: {len(inst.notes)} notes ({inst_type})")
    print()

# Analyze all generated files
for midi_file in Path("./output").glob("**/*.mid"):
    analyze_midi(midi_file)

File: single_shot_test.mid
Duration: 28.80s
Instruments: 6
Total notes: 457
  - 27: 113 notes (Program 27)
  - 35: 113 notes (Program 35)
  - 52: 83 notes (Program 52)
  - 53: 45 notes (Program 53)
  - 54: 46 notes (Program 54)
  - 128: 57 notes (Drums)

File: long_song.mid
Duration: 33.30s
Instruments: 5
Total notes: 507
  - 4: 57 notes (Program 4)
  - 24: 158 notes (Program 24)
  - 33: 54 notes (Program 33)
  - 49: 59 notes (Program 49)
  - 128: 179 notes (Drums)

File: test_song.mid
Duration: 14.20s
Instruments: 4
Total notes: 244
  - 16: 40 notes (Program 16)
  - 25: 170 notes (Program 25)
  - 35: 30 notes (Program 35)
  - 128: 4 notes (Drums)



## Understanding Context Length

### Training vs Inference:
- **Training**: Model was trained with `context_length=2048`
- **Inference**: You can use `context_length=4096` or higher
  - The model can handle longer sequences
  - But memory usage grows quadratically (N¬≤ for mLSTM)

### Memory Usage:
- `context_length=2048` ‚Üí ~10GB VRAM
- `context_length=4096` ‚Üí ~40GB VRAM  
- `context_length=8192` ‚Üí ~160GB VRAM (likely OOM)

### Solution for Long Generation:
Use **sliding window** (implemented in `generate_long()`):
- Keep only last N tokens as context
- Generate new chunk
- Slide window forward
- Repeat

This keeps memory constant while generating arbitrarily long sequences.

## Custom Prompts

REMIGEN format: `s-X o-Y t-Z i-A p-B d-C v-D ...`

- `s-X`: Signature (time signature)
- `o-Y`: Offset (timing)
- `t-Z`: Tempo
- `i-A`: Instrument
- `p-B`: Pitch
- `d-C`: Duration
- `v-D`: Velocity
- `b-1`: Bar marker

In [None]:
# Try different starting prompts
prompts = [
    "s-9 o-0 t-35",  # Slow tempo
    "s-9 o-0 t-120", # Fast tempo
    "s-9 o-0 t-60 i-0",  # With piano
]

for i, prompt in enumerate(prompts):
    result = generator.generate(
        prompt=prompt,
        temperature=0.8,
        max_tokens=1500,
        verbose=True
    )
    
    midi_path = f"./output/custom_prompt_{i}.mid"
    converter.tokens_to_midi(result['tokens'], midi_path)
    print(f"Saved: {midi_path}\n")

## Comparison with Museformer

For your research comparison, you can now:

1. Generate same number of pieces from both models
2. Use same prompts/seeds
3. Compare:
   - Musicality
   - Coherence over long sequences
   - Diversity
   - Computational requirements

In [None]:
# Generate evaluation dataset
eval_output = Path("./evaluation/xlstm_samples")
eval_output.mkdir(exist_ok=True, parents=True)

for i in range(20):  # Generate 20 samples
    result = generator.generate(
        prompt="s-9 o-0 t-35",
        temperature=0.8,
        max_tokens=2048,
        verbose=False
    )
    
    midi_path = eval_output / f"xlstm_{i:03d}.mid"
    converter.tokens_to_midi(result['tokens'], str(midi_path))
    
    if (i + 1) % 5 == 0:
        print(f"Generated {i+1}/20 samples")

print(f"\n‚úì Evaluation dataset ready: {eval_output}")