In [1]:
%%capture
# Clone the BachGen repository
!rm -rf BachGen && git clone https://github.com/gomar0801/BachGen.git
!chmod +x ./BachGen/scripts/setup.sh
!./BachGen/scripts/setup.sh

In [3]:
!pip install ./BachGen

Processing ./BachGen
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hBuilding wheels for collected packages: bachgen
  Building wheel for bachgen (pyproject.toml) ... [?25ldone
[?25h  Created wheel for bachgen: filename=bachgen-0.1.0-py3-none-any.whl size=11474 sha256=5e65ad29115680f731e36df603a54b79550855a5eca62b173bc91d4d3f1d2817
  Stored in directory: /private/var/folders/n1/bdjyqqwn5t5cglz8mdg_13900000gn/T/pip-ephem-wheel-cache-jtv832c8/wheels/64/3a/c7/28ae8ad8901358706742190d440cd5c543b38d6094faabcc5c
Successfully built bachgen
Installing collected packages: bachgen
  Attempting uninstall: bachgen
    Found existing installation: bachgen 0.1.0
    Uninstalling bachgen-0.1.0:
      Successfully uninstalled bachgen-0.1.0
Successfully installed bachgen-0.1.0

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.

In [2]:

from pathlib import Path
import json

# Our custom modules
from bachgen.display_and_play_partition import display_and_play
from bachgen.midi_bpe_token_style.music_xml_to_midi import music_xml_to_midi
from bachgen.midi_bpe_token_style.midi_to_music_xml import midi_to_music_xml
from bachgen.midi_bpe_token_style.midi_to_token import midi_to_token
from bachgen.midi_bpe_token_style.token_to_midi import token_to_midi

# Additional imports for MIDI display
from music21 import converter, stream
from symusic import Score
import pretty_midi
import matplotlib.pyplot as plt
import numpy as np

print("🎵 MIDI BPE Pipeline Demo")
print("=" * 50)


ModuleNotFoundError: No module named 'bachgen.midi_bpe_token_style'

In [None]:

# ==================== CELL 2: Helper Functions ====================

def display_midi_info(midi_path, show_notes=True):
    """
    Display MIDI file information and optionally show piano roll
    
    Args:
        midi_path (str): Path to MIDI file
        show_notes (bool): Whether to display piano roll visualization
    """
    print(f"\n🎹 MIDI FILE ANALYSIS: {midi_path}")
    print("-" * 40)
    
    try:
        # Load with symusic for detailed info
        score = Score(str(midi_path))
        print(f"Duration: {score.end()} ticks ({score.end() / score.tpq:.2f} beats)")
        print(f"Time division: {score.tpq} ticks per quarter note")
        print(f"Number of tracks: {len(score.tracks)}")
        
        # Track details
        total_notes = 0
        for i, track in enumerate(score.tracks):
            print(f"  Track {i}: {len(track.notes)} notes, '{track.name}', program {track.program}")
            total_notes += len(track.notes)
        
        print(f"Total notes: {total_notes}")
        
        if show_notes and total_notes > 0:
            # Create piano roll visualization
            try:
                pm = pretty_midi.PrettyMIDI(str(midi_path))
                
                # Get piano roll
                piano_roll = pm.get_piano_roll(fs=10)  # 10 Hz sampling
                
                # Plot piano roll
                plt.figure(figsize=(12, 6))
                plt.imshow(piano_roll, aspect='auto', origin='lower', cmap='Blues')
                plt.colorbar(label='Velocity')
                plt.xlabel('Time (0.1s)')
                plt.ylabel('MIDI Note')
                plt.title(f'Piano Roll: {Path(midi_path).name}')
                plt.tight_layout()
                plt.show()
                
            except Exception as e:
                print(f"  Note: Could not create piano roll visualization: {e}")
    
    except Exception as e:
        print(f"❌ Error analyzing MIDI: {e}")

def display_tokens(tokens_or_path, max_tokens=50):
    """
    Display tokenization results
    
    Args:
        tokens_or_path: Either token data or path to token JSON file
        max_tokens (int): Maximum number of tokens to display per track
    """
    print(f"\n🔤 TOKEN ANALYSIS")
    print("-" * 40)
    
    try:
        if isinstance(tokens_or_path, (str, Path)):
            # Load from file
            with open(tokens_or_path, 'r') as f:
                data = json.load(f)
            token_data = data['token_data']
            print(f"Source: {data.get('source_midi', 'unknown')}")
            print(f"Model: {data.get('tokenizer_model', 'unknown')}")
        else:
            # Direct token data
            token_data = tokens_or_path
        
        if isinstance(token_data, list):
            # Multiple tracks
            total_tokens = 0
            for i, track_data in enumerate(token_data):
                if isinstance(track_data, dict):
                    tokens = track_data['tokens']
                else:
                    tokens = track_data.tokens if hasattr(track_data, 'tokens') else track_data
                
                total_tokens += len(tokens)
                print(f"\nTrack {i}: {len(tokens)} tokens")
                print(f"  First {min(max_tokens, len(tokens))} tokens:")
                for j, token in enumerate(tokens[:max_tokens]):
                    print(f"    {j:2d}: {token}")
                if len(tokens) > max_tokens:
                    print(f"    ... and {len(tokens) - max_tokens} more tokens")
            
            print(f"\nTotal tokens across all tracks: {total_tokens}")
            
        else:
            # Single sequence
            if isinstance(token_data, dict):
                tokens = token_data['tokens']
            else:
                tokens = token_data.tokens if hasattr(token_data, 'tokens') else token_data
            
            print(f"Single sequence: {len(tokens)} tokens")
            print(f"First {min(max_tokens, len(tokens))} tokens:")
            for i, token in enumerate(tokens[:max_tokens]):
                print(f"  {i:2d}: {token}")
            if len(tokens) > max_tokens:
                print(f"  ... and {len(tokens) - max_tokens} more tokens")
        
        # Token type analysis
        if isinstance(token_data, list):
            all_tokens = []
            for track_data in token_data:
                if isinstance(track_data, dict):
                    all_tokens.extend(track_data['tokens'])
                else:
                    all_tokens.extend(track_data.tokens if hasattr(track_data, 'tokens') else track_data)
        else:
            if isinstance(token_data, dict):
                all_tokens = token_data['tokens']
            else:
                all_tokens = token_data.tokens if hasattr(token_data, 'tokens') else token_data
        
        # Analyze token types
        token_types = {}
        for token in all_tokens:
            token_type = token.split('_')[0] if '_' in token else token
            token_types[token_type] = token_types.get(token_type, 0) + 1
        
        print(f"\nToken type distribution:")
        for token_type, count in sorted(token_types.items(), key=lambda x: x[1], reverse=True):
            percentage = (count / len(all_tokens)) * 100
            print(f"  {token_type:15s}: {count:4d} ({percentage:5.1f}%)")
    
    except Exception as e:
        print(f"❌ Error displaying tokens: {e}")


In [None]:

# ==================== CELL 3: Pipeline 1 - Basic MIDI Roundtrip ====================

print("\n" + "=" * 60)
print("PIPELINE 1: MusicXML → MIDI → MusicXML (No Tokenization)")
print("=" * 60)

# Step 1: Load and display original MusicXML
original_xml = "../musicxml_sample/minimal.musicxml"
print(f"\n📄 Step 1: Display original MusicXML")
display_and_play(original_xml, show_score=True, midi=False)

# Step 2: Convert to MIDI
print(f"\n🎵 Step 2: Convert MusicXML to MIDI")
midi_from_xml = music_xml_to_midi(original_xml, "temp/minimal_from_xml.mid")

# Step 3: Display MIDI info and visualization
print(f"\n🎹 Step 3: Analyze converted MIDI")
display_midi_info(midi_from_xml, show_notes=True)

# Step 4: Convert MIDI back to MusicXML
print(f"\n📄 Step 4: Convert MIDI back to MusicXML")
xml_from_midi = midi_to_music_xml(midi_from_xml, "temp/minimal_roundtrip.musicxml")

# Step 5: Display reconstructed MusicXML
print(f"\n📄 Step 5: Display reconstructed MusicXML")
display_and_play(str(xml_from_midi), show_score=True, midi=False)

print(f"\n✅ Pipeline 1 completed!")
print(f"Original:     {original_xml}")
print(f"MIDI:         {midi_from_xml}")
print(f"Reconstructed: {xml_from_midi}")

In [None]:


# ==================== CELL 4: Pipeline 2 - Full BPE Token Pipeline ====================

print("\n" + "=" * 60)
print("PIPELINE 2: MusicXML → Tokens → MIDI → MusicXML (Full BPE)")
print("=" * 60)

# Step 1: Load original (same as before)
print(f"\n📄 Step 1: Original MusicXML")
print(f"File: {original_xml}")

# Step 2: Convert to MIDI (reuse from previous)
print(f"\n🎵 Step 2: Using MIDI from previous pipeline")
print(f"MIDI: {midi_from_xml}")

# Step 3: Convert MIDI to tokens
print(f"\n🔤 Step 3: Convert MIDI to tokens using pretrained REMI")
tokens, token_file = midi_to_token(midi_from_xml, "temp/minimal_tokens.json")

# Step 4: Display tokens
print(f"\n🔍 Step 4: Analyze tokens")
display_tokens(token_file, max_tokens=30)

# Step 5: Convert tokens back to MIDI
print(f"\n🎵 Step 5: Convert tokens back to MIDI")
midi_from_tokens = token_to_midi(token_file, "temp/minimal_from_tokens.mid")

# Step 6: Display reconstructed MIDI
print(f"\n🎹 Step 6: Analyze reconstructed MIDI")
display_midi_info(midi_from_tokens, show_notes=True)

# Step 7: Convert final MIDI to MusicXML
print(f"\n📄 Step 7: Convert final MIDI to MusicXML")
final_xml = midi_to_music_xml(midi_from_tokens, "temp/minimal_from_tokens.musicxml")

# Step 8: Display final result
print(f"\n📄 Step 8: Display final MusicXML")
display_and_play(str(final_xml), show_score=True, midi=False)

print(f"\n✅ Pipeline 2 completed!")
print(f"Original XML:   {original_xml}")
print(f"Intermediate MIDI: {midi_from_xml}")
print(f"Tokens:         {token_file}")
print(f"Reconstructed MIDI: {midi_from_tokens}")
print(f"Final XML:      {final_xml}")

In [None]:

print("\n" + "=" * 60)
print("COMPARISON ANALYSIS")
print("=" * 60)

def compare_files(file1, file2, file_type="MIDI"):
    """Compare two music files"""
    try:
        if file_type == "MIDI":
            score1 = Score(str(file1))
            score2 = Score(str(file2))
            
            print(f"\n📊 Comparing {file_type} files:")
            print(f"  File 1: {file1}")
            print(f"    Duration: {score1.end()} ticks")
            print(f"    Tracks: {len(score1.tracks)}")
            print(f"    Total notes: {sum(len(track.notes) for track in score1.tracks)}")
            
            print(f"  File 2: {file2}")
            print(f"    Duration: {score2.end()} ticks")
            print(f"    Tracks: {len(score2.tracks)}")
            print(f"    Total notes: {sum(len(track.notes) for track in score2.tracks)}")
            
            # Simple comparison
            duration_diff = abs(score1.end() - score2.end())
            print(f"  Duration difference: {duration_diff} ticks")
            
        elif file_type == "MusicXML":
            score1 = converter.parse(str(file1))
            score2 = converter.parse(str(file2))
            
            print(f"\n📊 Comparing {file_type} files:")
            print(f"  File 1: {file1}")
            print(f"    Parts: {len(score1.parts)}")
            print(f"    Duration: {score1.duration.quarterLength} quarters")
            
            print(f"  File 2: {file2}")
            print(f"    Parts: {len(score2.parts)}")
            print(f"    Duration: {score2.duration.quarterLength} quarters")
    
    except Exception as e:
        print(f"❌ Error comparing files: {e}")

# Compare MIDI files
print("\n🎹 MIDI Comparison:")
compare_files(midi_from_xml, midi_from_tokens, "MIDI")

# Compare MusicXML files
print("\n📄 MusicXML Comparison:")
compare_files(original_xml, final_xml, "MusicXML")
