# Music Generation with xLSTM + REMIGEN

This notebook generates music using our trained xLSTM model and decodes
REMIGEN tokens to MIDI.

Once the `xlstm` conda environment is set up, make sure to install the `midiProcessor`.

```python
cd ./repos/MidiProcessor
pip install .
pip install miditoolkit==0.1.16 numpy scipy pretty_midi mido tqdm
```

In [1]:
import sys
sys.path.append("/scratch1/e20-fyp-xlstm-music-generation/e20fyptemp1/fyp-musicgen/repos/helibrunna")

from source.languagemodel import LanguageModel
import midiprocessor as mp
import os
from pathlib import Path

print("‚úì Imports successful")

  from .autonotebook import tqdm as notebook_tqdm


‚úì Imports successful


## Load xLSTM Model

In [6]:
model_path = "/scratch1/e20-fyp-xlstm-music-generation/e20fyptemp1/fyp-musicgen/repos/helibrunna/output/lmd_remigen_xlstm/run_20260115-1028"

model = LanguageModel(
    model_path,
    config_overrides={"context_length": 16_384},  # Use full context
    device="cuda"  # or "cpu" if no GPU
)

model.summary()

[32m[1m   ‚ñÑ‚ñà    ‚ñà‚ñÑ       ‚ñÑ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  ‚ñÑ‚ñà        ‚ñÑ‚ñà  ‚ñÄ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñÑ     ‚ñÑ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà ‚ñà‚ñà‚ñà    ‚ñà‚ñÑ  ‚ñà‚ñà‚ñà‚ñÑ‚ñÑ‚ñÑ‚ñÑ   ‚ñà‚ñà‚ñà‚ñÑ‚ñÑ‚ñÑ‚ñÑ      ‚ñÑ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà 
[32m[22m  ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà     ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà ‚ñà‚ñà‚ñà       ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà   ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà ‚ñà‚ñà‚ñà‚ñÄ‚ñÄ‚ñÄ‚ñà‚ñà‚ñÑ ‚ñà‚ñà‚ñà‚ñÄ‚ñÄ‚ñÄ‚ñà‚ñà‚ñÑ   ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà 
[32m[1m  ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà     ‚ñà‚ñà‚ñà    ‚ñà‚ñÄ  ‚ñà‚ñà‚ñà       ‚ñà‚ñà‚ñà‚ñå   ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà   ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà ‚ñà‚ñà‚ñà   ‚ñà‚ñà‚ñà ‚ñà‚ñà‚ñà   ‚ñà‚ñà‚ñà   ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà 
[32m[22m ‚ñÑ‚ñà‚ñà‚ñà‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñà‚ñà‚ñà‚ñÑ‚ñÑ  ‚ñÑ‚ñà‚ñà‚ñà‚ñÑ‚ñÑ‚ñÑ     ‚ñà‚ñà‚ñà       ‚ñà‚ñà‚ñà‚ñå  ‚ñÑ‚ñà‚ñà‚ñà‚ñÑ‚ñÑ‚ñÑ‚ñà‚ñà‚ñÄ   ‚ñÑ‚ñà‚ñà‚ñà‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñà‚ñà‚ñÄ ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà ‚ñà‚ñà‚ñà   ‚ñà‚ñà‚ñà ‚ñà‚ñà‚ñà   ‚ñà‚ñà‚ñà   ‚ñà‚ñà‚ñà    ‚ñà‚ñà‚ñà 
[32

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'T5Tokenizer'. 
The class this function is called from is 'PreTrainedTokenizerFast'.


xLSTMLMModel(
  (xlstm_block_stack): xLSTMBlockStack(
    (blocks): ModuleList(
      (0-2): 3 x mLSTMBlock(
        (xlstm_norm): LayerNorm()
        (xlstm): mLSTMLayer(
          (proj_up): Linear(in_features=256, out_features=1024, bias=False)
          (q_proj): LinearHeadwiseExpand(in_features=512, num_heads=128, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
          (k_proj): LinearHeadwiseExpand(in_features=512, num_heads=128, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
          (v_proj): LinearHeadwiseExpand(in_features=512, num_heads=128, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
          (conv1d): CausalConv1d(
            (conv): Conv1d(512, 512, kernel_size=(4,), stride=(1,), padding=(3,), groups=512)
          )
          (conv_act_fn): SiLU()
          (mlstm_cell): mLSTMCell(
            (igate): Linear(in_features=1536, out_features=4, bias=True)
            (fgate)

## Generation Parameters

Settings for controlling the music generation.

In [16]:
# start_prompt = "s-9 o-0 t-38 i-35 p-62 d-2 v-22 o-6 t-38 i-35 p-62 d-2 v-17"

# Force specific instruments in the prompt
# start_prompt = "s-9 o-0 t-38 i-0 p-60 d-4 v-20 i-33 p-48 d-4 v-20 i-128 p-170 d-2 v-20"
#                        ^^^ Piano  ^^^ Bass       ^^^ Drums

In [7]:
# Temperature: controls randomness (0.5-1.5 recommended)
# Lower = more predictable, Higher = more creative
temperature = 0.8

# Maximum length in tokens
max_length = 5096  # Full context length

# Number of songs to generate
num_songs = 2

# Starting prompt (standard REMIGEN opening)
# This tells the model: "Start a new song in 9/8 time at tempo 38"
# start_prompt = "s-9 o-0 t-38"
start_prompt = "s-9 o-0 t-35 i-128 p-170 d-3 v-31 o-12 t-35 i-128 p-170 d-3 v-25 o-24 t-35 i-128 p-170 d-3 v-25 o-36 t-35 i-128 p-170 d-3 v-25 b-1 s-9 o-0 t-35 i-30 p-65 d-15 v-25 p-60 d-15 v-25 p-53 d-15 v-25 p-41"

# Output directory
output_dir = Path("./generated_music")
output_dir.mkdir(exist_ok=True)

print(f"Temperature: {temperature}")
print(f"Max length: {max_length} tokens")
print(f"Starting prompt: {start_prompt}")
print(f"Output directory: {output_dir}")

Temperature: 0.8
Max length: 5096 tokens
Starting prompt: s-9 o-0 t-35 i-128 p-170 d-3 v-31 o-12 t-35 i-128 p-170 d-3 v-25 o-24 t-35 i-128 p-170 d-3 v-25 o-36 t-35 i-128 p-170 d-3 v-25 b-1 s-9 o-0 t-35 i-30 p-65 d-15 v-25 p-60 d-15 v-25 p-53 d-15 v-25 p-41
Output directory: generated_music


## Token Generation Function

Generates REMIGEN tokens from the model.

In [8]:
def generate_remigen_tokens(
    model,
    prompt="s-9 o-0 t-38",
    temperature=0.8,
    max_length=2048,
    stop_at_bars=None,
    verbose=True
):
    """
    Generate REMIGEN tokens from the xLSTM model.
    
    Args:
        model: The LanguageModel instance
        prompt: Starting prompt (should start with s-X o-0 t-X)
        temperature: Sampling temperature
        max_length: Maximum number of tokens to generate
        stop_at_bars: If set, stop after generating N bars (b-1 tokens)
        verbose: Print progress information
    
    Returns:
        (tokens_string, info_dict)
    """
    
    if verbose:
        print(f"üéµ Starting generation...")
        print(f"   Prompt: {prompt}")
        print(f"   Max tokens: {max_length:,}")
        print(f"   Temperature: {temperature}")
        if stop_at_bars:
            print(f"   Target bars: {stop_at_bars}")
        print()
    
    # Generate with the model
    import time
    start_time = time.time()
    
    if verbose:
        print("‚è≥ Generating tokens...", end="", flush=True)
    
    output_dict = model.generate(
        prompt=prompt,
        temperature=temperature,
        max_length=max_length,
        end_tokens=[],
        forbidden_tokens=["[PAD]", "[EOS]"],
        return_structured_output=True
    )
    
    elapsed = time.time() - start_time
    
    if verbose:
        print(f" Done! ({elapsed:.1f}s)")
    
    # Extract generated tokens
    tokens = output_dict["output"]
    token_list = tokens.split()
    
    if verbose:
        print(f"üìä Generated {len(token_list):,} raw tokens")
    
    # FILTER OUT INVALID TOKENS
    valid_tokens = []
    invalid_count = 0
    
    for token in token_list:
        # Only keep tokens with format: prefix-value
        if '-' in token and not token.startswith('['):
            valid_tokens.append(token)
        else:
            invalid_count += 1
            if verbose and invalid_count <= 5:  # Show first 5
                print(f"‚ö†Ô∏è  Filtered invalid token: {token}")
    
    if verbose and invalid_count > 5:
        print(f"‚ö†Ô∏è  Filtered {invalid_count - 5} more invalid tokens...")
    
    # Count bars in valid tokens
    bar_count = sum(1 for t in valid_tokens if t == "b-1")
    
    if verbose:
        print(f"‚úì Valid tokens: {len(valid_tokens):,}")
        print(f"‚úì Bars generated: {bar_count}")
    
    # Rejoin
    tokens = " ".join(valid_tokens)
    
    # Optional: truncate at bar limit
    if stop_at_bars is not None:
        if verbose:
            print(f"‚úÇÔ∏è  Truncating to {stop_at_bars} bars...")
        
        truncated = []
        bars_seen = 0
        
        for token in valid_tokens:
            truncated.append(token)
            if token == "b-1":
                bars_seen += 1
                if bars_seen >= stop_at_bars:
                    break
        
        tokens = " ".join(truncated)
        
        if verbose:
            print(f"‚úì Truncated to {len(truncated):,} tokens ({bars_seen} bars)")
    
    # Add extra info to output dict
    output_dict.update({
        "valid_tokens": len(valid_tokens),
        "invalid_tokens": invalid_count,
        "bars": bar_count,
        "elapsed_time": elapsed
    })
    
    if verbose:
        print(f"‚ö° Speed: {len(valid_tokens)/elapsed:.1f} tokens/sec")
        print()
    
    return tokens, output_dict

In [5]:
# Test generation
print("Testing generation...")
test_tokens, test_info = generate_remigen_tokens(
    model,
    prompt=start_prompt,
    temperature=temperature,
    max_length=200,  # Short test
    verbose=True
)

print(f"Generated {len(test_tokens.split())} tokens")
print(f"Speed: {test_info['tokens_per_second']:.2f} tokens/sec")
print(f"\nFirst 100 chars: {test_tokens[:100]}")

Testing generation...
üéµ Starting generation...
   Prompt: s-9 o-0 t-35 i-128 p-170 d-3 v-31 o-12 t-35 i-128 p-170 d-3 v-25 o-24 t-35 i-128 p-170 d-3 v-25 o-36 t-35 i-128 p-170 d-3 v-25 b-1 s-9 o-0 t-35 i-30 p-65 d-15 v-25 p-60 d-15 v-25 p-53 d-15 v-25 p-41
   Max tokens: 200
   Temperature: 0.8

‚è≥ Generating tokens... Done! (7.6s)
üìä Generated 200 raw tokens
‚úì Valid tokens: 200
‚úì Bars generated: 2
‚ö° Speed: 26.4 tokens/sec

Generated 200 tokens
Speed: 21.15 tokens/sec

First 100 chars: s-9 o-0 t-35 i-128 p-170 d-3 v-31 o-12 t-35 i-128 p-170 d-3 v-25 o-24 t-35 i-128 p-170 d-3 v-25 o-36


## Decode Function

REMIGEN ‚Üí MIDI Decoder

Converts generated REMIGEN tokens to MIDI files.

In [9]:
def decode_remigen_to_midi(token_string, output_path):
    """
    Decode REMIGEN tokens to MIDI file.
    
    Args:
        token_string: Space-separated REMIGEN tokens
        output_path: Path to save .mid file
    
    Returns:
        True if successful, False otherwise
    """
    try:
        # Split tokens
        tokens = token_string.strip().split()
        
        # Initialize decoder
        decoder = mp.MidiDecoder('REMIGEN')
        
        # Decode to MIDI object
        midi_obj = decoder.decode_from_token_str_list(tokens)
        
        # Save
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        midi_obj.dump(output_path)
        
        return True
    
    except Exception as e:
        print(f"Error decoding: {e}")
        return False

In [10]:
# Test decoding
test_output_path = output_dir / "test_decode.mid"
success = decode_remigen_to_midi(test_tokens, str(test_output_path))

if success:
    print(f"‚úì Test decode successful: {test_output_path}")
else:
    print("‚úó Test decode failed")

‚úì Test decode successful: generated_music/test_decode.mid


## Generate Multiple Songs
Generate a batch of music samples.

In [None]:
print(f"Generating {num_songs} songs...\n")

generated_songs = []

for i in range(num_songs):
    print(f"Generating song {i+1}/{num_songs}...")
    
    # Generate tokens
    tokens, info = generate_remigen_tokens(
        model,
        prompt=start_prompt,
        temperature=temperature,
        max_length=max_length,
        # stop_at_bars=32  # Generate 32 bars per song
    )
    
    # Save info
    song_data = {
        "id": i,
        "tokens": tokens,
        "num_tokens": len(tokens.split()),
        "generation_time": info["elapsed_time"],
        "tokens_per_sec": info["tokens_per_second"]
    }
    generated_songs.append(song_data)
    
    print(f"  Generated {song_data['num_tokens']} tokens in {song_data['generation_time']:.2f}s")
    print(f"  Speed: {song_data['tokens_per_sec']:.2f} tokens/sec\n")

print(f"‚úì Generated {len(generated_songs)} songs")

# Iterative Generation for Long Sequences

In [11]:
def generate_long_remigen_iterative(
    model,
    start_prompt="s-9 o-0 t-38",
    temperature=0.8,
    chunk_size=1500,
    max_iterations=10,
    stop_at_bars=None
):
    """
    Generate long REMIGEN sequences iteratively.
    """
    
    output = start_prompt
    total_bars = 0
    
    print(f"üéµ Starting iterative generation...")
    print(f"   Chunk size: {chunk_size} tokens")
    print(f"   Max iterations: {max_iterations}\n")
    
    for iteration in range(max_iterations):
        print(f"üìù Iteration {iteration + 1}/{max_iterations}")
        print(f"   Current length: {len(output.split())} tokens")
        
        # Generate continuation
        chunk, info = generate_remigen_tokens(
            model,
            prompt=output,  # Use full context
            temperature=temperature,
            max_length=len(output.split()) + chunk_size,  # ‚Üê IMPORTANT: context + new tokens
            verbose=False
        )
        
        # Extract only NEW tokens (after the prompt)
        chunk_tokens = chunk.split()
        output_tokens = output.split()
        
        # Find where new tokens start
        if len(chunk_tokens) > len(output_tokens):
            new_tokens = chunk_tokens[len(output_tokens):]  # Get only new tokens
            output = output + " " + " ".join(new_tokens)  # APPEND new tokens
        else:
            print(f"‚ö†Ô∏è  No new tokens generated, stopping.")
            break
        
        # Count bars
        bars_in_output = output.count("b-1")
        new_bars = bars_in_output - total_bars
        total_bars = bars_in_output
        
        print(f"   Added {len(new_tokens)} new tokens ({new_bars} new bars)")
        print(f"   Total: {len(output.split())} tokens, {total_bars} bars\n")
        
        # Check stopping conditions
        if stop_at_bars and total_bars >= stop_at_bars:
            print(f"‚úì Reached target of {stop_at_bars} bars!")
            break
        
        # Clear CUDA cache
        import torch
        torch.cuda.empty_cache()
    
    print(f"‚úì Generation complete!")
    print(f"   Final: {len(output.split())} tokens, {total_bars} bars")
    
    return output

In [12]:
# Use it!
print("Generating long sequence iteratively...\n")

tokens = generate_long_remigen_iterative(
    model,
    start_prompt=start_prompt,
    temperature=0.8,
    chunk_size=1500,  # Safe size per iteration
    max_iterations=5,  # 5 iterations x 1500 = ~7500 tokens
    stop_at_bars=64  # Or set number of bars
)

print(f"\n‚úì Generated {len(tokens.split())} tokens!")

Generating long sequence iteratively...

üéµ Starting iterative generation...
   Chunk size: 1500 tokens
   Max iterations: 5

üìù Iteration 1/5
   Current length: 40 tokens
   Added 1500 new tokens (12 new bars)
   Total: 1540 tokens, 12 bars

üìù Iteration 2/5
   Current length: 1540 tokens


OutOfMemoryError: CUDA out of memory. Tried to allocate 114.00 MiB. GPU 0 has a total capacity of 47.37 GiB of which 108.31 MiB is free. Process 869539 has 890.00 MiB memory in use. Process 3386576 has 27.32 GiB memory in use. Process 3652956 has 8.60 GiB memory in use. Process 3665233 has 2.03 GiB memory in use. Including non-PyTorch memory, this process has 8.41 GiB memory in use. Of the allocated memory 7.54 GiB is allocated by PyTorch, and 235.81 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

### Decode All Generated Songs to MIDI

Convert all generated token sequences to MIDI files.

In [None]:
print("Decoding to MIDI files...\n")

successful = 0
failed = 0

for song_data in generated_songs:
    song_id = song_data["id"]
    tokens = song_data["tokens"]
    
    # Output path
    midi_path = output_dir / f"generated_song_{song_id:03d}.mid"
    
    # Decode
    success = decode_remigen_to_midi(tokens, str(midi_path))
    
    if success:
        successful += 1
        print(f"‚úì Song {song_id}: {midi_path.name}")
    else:
        failed += 1
        print(f"‚úó Song {song_id}: Failed to decode")

print(f"\n‚úì Successfully decoded: {successful}/{len(generated_songs)}")
if failed > 0:
    print(f"‚úó Failed: {failed}/{len(generated_songs)}")


Decoding to MIDI files...

‚úì Song 0: generated_song_000.mid
‚úì Song 1: generated_song_001.mid

‚úì Successfully decoded: 2/2


## Analyze Generated MIDI Files

In [None]:
import pretty_midi

print("Analyzing generated music...\n")

for song_data in generated_songs[:3]:  # Analyze first 3 songs
    song_id = song_data["id"]
    midi_path = output_dir / f"generated_song_{song_id:03d}.mid"
    
    if not midi_path.exists():
        continue
    
    # Load MIDI
    midi = pretty_midi.PrettyMIDI(str(midi_path))
    
    print(f"Song {song_id}:")
    print(f"  Duration: {midi.get_end_time():.2f}s")
    print(f"  Instruments: {len(midi.instruments)}")
    print(f"  Total notes: {sum(len(inst.notes) for inst in midi.instruments)}")
    
    # Show instruments
    for inst in midi.instruments:
        inst_type = "Drums" if inst.is_drum else f"Program {inst.program}"
        print(f"    {inst.name}: {len(inst.notes)} notes ({inst_type})")
    print()

Analyzing generated music...

Song 0:
  Duration: 17.75s
  Instruments: 5
  Total notes: 555
    28: 62 notes (Program 28)
    29: 261 notes (Program 29)
    30: 145 notes (Program 30)
    35: 38 notes (Program 35)
    128: 49 notes (Drums)

Song 1:
  Duration: 24.85s
  Instruments: 3
  Total notes: 490
    30: 218 notes (Program 30)
    33: 95 notes (Program 33)
    128: 177 notes (Drums)



## Save Generated Tokens

Save raw token sequences for analysis.

In [None]:


tokens_dir = output_dir / "tokens"
tokens_dir.mkdir(exist_ok=True)

for song_data in generated_songs:
    song_id = song_data["id"]
    tokens = song_data["tokens"]
    
    token_path = tokens_dir / f"generated_song_{song_id:03d}.txt"
    
    with open(token_path, 'w') as f:
        f.write(tokens)

print(f"‚úì Saved token files to {tokens_dir}")

‚úì Saved token files to generated_music/tokens


## Generation Summary

Summary of the music generation session.

In [None]:
print("=" * 60)
print("MUSIC GENERATION SUMMARY")
print("=" * 60)
print(f"Model: {model_path}")
print(f"Songs generated: {len(generated_songs)}")
print(f"Temperature: {temperature}")
print(f"Max tokens: {max_length}")
print(f"Output directory: {output_dir}")
print(f"\nGenerated files:")
print(f"  MIDI files: {output_dir}/*.mid")
print(f"  Token files: {tokens_dir}/*.txt")
print("=" * 60)

MUSIC GENERATION SUMMARY
Model: /scratch1/e20-fyp-xlstm-music-generation/e20fyptemp1/fyp-musicgen/repos/helibrunna/output/lmd_remigen_xlstm/run_20260115-1028
Songs generated: 2
Temperature: 0.8
Max tokens: 2048
Output directory: generated_music

Generated files:
  MIDI files: generated_music/*.mid
  Token files: generated_music/tokens/*.txt
