<a href="https://colab.research.google.com/github/hongminh54/BeatHeritage/blob/main/colab/beatheritage_v1_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# BeatHeritage V1 - Beatmap Generator

An enhanced AI model for generating osu! beatmaps with improved stability and quality control.


### Instructions:
1. **Read and accept the rules** by clicking the checkbox in the first cell
2. **Ensure GPU runtime**: Go to __Runtime → Change Runtime Type → GPU__
3. **Execute cells in order**: Click ▶️ on each cell sequentially
4. **Upload your audio**: Choose an MP3/OGG file when prompted
5. **Configure parameters**: Adjust settings to your preference
6. **Generate beatmap**: Run the generation cell and wait for results


In [None]:
#@title 🚀 Setup Environment { display-mode: "form" }
#@markdown ### ⚠️ Important: Please use this tool responsibly
#@markdown - Always disclose AI usage in your beatmap descriptions
#@markdown - Respect the original music artists and mappers
#@markdown - This tool is for educational and creative purposes

i_accept_the_rules = False #@param {type:"boolean"}
#@markdown ☑️ **I accept the rules and will use this tool responsibly**

import os
import sys

if not i_accept_the_rules:
    raise ValueError("Please read and accept the rules before proceeding!")

print("Installing BeatHeritage...")
print("="*50)

# Clone repository if not exists
if not os.path.exists('/content/BeatHeritage'):
    !git clone -q https://github.com/hongminh54/BeatHeritage.git
    print("✅ Repository cloned")
else:
    print("✅ Repository already exists")

%cd /content/BeatHeritage

# Install dependencies
print("\nInstalling dependencies...")
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q -r requirements.txt
!apt-get install -y ffmpeg > /dev/null 2>&1

print("\nSetup complete!")

# Import required libraries
import warnings
warnings.filterwarnings('ignore')

import torch
from google.colab import files
from IPython.display import display, HTML, Audio
from pathlib import Path
import json
import shlex
import subprocess
from datetime import datetime
import zipfile

# Check GPU availability
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"\nUsing device: {device}")
if device == 'cuda':
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f"GPU: {gpu_name}")
    print(f"Memory: {gpu_memory:.1f} GB")
else:
    print("No GPU detected! Generation will be VERY slow.")

# Initialize global variables
audio_path = ""
output_path = "/content/BeatHeritage/output"
os.makedirs(output_path, exist_ok=True)

In [None]:
#@title 🎵 Upload Audio File { display-mode: "form" }
#@markdown Upload your audio file (MP3, OGG, or WAV format)

def upload_and_validate_audio():
    """Upload and validate audio file with proper error handling"""
    global audio_path
    
    print("Please select an audio file to upload...")
    uploaded = files.upload()
    
    if not uploaded:
        print("No file uploaded")
        return None
    
    # Get the first uploaded file
    original_filename = list(uploaded.keys())[0]
    
    # Clean filename - remove special characters and spaces
    import re
    clean_filename = re.sub(r'[^a-zA-Z0-9._-]', '_', original_filename)
    clean_filename = clean_filename.replace(' ', '_')
    
    # Ensure proper extension
    if not any(clean_filename.lower().endswith(ext) for ext in ['.mp3', '.ogg', '.wav']):
        print(f"Invalid file format: {original_filename}")
        print("Please upload an MP3, OGG, or WAV file")
        return None
    
    # Save with cleaned filename
    audio_path = f'/content/BeatHeritage/{clean_filename}'
    
    # Write the uploaded content to the new path
    with open(audio_path, 'wb') as f:
        f.write(uploaded[original_filename])
    
    print(f"Audio uploaded successfully!")
    print(f"Original: {original_filename}")
    print(f"Saved as: {clean_filename}")
    print(f"Path: {audio_path}")
    
    # Display audio player
    display(Audio(audio_path))
    
    return audio_path

# Upload audio
audio_path = upload_and_validate_audio()

if not audio_path:
    print("\n⚠Please run this cell again and upload a valid audio file")

In [None]:
#@title ⚙️ Configure Generation Parameters { display-mode: "form" }

#@markdown ### 🎯 Basic Settings
#@markdown ---
#@markdown Choose the AI model version to use:
model_version = "BeatHeritage V1 (Enhanced)" #@param ["BeatHeritage V1 (Enhanced)", "Mapperatorinator V30", "Mapperatorinator V29", "Mapperatorinator V28"]

#@markdown Select the game mode for your beatmap:
gamemode = "Standard" #@param ["Standard", "Taiko", "Catch the Beat", "Mania"]

#@markdown Target difficulty (★ rating):
difficulty = 5.5 #@param {type:"slider", min:1, max:10, step:0.1}

#@markdown ### 🎨 Style Configuration
#@markdown ---
#@markdown Primary mapping style descriptor:
descriptor_1 = "clean" #@param ["clean", "tech", "jump aim", "stream", "aim", "speed", "flow", "complex", "simple", "modern", "classic", "slider tech", "alt", "precision", "stamina"]

#@markdown Secondary style descriptor (optional):
descriptor_2 = "" #@param ["", "clean", "tech", "jump aim", "stream", "aim", "speed", "flow", "complex", "simple", "modern", "classic", "slider tech", "alt", "precision", "stamina"]

#@markdown ### 🔧 Advanced Parameters
#@markdown ---
#@markdown Generation temperature (lower = more conservative):
temperature = 0.85 #@param {type:"slider", min:0.1, max:2.0, step:0.05}

#@markdown Top-p sampling (nucleus sampling):
top_p = 0.92 #@param {type:"slider", min:0.1, max:1.0, step:0.01}

#@markdown Classifier-free guidance scale:
cfg_scale = 7.5 #@param {type:"slider", min:1.0, max:20.0, step:0.5}

#@markdown ### 📊 Quality Control (BeatHeritage V1)
#@markdown ---
enable_auto_correction = True #@param {type:"boolean"}
enable_flow_optimization = True #@param {type:"boolean"}
enable_pattern_variety = True #@param {type:"boolean"}

#@markdown ### 🎯 Export Options
#@markdown ---
super_timing = False #@param {type:"boolean"}
#@markdown Enable for songs with variable BPM (slower generation)

export_osz = True #@param {type:"boolean"}
#@markdown Export as .osz package (includes audio)

# Map model names to config names
model_configs = {
    "BeatHeritage V1 (Enhanced)": "beatheritage_v1",
    "Mapperatorinator V30": "v30",
    "Mapperatorinator V29": "v29",
    "Mapperatorinator V28": "v28"
}

# Map gamemode names to indices
gamemode_indices = {
    "Standard": 0,
    "Taiko": 1,
    "Catch the Beat": 2,
    "Mania": 3
}

selected_model = model_configs[model_version]
selected_gamemode = gamemode_indices[gamemode]

# Build descriptor list
descriptors = [d for d in [descriptor_1, descriptor_2] if d]

# Display configuration summary
print("Configuration Summary")
print("="*50)
print(f"Model: {model_version}")
print(f"Game Mode: {gamemode}")
print(f"Difficulty: {difficulty}★")
print(f"Style: {', '.join(descriptors) if descriptors else 'Default'}")
print(f"Temperature: {temperature}")
print(f"Top-p: {top_p}")
print(f"CFG Scale: {cfg_scale}")

if selected_model == "beatheritage_v1":
    print("\nBeatHeritage V1 Features:")
    if enable_auto_correction:
        print("  ✓ Auto-correction enabled")
    if enable_flow_optimization:
        print("  ✓ Flow optimization enabled")
    if enable_pattern_variety:
        print("  ✓ Pattern variety enabled")

if super_timing:
    print("\nSuper timing enabled (for variable BPM)")

print("\nConfiguration ready!")

In [None]:
#@title 🎮 Generate Beatmap { display-mode: "form" }
#@markdown Click the play button to start generation. This may take a few minutes depending on song length.

def generate_beatmap():
    """Generate beatmap with proper error handling and progress tracking"""
    
    if not audio_path or not os.path.exists(audio_path):
        print("Error: No audio file found!")
        print("Please upload an audio file first.")
        return None
    
    print("Starting beatmap generation...")
    print("="*50)
    print(f"Audio: {os.path.basename(audio_path)}")
    print(f"Model: {model_version}")
    print(f"Mode: {gamemode}")
    print(f"Difficulty: {difficulty}★")
    print("="*50)
    print()
    
    # Build command with proper escaping
    cmd = [
        'python', 'inference.py',
        '-cn', selected_model,
        f'audio_path={shlex.quote(audio_path)}',
        f'output_path={shlex.quote(output_path)}',
        f'gamemode={selected_gamemode}',
        f'difficulty={difficulty}',
        f'temperature={temperature}',
        f'top_p={top_p}',
        f'cfg_scale={cfg_scale}',
        f'super_timing={str(super_timing).lower()}',
        f'export_osz={str(export_osz).lower()}',
    ]
    
    # Add descriptors if specified
    if descriptors:
        desc_str = json.dumps(descriptors)
        cmd.append(f'descriptors={shlex.quote(desc_str)}')
    
    # Add BeatHeritage V1 specific features
    if selected_model == "beatheritage_v1":
        if enable_auto_correction:
            cmd.append('quality_control.enable_auto_correction=true')
        if enable_flow_optimization:
            cmd.append('quality_control.enable_flow_optimization=true')
        if enable_pattern_variety:
            cmd.append('advanced_features.enable_pattern_variety=true')
        
        # Always enable these for V1
        cmd.extend([
            'advanced_features.enable_context_aware_generation=true',
            'advanced_features.enable_style_preservation=true',
            'generate_positions=true',
            'position_refinement=true'
        ])
    
    # Execute command
    try:
        print("⏳ Generating beatmap... (this may take several minutes)\n")
        
        # Run the command
        process = subprocess.Popen(
            cmd,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
            bufsize=1,
            universal_newlines=True
        )
        
        # Stream output in real-time
        for line in process.stdout:
            print(line, end='')
        
        # Wait for completion
        return_code = process.wait()
        
        if return_code == 0:
            print("\n" + "="*50)
            print("Beatmap generation complete!")
            
            # List generated files
            generated_files = list(Path(output_path).glob('*'))
            if generated_files:
                print(f"\nGenerated {len(generated_files)} file(s):")
                for file in generated_files:
                    size_mb = file.stat().st_size / (1024 * 1024)
                    print(f"  • {file.name} ({size_mb:.2f} MB)")
            
            return generated_files
        else:
            print(f"\nGeneration failed with error code: {return_code}")
            return None
            
    except Exception as e:
        print(f"\nError during generation: {str(e)}")
        print("\nTroubleshooting tips:")
        print("1. Ensure the audio file is valid")
        print("2. Check if GPU memory is sufficient")
        print("3. Try reducing temperature or cfg_scale")
        print("4. Disable super_timing if enabled")
        return None

# Run generation
generated_files = generate_beatmap()

In [None]:
#@title 📥 Download Generated Files { display-mode: "form" }
#@markdown Download your generated beatmap files

def download_results():
    """Package and download generated beatmap files"""
    
    output_files = list(Path(output_path).glob('*'))
    
    if not output_files:
        print("No files to download")
        print("Please generate a beatmap first.")
        return
    
    print("Preparing files for download...")
    print("="*50)
    
    # Create timestamp for unique naming
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    
    # Check if we have .osz files
    osz_files = [f for f in output_files if f.suffix == '.osz']
    osu_files = [f for f in output_files if f.suffix == '.osu']
    
    # Download .osz files directly if available
    if osz_files:
        for osz_file in osz_files:
            print(f"\n📥 Downloading: {osz_file.name}")
            files.download(str(osz_file))
    
    # Download .osu files
    elif osu_files:
        if len(osu_files) == 1:
            # Single file - download directly
            osu_file = osu_files[0]
            print(f"\n📥 Downloading: {osu_file.name}")
            files.download(str(osu_file))
        else:
            # Multiple files - create zip
            zip_name = f'beatheritage_{gamemode.lower()}_{timestamp}.zip'
            zip_path = f'/content/{zip_name}'
            
            with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
                for file in output_files:
                    zipf.write(file, file.name)
                    print(f"  • Added: {file.name}")
            
            print(f"\nDownloading: {zip_name}")
            files.download(zip_path)
    
    # Also handle other files
    other_files = [f for f in output_files if f.suffix not in ['.osz', '.osu']]
    if other_files:
        print("\nAdditional files generated:")
        for file in other_files:
            print(f"  • {file.name}")
    
    print("\nDownload complete!")
    print("\nTips:")
    print("• .osz files can be opened directly in osu!")
    print("• .osu files should be placed in your Songs folder")
    print("• Press F5 in osu! to refresh after adding files")

# Download files
download_results()

---

## Additional Information

### Tips for Best Results:
- **Audio Quality**: Use high-quality audio files (320kbps MP3 or FLAC)
- **Difficulty Matching**: Match the difficulty rating to song intensity
- **Style Descriptors**: Choose descriptors that match the music genre
- **Variable BPM**: Enable `super_timing` for songs with tempo changes

### Troubleshooting:

**Out of Memory:**
- Restart runtime to clear GPU memory
- Use shorter songs or segments
- Reduce cfg_scale value

**Poor Quality Output:**
- Lower temperature (0.7-0.8) for stability
- Increase cfg_scale (10-15) for stronger guidance
- Use more specific descriptors

**Generation Errors:**
- Ensure audio file has no special characters
- Check GPU is enabled in runtime
- Try different model versions

### Resources:
- [GitHub Repository](https://github.com/hongminh54/BeatHeritage)
- [Documentation](https://github.com/hongminh54/BeatHeritage/blob/main/README.md)

### License & Credits:
- BeatHeritage V1 by hongminh54
- Based on Mapperatorinator by OliBomby
- Please credit AI usage in your beatmap descriptions

---