# MT3-Infer - Music Transcription

This notebook demonstrates how to use MT3-Infer for transcribing music audio to MIDI.

**Features:**
- Transcribe audio to MIDI automatically
- Multiple pre-trained models available (MR-MT3, MT3-PyTorch, YourMT3)
- High-quality music transcription using transformer architecture

## 1. Installation

In [None]:
# Install mt3-infer with synthesis support
!pip install -q "mt3-infer[synthesis] @ git+https://github.com/openmirlab/mt3-infer.git" soundfile
!apt-get -qq install -y fluidsynth

In [None]:
# Verify installation
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 2. Download Model

We'll use the **MR-MT3** model, which is optimized for speed (57x real-time).

In [None]:
# Create directories
!mkdir -p input_songs
!mkdir -p outputs

In [None]:
# Download and cache the model (models are automatically downloaded on first use)
from mt3_infer import download_model, list_models

model_name = "mr_mt3"  # Fast model (57x real-time)
print(f"Downloading model: {model_name}")
print("(This may take a few minutes on first run)\n")

checkpoint_path = download_model(model_name)

print(f"\nModel downloaded successfully!")
print(f"Checkpoint: {checkpoint_path}")

## 3. Upload Your Audio File

Upload a `.wav` file to transcribe to MIDI.

In [None]:
# Option 1: Upload from your computer
from google.colab import files

print("Upload your audio file (.wav format):")
uploaded = files.upload()

# Move uploaded file to input folder
for filename in uploaded.keys():
    !mv "{filename}" input_songs/
    print(f"Moved {filename} to input_songs/")

In [None]:
# Option 2: Use a sample audio (uncomment to use)
# !wget -q -O input_songs/sample.wav "YOUR_AUDIO_URL_HERE"

In [None]:
# Check input files
!ls -lh input_songs/

## 4. Run Music Transcription

Select your preferred method and run the transcription:

In [None]:
#@title Select Inference Options { display-mode: "form" }
#@markdown Choose inference method and device:

inference_method = "CLI (Command Line)" #@param ["CLI (Command Line)", "Python API"]
device = "cpu" #@param ["auto", "cuda", "cpu"]

print(f"Selected method: {inference_method}")
print(f"Selected device: {device}")

In [None]:
#@title Run Music Transcription { display-mode: "form" }
#@markdown Click the play button to run transcription with your selected method.

from pathlib import Path

if inference_method == "CLI (Command Line)":
    # ============================================
    # Option A: CLI (Command Line)
    # ============================================
    print("Running with CLI...\n")
    
    input_folder = Path("input_songs")
    for audio_file in input_folder.glob("*.wav"):
        output_file = Path("outputs") / f"{audio_file.stem}.mid"
        !mt3-infer transcribe "{audio_file}" -o "{output_file}" -m mr_mt3 --device {device}

else:
    # ============================================
    # Option B: Python API
    # ============================================
    print("Running with Python API...\n")

    from mt3_infer.utils.audio import load_audio
    from mt3_infer import download_model
    from mt3_infer.adapters.mr_mt3 import MRMT3Adapter

    # Load model
    checkpoint_path = download_model("mr_mt3")
    adapter = MRMT3Adapter()
    adapter.load_model(checkpoint_path, device=device)

    # Process files
    input_folder = Path("input_songs")
    output_folder = Path("outputs")
    output_folder.mkdir(exist_ok=True)

    for audio_path in input_folder.glob("*.wav"):
        print(f"Processing: {audio_path.name}")

        # Load audio
        audio, sr = load_audio(str(audio_path), sr=16000)

        # Run transcription
        midi = adapter.transcribe(audio, sr)

        # Save MIDI
        output_path = output_folder / f"{audio_path.stem}.mid"
        midi.save(str(output_path))
        print(f"  Saved: {output_path}")

    print("\nDone!")

## 5. Check Output Files

In [None]:
# Check output files
!ls -lh outputs/

## 6. Preview Results

In [None]:
#@title Preview MIDI Results { display-mode: "form" }
#@markdown This cell displays MIDI info and synthesizes audio for playback.

import pretty_midi
import numpy as np
from pathlib import Path
import IPython.display as ipd

# Download a General MIDI soundfont if not exists
soundfont_path = "/usr/share/sounds/sf2/FluidR3_GM.sf2"
if not Path(soundfont_path).exists():
    print("Downloading soundfont...")
    !wget -q -O /tmp/FluidR3_GM.sf2 "https://keymusician01.s3.amazonaws.com/FluidR3_GM.sf2"
    !mkdir -p /usr/share/sounds/sf2
    !mv /tmp/FluidR3_GM.sf2 /usr/share/sounds/sf2/
    print("Soundfont ready!")

output_dir = Path("outputs")

# Find and display all output files
for midi_file in sorted(output_dir.glob("*.mid")):
    print(f"\n{'='*60}")
    print(f"File: {midi_file.name}")
    print(f"{'='*60}")
    
    # Load and display MIDI info
    midi_data = pretty_midi.PrettyMIDI(str(midi_file))
    print(f"Duration: {midi_data.get_end_time():.2f} seconds")
    print(f"Tempo: {midi_data.estimate_tempo():.1f} BPM")
    print(f"Instruments: {len(midi_data.instruments)}")
    
    total_notes = sum(len(inst.notes) for inst in midi_data.instruments)
    print(f"Total notes: {total_notes}")
    
    for inst in midi_data.instruments:
        if inst.is_drum:
            inst_name = "Drums"
        else:
            inst_name = pretty_midi.program_to_instrument_name(inst.program)
        print(f"  - {inst_name}: {len(inst.notes)} notes")
    
    # Synthesize MIDI to audio
    print(f"\nSynthesizing audio...")
    try:
        from midi2audio import FluidSynth
        
        audio_output = f"/tmp/{midi_file.stem}_synth.wav"
        fs = FluidSynth(soundfont_path)
        fs.midi_to_audio(str(midi_file), audio_output)
        
        print("Play synthesized MIDI:")
        display(ipd.Audio(audio_output))
    except Exception as e:
        print(f"Synthesis failed: {e}")
        print("Falling back to pretty_midi synthesis...")
        try:
            # Fallback: use pretty_midi's built-in synthesizer
            audio_data = midi_data.fluidsynth(fs=44100)
            print("Play synthesized MIDI:")
            display(ipd.Audio(audio_data, rate=44100))
        except Exception as e2:
            print(f"Fallback also failed: {e2}")

## 7. Download Results

In [None]:
# Download all output files as a zip
!zip -r outputs.zip outputs/

from google.colab import files
files.download("outputs.zip")

---

## Available Models

| Model | Speed | Notes Detected | Size | Best For |
|-------|-------|----------------|------|----------|
| MR-MT3 | 57x real-time | 116 notes | 176 MB | Speed (recommended) |
| MT3-PyTorch | 12x real-time | 147 notes | 176 MB | Accuracy |
| YourMT3 | ~15x real-time | 118 notes | 536 MB | Multi-stem separation |

See the [GitHub repository](https://github.com/openmirlab/mt3-infer) for more options.