# MT3 Batch Transcription for Clef Baseline

This notebook transcribes audio files to MIDI using the official Google MT3 model.

---

## Attribution

**Original notebook:** [Music Transcription with Transformers](https://github.com/magenta/mt3/blob/main/mt3/colab/music_transcription_with_transformers.ipynb)  
**Original authors:** Google Magenta Team  
**License:** Apache License 2.0

**Modifications by:** Clef Project (2025)  
**Changes made:**
- Added batch transcription support for multiple audio files
- Added Google Drive integration for input/output
- Removed Google Analytics tracking code
- Added progress tracking and error handling
- Added skip logic for resumable transcription

---

## Usage

1. Upload audio files to Google Drive: `clef_baseline/audio/`
2. Run all cells
3. Download MIDI files from: `clef_baseline/midi/`

**Reference:** Gardner et al. (2022) - MT3: Multi-Task Multitrack Music Transcription

In [None]:
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# Modifications Copyright 2025 Clef Project
# Modifications are licensed under the Apache License, Version 2.0
# ==============================================================================

#@title 1. Setup Environment
#@markdown Install MT3 and its dependencies (may take a few minutes).

!apt-get update -qq && apt-get install -qq libfluidsynth3 build-essential libasound2-dev libjack-dev

# Install mt3
!git clone --branch=main https://github.com/magenta/mt3
!mv mt3 mt3_tmp; mv mt3_tmp/* .; rm -r mt3_tmp
!python3 -m pip install jax[cuda12] nest-asyncio pyfluidsynth==1.3.0 -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Copy checkpoints
!gsutil -q -m cp -r gs://mt3/checkpoints .

# Copy soundfont (originally from https://sites.google.com/site/soundfonts4u)
!gsutil -q -m cp gs://magentadata/soundfonts/SGM-v2.01-Sal-Guit-Bass-V1.3.sf2 .

print("Setup complete!")

In [None]:
#@title 2. Imports and Definitions

import functools
import os

import numpy as np
import tensorflow.compat.v2 as tf

import functools
import gin
import jax
import librosa
import note_seq
import seqio
import t5
import t5x

from mt3 import metrics_utils
from mt3 import models
from mt3 import network
from mt3 import note_sequences
from mt3 import preprocessors
from mt3 import spectrograms
from mt3 import vocabularies

import nest_asyncio
nest_asyncio.apply()

SAMPLE_RATE = 16000
SF2_PATH = 'SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'


class InferenceModel(object):
  """Wrapper of T5X model for music transcription."""

  def __init__(self, checkpoint_path, model_type='mt3'):

    # Model Constants.
    if model_type == 'ismir2021':
      num_velocity_bins = 127
      self.encoding_spec = note_sequences.NoteEncodingSpec
      self.inputs_length = 512
    elif model_type == 'mt3':
      num_velocity_bins = 1
      self.encoding_spec = note_sequences.NoteEncodingWithTiesSpec
      self.inputs_length = 256
    else:
      raise ValueError('unknown model_type: %s' % model_type)

    gin_files = ['/content/mt3/gin/model.gin',
                 f'/content/mt3/gin/{model_type}.gin']

    self.batch_size = 8
    self.outputs_length = 1024
    self.sequence_length = {'inputs': self.inputs_length,
                            'targets': self.outputs_length}

    self.partitioner = t5x.partitioning.PjitPartitioner(
        num_partitions=1)

    # Build Codecs and Vocabularies.
    self.spectrogram_config = spectrograms.SpectrogramConfig()
    self.codec = vocabularies.build_codec(
        vocab_config=vocabularies.VocabularyConfig(
            num_velocity_bins=num_velocity_bins))
    self.vocabulary = vocabularies.vocabulary_from_codec(self.codec)
    self.output_features = {
        'inputs': seqio.ContinuousFeature(dtype=tf.float32, rank=2),
        'targets': seqio.Feature(vocabulary=self.vocabulary),
    }

    # Create a T5X model.
    self._parse_gin(gin_files)
    self.model = self._load_model()

    # Restore from checkpoint.
    self.restore_from_checkpoint(checkpoint_path)

  @property
  def input_shapes(self):
    return {
          'encoder_input_tokens': (self.batch_size, self.inputs_length),
          'decoder_input_tokens': (self.batch_size, self.outputs_length)
    }

  def _parse_gin(self, gin_files):
    """Parse gin files used to train the model."""
    gin_bindings = [
        'from __gin__ import dynamic_registration',
        'from mt3 import vocabularies',
        'VOCAB_CONFIG=@vocabularies.VocabularyConfig()',
        'vocabularies.VocabularyConfig.num_velocity_bins=%NUM_VELOCITY_BINS'
    ]
    with gin.unlock_config():
      gin.parse_config_files_and_bindings(
          gin_files, gin_bindings, finalize_config=False)

  def _load_model(self):
    """Load up a T5X `Model` after parsing training gin config."""
    model_config = gin.get_configurable(network.T5Config)()
    module = network.Transformer(config=model_config)
    return models.ContinuousInputsEncoderDecoderModel(
        module=module,
        input_vocabulary=self.output_features['inputs'].vocabulary,
        output_vocabulary=self.output_features['targets'].vocabulary,
        optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
        input_depth=spectrograms.input_depth(self.spectrogram_config))


  def restore_from_checkpoint(self, checkpoint_path):
    """Restore training state from checkpoint, resets self._predict_fn()."""
    train_state_initializer = t5x.utils.TrainStateInitializer(
      optimizer_def=self.model.optimizer_def,
      init_fn=self.model.get_initial_variables,
      input_shapes=self.input_shapes,
      partitioner=self.partitioner)

    restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(
        path=checkpoint_path, mode='specific', dtype='float32')

    train_state_axes = train_state_initializer.train_state_axes
    self._predict_fn = self._get_predict_fn(train_state_axes)
    self._train_state = train_state_initializer.from_checkpoint_or_scratch(
        [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))

  @functools.lru_cache()
  def _get_predict_fn(self, train_state_axes):
    """Generate a partitioned prediction function for decoding."""
    def partial_predict_fn(params, batch, decode_rng):
      return self.model.predict_batch_with_aux(
          params, batch, decoder_params={'decode_rng': None})
    return self.partitioner.partition(
        partial_predict_fn,
        in_axis_resources=(
            train_state_axes.params,
            t5x.partitioning.PartitionSpec('data',), None),
        out_axis_resources=t5x.partitioning.PartitionSpec('data',)
    )

  def predict_tokens(self, batch, seed=0):
    """Predict tokens from preprocessed dataset batch."""
    prediction, _ = self._predict_fn(
        self._train_state.params, batch, jax.random.PRNGKey(seed))
    return self.vocabulary.decode_tf(prediction).numpy()

  def __call__(self, audio):
    """Infer note sequence from audio samples.

    Args:
      audio: 1-d numpy array of audio samples (16kHz) for a single example.

    Returns:
      A note_sequence of the transcribed audio.
    """
    ds = self.audio_to_dataset(audio)
    ds = self.preprocess(ds)

    model_ds = self.model.FEATURE_CONVERTER_CLS(pack=False)(
        ds, task_feature_lengths=self.sequence_length)
    model_ds = model_ds.batch(self.batch_size)

    inferences = (tokens for batch in model_ds.as_numpy_iterator()
                  for tokens in self.predict_tokens(batch))

    predictions = []
    for example, tokens in zip(ds.as_numpy_iterator(), inferences):
      predictions.append(self.postprocess(tokens, example))

    result = metrics_utils.event_predictions_to_ns(
        predictions, codec=self.codec, encoding_spec=self.encoding_spec)
    return result['est_ns']

  def audio_to_dataset(self, audio):
    """Create a TF Dataset of spectrograms from input audio."""
    frames, frame_times = self._audio_to_frames(audio)
    return tf.data.Dataset.from_tensors({
        'inputs': frames,
        'input_times': frame_times,
    })

  def _audio_to_frames(self, audio):
    """Compute spectrogram frames from audio."""
    frame_size = self.spectrogram_config.hop_width
    padding = [0, frame_size - len(audio) % frame_size]
    audio = np.pad(audio, padding, mode='constant')
    frames = spectrograms.split_audio(audio, self.spectrogram_config)
    num_frames = len(audio) // frame_size
    times = np.arange(num_frames) / self.spectrogram_config.frames_per_second
    return frames, times

  def preprocess(self, ds):
    pp_chain = [
        functools.partial(
            t5.data.preprocessors.split_tokens_to_inputs_length,
            sequence_length=self.sequence_length,
            output_features=self.output_features,
            feature_key='inputs',
            additional_feature_keys=['input_times']),
        # Cache occurs here during training.
        preprocessors.add_dummy_targets,
        functools.partial(
            preprocessors.compute_spectrograms,
            spectrogram_config=self.spectrogram_config)
    ]
    for pp in pp_chain:
      ds = pp(ds)
    return ds

  def postprocess(self, tokens, example):
    tokens = self._trim_eos(tokens)
    start_time = example['input_times'][0]
    # Round down to nearest symbolic token step.
    start_time -= start_time % (1 / self.codec.steps_per_second)
    return {
        'est_tokens': tokens,
        'start_time': start_time,
        # Internal MT3 code expects raw inputs, not used here.
        'raw_inputs': []
    }

  @staticmethod
  def _trim_eos(tokens):
    tokens = np.array(tokens, np.int32)
    if vocabularies.DECODED_EOS_ID in tokens:
      tokens = tokens[:np.argmax(tokens == vocabularies.DECODED_EOS_ID)]
    return tokens

print("Imports complete!")

In [None]:
#@title 3. Load Model
#@markdown The `ismir2021` model transcribes piano only, with note velocities.
#@markdown The `mt3` model transcribes multiple simultaneous instruments,
#@markdown but without velocities.

MODEL = "ismir2021" #@param["ismir2021", "mt3"]

checkpoint_path = f'/content/checkpoints/{MODEL}/'

print(f"Loading {MODEL} model...")
inference_model = InferenceModel(checkpoint_path, MODEL)
print("Model loaded!")

In [None]:
#@title 4. Mount Google Drive
#@markdown Configure paths for batch processing.

from google.colab import drive
drive.mount('/content/drive')

# Configure paths
BASE_DIR = '/content/drive/MyDrive/clef_baseline'  #@param {type:"string"}
INPUT_DIR = f'{BASE_DIR}/audio'
OUTPUT_DIR = f'{BASE_DIR}/midi'

# Create directories if needed
os.makedirs(INPUT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Input directory:  {INPUT_DIR}")
print(f"Output directory: {OUTPUT_DIR}")

In [None]:
#@title 5. List Audio Files

import glob

# Find all audio files (wav, mp3, flac)
audio_files = []
for ext in ['*.wav', '*.mp3', '*.flac', '*.WAV', '*.MP3', '*.FLAC']:
    audio_files.extend(glob.glob(f'{INPUT_DIR}/**/{ext}', recursive=True))

audio_files = sorted(audio_files)
print(f"Found {len(audio_files)} audio files")

# Show first 10
for f in audio_files[:10]:
    print(f"  - {os.path.relpath(f, INPUT_DIR)}")
if len(audio_files) > 10:
    print(f"  ... and {len(audio_files) - 10} more")

In [None]:
#@title 6. Batch Transcribe
#@markdown This will transcribe all audio files.
#@markdown Files that already have MIDI output will be skipped.

import time
from tqdm.notebook import tqdm

def transcribe_file(audio_path, input_dir, output_dir):
    """Transcribe a single audio file to MIDI."""
    # Generate output path (preserve subdirectory structure)
    rel_path = os.path.relpath(audio_path, input_dir)
    midi_name = os.path.splitext(rel_path)[0] + '.mid'
    midi_path = os.path.join(output_dir, midi_name)
    
    # Skip if already exists
    if os.path.exists(midi_path):
        return midi_path, 'skipped'
    
    # Create subdirectory if needed
    os.makedirs(os.path.dirname(midi_path), exist_ok=True)
    
    try:
        # Load audio
        audio, _ = librosa.load(audio_path, sr=SAMPLE_RATE)
        
        # Transcribe
        est_ns = inference_model(audio)
        
        # Save MIDI
        note_seq.sequence_proto_to_midi_file(est_ns, midi_path)
        
        return midi_path, 'success'
    except Exception as e:
        return midi_path, f'error: {str(e)}'

# Transcribe all files
results = []
start_time = time.time()

for audio_path in tqdm(audio_files, desc="Transcribing"):
    midi_path, status = transcribe_file(audio_path, INPUT_DIR, OUTPUT_DIR)
    results.append({
        'audio': os.path.relpath(audio_path, INPUT_DIR),
        'midi': os.path.relpath(midi_path, OUTPUT_DIR),
        'status': status
    })
    
    # Print progress for long-running jobs
    if len(results) % 10 == 0:
        elapsed = time.time() - start_time
        avg_time = elapsed / len(results)
        remaining = avg_time * (len(audio_files) - len(results))
        print(f"  Progress: {len(results)}/{len(audio_files)} | "
              f"Avg: {avg_time:.1f}s/file | "
              f"ETA: {remaining/60:.1f} min")

elapsed = time.time() - start_time
print(f"\nCompleted in {elapsed/60:.1f} minutes")
print(f"Average: {elapsed/len(audio_files):.1f} seconds per file")

In [None]:
#@title 7. Summary

import pandas as pd

df = pd.DataFrame(results)
print("\n" + "="*50)
print("TRANSCRIPTION SUMMARY")
print("="*50)
print(f"Total files:  {len(df)}")
print(f"Success:      {len(df[df['status'] == 'success'])}")
print(f"Skipped:      {len(df[df['status'] == 'skipped'])}")
print(f"Errors:       {len(df[~df['status'].isin(['success', 'skipped'])])}")
print("="*50)

# Show errors if any
errors = df[~df['status'].isin(['success', 'skipped'])]
if len(errors) > 0:
    print("\nErrors:")
    for _, row in errors.iterrows():
        print(f"  {row['audio']}: {row['status']}")

# Save results log
log_path = f'{OUTPUT_DIR}/transcription_log.csv'
df.to_csv(log_path, index=False)
print(f"\nLog saved to: {log_path}")

In [None]:
#@title 8. List Output Files

midi_files = glob.glob(f'{OUTPUT_DIR}/**/*.mid', recursive=True)
print(f"Generated {len(midi_files)} MIDI files in {OUTPUT_DIR}")
print()

for f in sorted(midi_files)[:20]:
    size_kb = os.path.getsize(f) / 1024
    print(f"  {os.path.relpath(f, OUTPUT_DIR):50} ({size_kb:.1f} KB)")
if len(midi_files) > 20:
    print(f"  ... and {len(midi_files) - 20} more")

## Next Steps

1. Download MIDI files from Google Drive: `clef_baseline/midi/`
2. Run music21 quantization locally: `python scripts/batch_quantize.py`
3. Run MV2H evaluation: `bash scripts/batch_evaluate.sh`

See `docs/baseline-mt3-plan.md` for full instructions.