# BeatHeritage V1 - Enhanced Beatmap Generator

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/hongminh54/BeatHeritage/blob/main/colab/beatheritage_v1_inference.ipynb)

BeatHeritage V1 is an enhanced version of Mapperatorinator V30 with improved stability, quality control, and pattern generation.

## Features
- 🎯 Enhanced stability with optimized sampling parameters
- 🎨 Better pattern variety and flow optimization
- 🎮 Support for all game modes (std, taiko, ctb, mania)
- 📊 Quality control with auto-correction
- 🚀 Performance optimizations with mixed precision

## 1. Setup Environment

In [None]:
# Check GPU availability
!nvidia-smi

# Clone repository
!git clone https://github.com/hongminh54/BeatHeritage.git
%cd BeatHeritage

# Install dependencies
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -r requirements.txt
!apt-get install -y ffmpeg

print("✅ Setup complete!")

## 2. Import Libraries

In [None]:
import os
import sys
import torch
import warnings
warnings.filterwarnings('ignore')

from IPython.display import display, HTML, Audio
from google.colab import files
import ipywidgets as widgets
from pathlib import Path

# Check CUDA availability
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

## 3. Upload Audio File

In [None]:
# Option 1: Upload from local
uploaded = files.upload()

if uploaded:
    audio_filename = list(uploaded.keys())[0]
    audio_path = f'/content/BeatHeritage/{audio_filename}'
    print(f"✅ Uploaded: {audio_filename}")
    display(Audio(audio_path))
else:
    # Option 2: Use demo audio
    !wget -O demo.mp3 'https://www.example.com/demo.mp3'  # Replace with actual demo URL
    audio_path = '/content/BeatHeritage/demo.mp3'
    print("Using demo audio")

## 4. Configure Generation Parameters

In [None]:
# Create interactive widgets for configuration
style = {'description_width': 'initial'}

# Model selection
model_selector = widgets.Dropdown(
    options=[
        ('BeatHeritage V1 (Enhanced)', 'beatheritage_v1'),
        ('Mapperatorinator V30', 'v30'),
        ('Mapperatorinator V29', 'v29'),
        ('Mapperatorinator V28', 'v28')
    ],
    value='beatheritage_v1',
    description='Model:',
    style=style
)

# Game mode selection
gamemode = widgets.Dropdown(
    options=[
        ('Standard', 0),
        ('Taiko', 1),
        ('Catch the Beat', 2),
        ('Mania', 3)
    ],
    value=0,
    description='Game Mode:',
    style=style
)

# Difficulty slider
difficulty = widgets.FloatSlider(
    value=5.5,
    min=1.0,
    max=10.0,
    step=0.1,
    description='Difficulty (★):',
    style=style
)

# Descriptors
descriptors = widgets.SelectMultiple(
    options=['jump aim', 'stream', 'tech', 'aim', 'speed', 'flow', 'clean', 'complex', 'simple', 'modern', 'classic'],
    value=['clean'],
    description='Style Descriptors:',
    rows=5,
    style=style
)

# Advanced settings
temperature = widgets.FloatSlider(
    value=0.85,
    min=0.1,
    max=2.0,
    step=0.05,
    description='Temperature:',
    style=style
)

top_p = widgets.FloatSlider(
    value=0.92,
    min=0.1,
    max=1.0,
    step=0.01,
    description='Top-p:',
    style=style
)

cfg_scale = widgets.FloatSlider(
    value=7.5,
    min=1.0,
    max=20.0,
    step=0.5,
    description='CFG Scale:',
    style=style
)

# Quality control options
enable_auto_correction = widgets.Checkbox(
    value=True,
    description='Enable Auto-correction',
    style=style
)

enable_flow_optimization = widgets.Checkbox(
    value=True,
    description='Enable Flow Optimization',
    style=style
)

super_timing = widgets.Checkbox(
    value=False,
    description='Super Timing (for variable BPM)',
    style=style
)

export_osz = widgets.Checkbox(
    value=True,
    description='Export as .osz',
    style=style
)

# Display all widgets
display(HTML('<h3>Basic Settings</h3>'))
display(model_selector, gamemode, difficulty)

display(HTML('<h3>Style Configuration</h3>'))
display(descriptors)

display(HTML('<h3>Advanced Settings</h3>'))
display(temperature, top_p, cfg_scale)

display(HTML('<h3>Quality Control (BeatHeritage V1)</h3>'))
display(enable_auto_correction, enable_flow_optimization)

display(HTML('<h3>Export Options</h3>'))
display(super_timing, export_osz)

## 5. Generate Beatmap

In [None]:
import subprocess
import json

# Prepare command
output_path = '/content/BeatHeritage/output'
os.makedirs(output_path, exist_ok=True)

# Build command arguments
cmd = [
    'python', 'inference.py',
    '-cn', model_selector.value,
    f'audio_path={audio_path}',
    f'output_path={output_path}',
    f'gamemode={gamemode.value}',
    f'difficulty={difficulty.value}',
    f'temperature={temperature.value}',
    f'top_p={top_p.value}',
    f'cfg_scale={cfg_scale.value}',
    f'super_timing={str(super_timing.value).lower()}',
    f'export_osz={str(export_osz.value).lower()}',
]

# Add descriptors if selected
if descriptors.value:
    desc_list = json.dumps(list(descriptors.value))
    cmd.append(f'descriptors={desc_list}')

# Add BeatHeritage V1 specific features
if model_selector.value == 'beatheritage_v1':
    cmd.extend([
        f'quality_control.enable_auto_correction={str(enable_auto_correction.value).lower()}',
        f'quality_control.enable_flow_optimization={str(enable_flow_optimization.value).lower()}',
        'advanced_features.enable_context_aware_generation=true',
        'advanced_features.enable_style_preservation=true',
        'advanced_features.enable_pattern_variety=true',
        'generate_positions=true',
        'position_refinement=true'
    ])

print("🚀 Starting beatmap generation...")
print(f"Model: {model_selector.label}")
print(f"Game Mode: {gamemode.label}")
print(f"Difficulty: {difficulty.value}★")
print(f"Descriptors: {', '.join(descriptors.value) if descriptors.value else 'None'}")
print("\n" + "="*50 + "\n")

# Run inference
try:
    result = subprocess.run(cmd, capture_output=True, text=True, check=True)
    print(result.stdout)
    print("\n✅ Beatmap generation complete!")
    
    # List generated files
    generated_files = list(Path(output_path).glob('*'))
    print(f"\n📁 Generated {len(generated_files)} files:")
    for file in generated_files:
        print(f"  - {file.name}")
        
except subprocess.CalledProcessError as e:
    print(f"❌ Error during generation:")
    print(e.stderr)
    print("\nTroubleshooting tips:")
    print("1. Check if the audio file is valid")
    print("2. Try reducing temperature or cfg_scale")
    print("3. Ensure sufficient GPU memory is available")

## 6. Download Results

In [None]:
import zipfile
from datetime import datetime

# Find generated files
output_files = list(Path(output_path).glob('*'))

if output_files:
    # Create zip archive
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    zip_name = f'beatheritage_v1_output_{timestamp}.zip'
    zip_path = f'/content/{zip_name}'
    
    with zipfile.ZipFile(zip_path, 'w') as zipf:
        for file in output_files:
            zipf.write(file, file.name)
    
    print(f"📦 Created archive: {zip_name}")
    
    # Download options
    download_button = widgets.Button(
        description='Download All Files',
        button_style='success',
        icon='download'
    )
    
    def download_files(b):
        files.download(zip_path)
    
    download_button.on_click(download_files)
    display(download_button)
    
    # Also offer individual file downloads
    print("\n📄 Or download individual files:")
    for file in output_files:
        if file.suffix in ['.osu', '.osz']:
            print(f"  - {file.name}")
            files.download(str(file))
else:
    print("❌ No output files found")

## 7. Advanced: Batch Processing

In [None]:
# Batch processing for multiple audio files
def batch_generate(audio_files, config):
    """
    Generate beatmaps for multiple audio files with the same configuration
    """
    results = []
    
    for audio_file in audio_files:
        print(f"\n🎵 Processing: {audio_file}")
        
        cmd = [
            'python', 'inference.py',
            '-cn', config['model'],
            f'audio_path={audio_file}',
            f'output_path={config["output_dir"]}',
            f'gamemode={config["gamemode"]}',
            f'difficulty={config["difficulty"]}',
            f'temperature={config["temperature"]}',
            f'top_p={config["top_p"]}',
        ]
        
        try:
            result = subprocess.run(cmd, capture_output=True, text=True, check=True)
            results.append({'file': audio_file, 'status': 'success'})
            print(f"  ✅ Success")
        except Exception as e:
            results.append({'file': audio_file, 'status': 'failed', 'error': str(e)})
            print(f"  ❌ Failed: {e}")
    
    return results

# Example usage (uncomment to use)
# audio_files = ['song1.mp3', 'song2.mp3', 'song3.mp3']
# config = {
#     'model': 'beatheritage_v1',
#     'output_dir': '/content/batch_output',
#     'gamemode': 0,
#     'difficulty': 5.5,
#     'temperature': 0.85,
#     'top_p': 0.92
# }
# results = batch_generate(audio_files, config)

## 8. Visualize Results (Optional)

In [None]:
# Simple beatmap visualization
import matplotlib.pyplot as plt
import numpy as np

def visualize_beatmap(osu_file_path):
    """
    Simple visualization of beatmap hit objects
    """
    hit_objects = []
    
    with open(osu_file_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        in_hit_objects = False
        
        for line in lines:
            if '[HitObjects]' in line:
                in_hit_objects = True
                continue
            
            if in_hit_objects and line.strip():
                parts = line.strip().split(',')
                if len(parts) >= 2:
                    try:
                        x = int(parts[0])
                        y = int(parts[1])
                        hit_objects.append((x, y))
                    except:
                        pass
    
    if hit_objects:
        # Create visualization
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        
        # Scatter plot of hit objects
        xs, ys = zip(*hit_objects)
        ax1.scatter(xs, ys, alpha=0.6, s=20)
        ax1.set_xlim(0, 512)
        ax1.set_ylim(384, 0)  # Inverted Y axis
        ax1.set_aspect('equal')
        ax1.set_title('Hit Object Positions')
        ax1.set_xlabel('X')
        ax1.set_ylabel('Y')
        ax1.grid(True, alpha=0.3)
        
        # Heatmap
        heatmap, xedges, yedges = np.histogram2d(xs, ys, bins=20)
        extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
        im = ax2.imshow(heatmap.T, extent=extent, origin='lower', cmap='hot', aspect='auto')
        ax2.set_title('Density Heatmap')
        ax2.set_xlabel('X')
        ax2.set_ylabel('Y')
        plt.colorbar(im, ax=ax2)
        
        plt.tight_layout()
        plt.show()
        
        print(f"Total hit objects: {len(hit_objects)}")
    else:
        print("No hit objects found in the beatmap")

# Find and visualize generated .osu files
osu_files = list(Path(output_path).glob('*.osu'))
if osu_files:
    for osu_file in osu_files[:1]:  # Visualize first file
        print(f"Visualizing: {osu_file.name}")
        visualize_beatmap(osu_file)

## Tips & Troubleshooting

### For Best Results:
- Use high-quality audio files (MP3 320kbps or FLAC)
- Match difficulty to song intensity
- Use descriptors that match the song style
- Enable super_timing for songs with variable BPM

### Common Issues:

**1. Out of Memory:**
- Reduce batch size in advanced settings
- Use shorter audio segments
- Restart runtime to clear memory

**2. Poor Quality Output:**
- Lower temperature (0.7-0.8) for more stable output
- Increase cfg_scale (10-15) for stronger guidance
- Use more specific descriptors

**3. Repetitive Patterns:**
- Enable pattern variety in BeatHeritage V1
- Add diverse descriptors
- Increase top_p value

### Support
- GitHub: https://github.com/hongminh54/BeatHeritage
- Documentation: See README.md
- Discord: Join the community for help