# 🎶 **YourMT3+**  
"YourMT3+: Multi-instrument Music Transcription with Enhanced Transformer Architectures and Cross-dataset Stem Augmentation" <br>
<i>Sungkyun Chang, Emmanouil Benetos, Holger Kirchhoff and Simon Dixon,  <a href="https://arxiv.org/abs/2407.04822">IEEE MLSP 2024</a> (to appear)</i>

<div>
<img src="https://i.imgur.com/yfa53Xn.jpeg" width="800" />
</div>



#### 🥁 How to use:
1. Execute the code blocks below in sequence.
2. In the GradIO interface, select either 'Audio upload' or 'YouTube' via the tabs.
3. Click on "Play" or "Transcribe".

#### Known Issues:
- After changing the model checkpoint, a restart is required.
- When transcribing music from sources outside the dataset, models trained with Pitch-shift (PS) often incorrectly transcribe segments a semitone higher or lower. This issue is not observed in the YPTF-S (noPS) model, as seen in the example at https://youtu.be/9E82wwNc7r8?si=I-WyfwJXCBDY2reh.

In [1]:
# @title Setup
!pip install awscli
!mkdir amt
!aws s3 cp s3://amt-deploy-public/amt/ /content/amt --no-sign-request --recursive
!aws s3 cp s3://amt-deploy-public/examples/ /content/examples --no-sign-request --recursive
%cd amt/src
!pip install -r requirements.txt
!pip install transformers==4.45.1
!apt-get install sox
!pip install gradio
!pip install pytube
!python -m pip install -U yt-dlp[default]

Collecting awscli
  Downloading awscli-1.40.0-py3-none-any.whl.metadata (11 kB)
Collecting botocore==1.38.1 (from awscli)
  Downloading botocore-1.38.1-py3-none-any.whl.metadata (5.7 kB)
Collecting docutils<=0.19,>=0.18.1 (from awscli)
  Downloading docutils-0.19-py3-none-any.whl.metadata (2.7 kB)
Collecting s3transfer<0.13.0,>=0.12.0 (from awscli)
  Downloading s3transfer-0.12.0-py3-none-any.whl.metadata (1.7 kB)
Collecting colorama<0.4.7,>=0.2.5 (from awscli)
  Downloading colorama-0.4.6-py2.py3-none-any.whl.metadata (17 kB)
Collecting rsa<4.8,>=3.1.2 (from awscli)
  Downloading rsa-4.7.2-py3-none-any.whl.metadata (3.6 kB)
Collecting jmespath<2.0.0,>=0.7.1 (from botocore==1.38.1->awscli)
  Downloading jmespath-1.0.1-py3-none-any.whl.metadata (7.6 kB)
Downloading awscli-1.40.0-py3-none-any.whl (4.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.7/4.7 MB[0m [31m67.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading botocore-1.38.1-py3-none-any.whl (13.5 MB)
[2K   

In [2]:
!pip install numpy==1.23.5

Collecting numpy==1.23.5
  Downloading numpy-1.23.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.3 kB)
Downloading numpy-1.23.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.1/17.1 MB[0m [31m61.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.26.4
    Uninstalling numpy-1.26.4:
      Successfully uninstalled numpy-1.26.4
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
scikit-image 0.25.2 requires numpy>=1.24, but you have numpy 1.23.5 which is incompatible.
xarray 2025.1.2 requires numpy>=1.24, but you have numpy 1.23.5 which is incompatible.
tensorflow 2.18.0 requires numpy<2.1.0,>=1.26.0, but you have numpy 1.23.5 which is incompatible.
albucore

In [1]:
# @title Model helper
%cd /content/amt/src
from collections import Counter
import argparse
import torch
import numpy as np

from model.init_train import initialize_trainer, update_config
from utils.task_manager import TaskManager
from config.vocabulary import drum_vocab_presets
from utils.utils import str2bool
from utils.utils import Timer
from utils.audio import slice_padded_array
from utils.note2event import mix_notes
from utils.event2note import merge_zipped_note_events_and_ties_to_notes
from utils.utils import write_model_output_as_midi, write_err_cnt_as_json
from model.ymt3 import YourMT3

def load_model_checkpoint(args=None):
    parser = argparse.ArgumentParser(description="YourMT3")
    # General
    parser.add_argument('exp_id', type=str, help='A unique identifier for the experiment is used to resume training. The "@" symbol can be used to load a specific checkpoint.')
    parser.add_argument('-p', '--project', type=str, default='ymt3', help='project name')
    parser.add_argument('-ac', '--audio-codec', type=str, default=None, help='audio codec (default=None). {"spec", "melspec"}. If None, default value defined in config.py will be used.')
    parser.add_argument('-hop', '--hop-length', type=int, default=None, help='hop length in frames (default=None). {128, 300} 128 for MT3, 300 for PerceiverTFIf None, default value defined in config.py will be used.')
    parser.add_argument('-nmel', '--n-mels', type=int, default=None, help='number of mel bins (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-if', '--input-frames', type=int, default=None, help='number of audio frames for input segment (default=None). If None, default value defined in config.py will be used.')
    # Model configurations
    parser.add_argument('-sqr', '--sca-use-query-residual', type=str2bool, default=None, help='sca use query residual flag. Default follows config.py')
    parser.add_argument('-enc', '--encoder-type', type=str, default=None, help="Encoder type. 't5' or 'perceiver-tf' or 'conformer'. Default is 't5', following config.py.")
    parser.add_argument('-dec', '--decoder-type', type=str, default=None, help="Decoder type. 't5' or 'multi-t5'. Default is 't5', following config.py.")
    parser.add_argument('-preenc', '--pre-encoder-type', type=str, default='default', help="Pre-encoder type. None or 'conv' or 'default'. By default, t5_enc:None, perceiver_tf_enc:conv, conformer:None")
    parser.add_argument('-predec', '--pre-decoder-type', type=str, default='default', help="Pre-decoder type. {None, 'linear', 'conv1', 'mlp', 'group_linear'} or 'default'. Default is {'t5': None, 'perceiver-tf': 'linear', 'conformer': None}.")
    parser.add_argument('-cout', '--conv-out-channels', type=int, default=None, help='Number of filters for pre-encoder conv layer. Default follows "model_cfg" of config.py.')
    parser.add_argument('-tenc', '--task-cond-encoder', type=str2bool, default=True, help='task conditional encoder (default=True). True or False')
    parser.add_argument('-tdec', '--task-cond-decoder', type=str2bool, default=True, help='task conditional decoder (default=True). True or False')
    parser.add_argument('-df', '--d-feat', type=int, default=None, help='Audio feature will be projected to this dimension for Q,K,V of T5 or K,V of Perceiver (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-pt', '--pretrained', type=str2bool, default=False, help='pretrained T5(default=False). True or False')
    parser.add_argument('-b', '--base-name', type=str, default="google/t5-v1_1-small", help='base model name (default="google/t5-v1_1-small")')
    parser.add_argument('-epe', '--encoder-position-encoding-type', type=str, default='default', help="Positional encoding type of encoder. By default, pre-defined PE for T5 or Perceiver-TF encoder in config.py. For T5: {'sinusoidal', 'trainable'}, conformer: {'rotary', 'trainable'}, Perceiver-TF: {'trainable', 'rope', 'alibi', 'alibit', 'None', '0', 'none', 'tkd', 'td', 'tk', 'kdt'}.")
    parser.add_argument('-dpe', '--decoder-position-encoding-type', type=str, default='default', help="Positional encoding type of decoder. By default, pre-defined PE for T5 in config.py. {'sinusoidal', 'trainable'}.")
    parser.add_argument('-twe', '--tie-word-embedding', type=str2bool, default=None, help='tie word embedding (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-el', '--event-length', type=int, default=None, help='event length (default=None). If None, default value defined in model cfg of config.py will be used.')
    # Perceiver-TF configurations
    parser.add_argument('-dl', '--d-latent', type=int, default=None, help='Latent dimension of Perceiver. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-nl', '--num-latents', type=int, default=None, help='Number of latents of Perceiver. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-dpm', '--perceiver-tf-d-model', type=int, default=None, help='Perceiver-TF d_model (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-npb', '--num-perceiver-tf-blocks', type=int, default=None, help='Number of blocks of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py.')
    parser.add_argument('-npl', '--num-perceiver-tf-local-transformers-per-block', type=int, default=None, help='Number of local layers per block of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-npt', '--num-perceiver-tf-temporal-transformers-per-block', type=int, default=None, help='Number of temporal layers per block of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-atc', '--attention-to-channel', type=str2bool, default=None, help='Attention to channel flag of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-ln', '--layer-norm-type', type=str, default=None, help='Layer normalization type (default=None). {"layer_norm", "rms_norm"}. If None, default value defined in config.py will be used.')
    parser.add_argument('-ff', '--ff-layer-type', type=str, default=None, help='Feed forward layer type (default=None). {"mlp", "moe", "gmlp"}. If None, default value defined in config.py will be used.')
    parser.add_argument('-wf', '--ff-widening-factor', type=int, default=None, help='Feed forward layer widening factor for MLP/MoE/gMLP (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-nmoe', '--moe-num-experts', type=int, default=None, help='Number of experts for MoE (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-kmoe', '--moe-topk', type=int, default=None, help='Top-k for MoE (default=None). If None, default value defined in config.py will be used.')
    parser.add_argument('-act', '--hidden-act', type=str, default=None, help='Hidden activation function (default=None). {"gelu", "silu", "relu", "tanh"}. If None, default value defined in config.py will be used.')
    parser.add_argument('-rt', '--rotary-type', type=str, default=None, help='Rotary embedding type expressed in three letters. e.g. ppl: "pixel" for SCA and latents, "lang" for temporal transformer. If None, use config.')
    parser.add_argument('-rk', '--rope-apply-to-keys', type=str2bool, default=None, help='Apply rope to keys (default=None). If None, use config.')
    parser.add_argument('-rp', '--rope-partial-pe', type=str2bool, default=None, help='Whether to apply RoPE to partial positions (default=None). If None, use config.')
    # Decoder configurations
    parser.add_argument('-dff', '--decoder-ff-layer-type', type=str, default=None, help='Feed forward layer type of decoder (default=None). {"mlp", "moe", "gmlp"}. If None, default value defined in config.py will be used.')
    parser.add_argument('-dwf', '--decoder-ff-widening-factor', type=int, default=None, help='Feed forward layer widening factor for decoder MLP/MoE/gMLP (default=None). If None, default value defined in config.py will be used.')
    # Task and Evaluation configurations
    parser.add_argument('-tk', '--task', type=str, default='mt3_full_plus', help='tokenizer type (default=mt3_full_plus). See config/task.py for more options.')
    parser.add_argument('-epv', '--eval-program-vocab', type=str, default=None, help='evaluation vocabulary (default=None). If None, default vocabulary of the data preset will be used.')
    parser.add_argument('-edv', '--eval-drum-vocab', type=str, default=None, help='evaluation vocabulary for drum (default=None). If None, default vocabulary of the data preset will be used.')
    parser.add_argument('-etk', '--eval-subtask-key', type=str, default='default', help='evaluation subtask key (default=default). See config/task.py for more options.')
    parser.add_argument('-t', '--onset-tolerance', type=float, default=0.05, help='onset tolerance (default=0.05).')
    parser.add_argument('-os', '--test-octave-shift', type=str2bool, default=False, help='test optimal octave shift (default=False). True or False')
    parser.add_argument('-w', '--write-model-output', type=str2bool, default=True, help='write model test output to file (default=False). True or False')
    # Trainer configurations
    parser.add_argument('-pr','--precision', type=str, default="bf16-mixed", help='precision (default="bf16-mixed") {32, 16, bf16, bf16-mixed}')
    parser.add_argument('-st', '--strategy', type=str, default='auto', help='strategy (default=auto). auto or deepspeed or ddp')
    parser.add_argument('-n', '--num-nodes', type=int, default=1, help='number of nodes (default=1)')
    parser.add_argument('-g', '--num-gpus', type=str, default='auto', help='number of gpus (default="auto")')
    parser.add_argument('-wb', '--wandb-mode', type=str, default="disabled", help='wandb mode for logging (default=None). "disabled" or "online" or "offline". If None, default value defined in config.py will be used.')
    # Debug
    parser.add_argument('-debug', '--debug-mode', type=str2bool, default=False, help='debug mode (default=False). True or False')
    parser.add_argument('-tps', '--test-pitch-shift', type=int, default=None, help='use pitch shift when testing. debug-purpose only. (default=None). semitone in int.')
    args = parser.parse_args(args)
    # yapf: enable
    if torch.__version__ >= "1.13":
        torch.set_float32_matmul_precision("high")
    args.epochs = None

    # Initialize and update config
    _, _, dir_info, shared_cfg = initialize_trainer(args, stage='test')
    shared_cfg, audio_cfg, model_cfg = update_config(args, shared_cfg, stage='test')

    if args.eval_drum_vocab != None:  # override eval_drum_vocab
        eval_drum_vocab = drum_vocab_presets[args.eval_drum_vocab]

    # Initialize task manager
    tm = TaskManager(task_name=args.task,
                     max_shift_steps=int(shared_cfg["TOKENIZER"]["max_shift_steps"]),
                     debug_mode=args.debug_mode)
    print(f"Task: {tm.task_name}, Max Shift Steps: {tm.max_shift_steps}")

    # Use GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Model
    model = YourMT3(
        audio_cfg=audio_cfg,
        model_cfg=model_cfg,
        shared_cfg=shared_cfg,
        optimizer=None,
        task_manager=tm,  # tokenizer is a member of task_manager
        eval_subtask_key=args.eval_subtask_key,
        write_output_dir=dir_info["lightning_dir"] if args.write_model_output or args.test_octave_shift else None
        ).to(device)
    # Set weights_only=False to load the entire model
    checkpoint = torch.load(dir_info["last_ckpt_path"], weights_only=False)
    state_dict = checkpoint['state_dict']
    new_state_dict = {k: v for k, v in state_dict.items() if 'pitchshift' not in k}
    model.load_state_dict(new_state_dict, strict=False)
    return model.eval()


def transcribe(model, audio_info):
    t = Timer()

    # Converting Audio
    t.start()
    audio, sr = torchaudio.load(uri=audio_info['filepath'])
    audio = torch.mean(audio, dim=0).unsqueeze(0)
    audio = torchaudio.functional.resample(audio, sr, model.audio_cfg['sample_rate'])
    audio_segments = slice_padded_array(audio, model.audio_cfg['input_frames'], model.audio_cfg['input_frames'])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    audio_segments = torch.from_numpy(audio_segments.astype('float32')).to(device).unsqueeze(1) # (n_seg, 1, seg_sz)
    t.stop(); t.print_elapsed_time("converting audio");

    # Inference
    t.start()
    pred_token_arr, _ = model.inference_file(bsz=8, audio_segments=audio_segments)
    t.stop(); t.print_elapsed_time("model inference");

    # Post-processing
    t.start()
    num_channels = model.task_manager.num_decoding_channels
    n_items = audio_segments.shape[0]
    start_secs_file = [model.audio_cfg['input_frames'] * i / model.audio_cfg['sample_rate'] for i in range(n_items)]
    pred_notes_in_file = []
    n_err_cnt = Counter()
    for ch in range(num_channels):
        pred_token_arr_ch = [arr[:, ch, :] for arr in pred_token_arr]  # (B, L)
        zipped_note_events_and_tie, list_events, ne_err_cnt = model.task_manager.detokenize_list_batches(
            pred_token_arr_ch, start_secs_file, return_events=True)
        pred_notes_ch, n_err_cnt_ch = merge_zipped_note_events_and_ties_to_notes(zipped_note_events_and_tie)
        pred_notes_in_file.append(pred_notes_ch)
        n_err_cnt += n_err_cnt_ch
    pred_notes = mix_notes(pred_notes_in_file)  # This is the mixed notes from all channels

    # Write MIDI
    write_model_output_as_midi(pred_notes, '/content/',
                              audio_info['track_name'], model.midi_output_inverse_vocab)
    t.stop(); t.print_elapsed_time("post processing");
    midifile =  os.path.join('/content/model_output/', audio_info['track_name']  + '.mid')
    assert os.path.exists(midifile)
    return midifile


/content/amt/src


  @autocast(enabled=False)
  @autocast(enabled=False)


In [8]:
# @title HTML helper
import re
import base64
def to_data_url(midi_filename):
    """ This is crucial for Colab/WandB support. Thanks to Scott Hawley!!
        https://github.com/drscotthawley/midi-player/blob/main/midi_player/midi_player.py

    """
    with open(midi_filename, "rb") as f:
        encoded_string = base64.b64encode(f.read())
    return 'data:audio/midi;base64,'+encoded_string.decode('utf-8')


def to_youtube_embed_url(video_url):
    regex = r"(?:https:\/\/)?(?:www\.)?(?:youtube\.com|youtu\.be)\/(?:watch\?v=)?(.+)"
    return re.sub(regex, r"https://www.youtube.com/embed/\1",video_url)


def create_html_from_midi(midifile):
    html_template = """
<!DOCTYPE html>
<html>
<head>
  <title>Awesome MIDI Player</title>
  <script src="https://cdn.jsdelivr.net/combine/npm/tone@14.7.58,npm/@magenta/music@1.23.1/es6/core.js,npm/focus-visible@5,npm/html-midi-player@1.5.0">
  </script>
  <style>
    /* Background color for the section */
    #proll {{background-color:transparent}}
    /* Custom player style */
    #proll midi-player {{
      display: block;
      width: inherit;
      margin: 4px;
      margin-bottom: 0;
      transform-origin: top;
      transform: scaleY(0.8); /* Added scaleY */
    }}
    #proll midi-player::part(control-panel) {{
      background: #d8dae880;
      border-radius: 8px 8px 0 0;
      border: 1px solid #A0A0A0;
    }}
    /* Custom visualizer style */
    #proll midi-visualizer .piano-roll-visualizer {{
      background: #45507328;
      border-radius: 0 0 8px 8px;
      border: 1px solid #A0A0A0;
      margin: 4px;
      margin-top: 1;
      overflow: auto;
      transform-origin: top;
      transform: scaleY(0.8); /* Added scaleY */
    }}
    #proll midi-visualizer svg rect.note {{
      opacity: 0.6;
      stroke-width: 2;
    }}
    #proll midi-visualizer svg rect.note[data-instrument="0"] {{
      fill: #e22;
      stroke: #055;
    }}
    #proll midi-visualizer svg rect.note[data-instrument="2"] {{
      fill: #2ee;
      stroke: #055;
    }}
    #proll midi-visualizer svg rect.note[data-is-drum="true"] {{
      fill: #888;
      stroke: #888;
    }}
    #proll midi-visualizer svg rect.note.active {{
      opacity: 0.9;
      stroke: #34384F;
    }}
    /* Media queries for responsive scaling */
    @media (max-width: 700px) {{ #proll midi-visualizer .piano-roll-visualizer {{transform-origin: top; transform: scaleY(0.75);}} }}
    @media (max-width: 500px) {{ #proll midi-visualizer .piano-roll-visualizer {{transform-origin: top; transform: scaleY(0.7);}} }}
    @media (max-width: 400px) {{ #proll midi-visualizer .piano-roll-visualizer {{transform-origin: top; transform: scaleY(0.6);}} }}
    @media (max-width: 300px) {{ #proll midi-visualizer .piano-roll-visualizer {{transform-origin: top; transform: scaleY(0.5);}} }}
  </style>
</head>
<body>
  <div>
    <a href="{midifile}" target="_blank" style="font-size: 14px;">Download MIDI</a> <br>
  </div>
  <div>
    <section id="proll">
      <midi-player src="{midifile}" sound-font="https://storage.googleapis.com/magentadata/js/soundfonts/sgm_plus" visualizer="#proll midi-visualizer">
      </midi-player>
      <midi-visualizer src="{midifile}">
      </midi-visualizer>
    </section>
  </div>
</body>
</html>
""".format(midifile=midifile)
    html = f"""<div style="display: flex; justify-content: center; align-items: center;">
                  <iframe style="width: 100%; height: 500px; overflow:hidden" srcdoc='{html_template}'></iframe>
            </div>"""
    return html


def create_html_youtube_player(youtube_url):
    youtube_url = to_youtube_embed_url(youtube_url)
    html = f"""<div style="display: flex; justify-content: center; align-items: center;">
                  <iframe width=560 height=100% src='{youtube_url}' title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>
              </div>"""
    return html


In [9]:
# @title GradIO helper
import os
import subprocess
import glob
from typing import Tuple, Dict, Literal
from ctypes import ArgumentError
from google.colab import output

from pytube import YouTube
import gradio as gr
import torchaudio

def prepare_media(source_path_or_url: os.PathLike,
                  source_type: Literal['audio_filepath', 'youtube_url'],
                  delete_video: bool = True) -> Dict:
    """prepare media from source path or youtube, and return audio info"""
    # Get audio_file
    if source_type == 'audio_filepath':
        audio_file = source_path_or_url
    elif source_type == 'youtube_url':
        # Download from youtube
        try:
            # Try PyTube first
            proxy_handler = {"http": "http://127.0.0.1:1087", "https":"http://127.0.0.1:1087"}
            yt = YouTube(source_path_or_url, proxies=proxy_handler)
            # yt = YouTube(source_path_or_url)
            audio_stream = min(yt.streams.filter(only_audio=True), key=lambda s: s.bitrate)
            mp4_file = audio_stream.download(output_path='downloaded') # ./downloaded
            audio_file = mp4_file[:-3] + 'mp3'
            subprocess.run(['ffmpeg', '-i', mp4_file, '-ac', '1', audio_file])
            os.remove(mp4_file)
        except Exception as e:
            try:
                # Try alternative
                print(f"Failed with PyTube, error: {e}. Trying yt-dlp...")
                audio_file = './downloaded/yt_audio'
                subprocess.run(['yt-dlp', '-x', source_path_or_url, '-f', 'bestaudio',
                    '-o', audio_file, '--audio-format', 'mp3', '--restrict-filenames',
                    '--force-overwrites'])
                # subprocess.run(['yt-dlp', '-x', source_path_or_url, '-f', 'bestaudio',
                #     '-o', audio_file, '--audio-format', 'mp3', '--restrict-filenames',
                #     '--force-overwrites', '--cookiefile', '/content/cookies.txt'])
                audio_file += '.mp3'
            except Exception as e:
                print(f"Alternative downloader failed, error: {e}. Please try again later!")
                return None
    else:
        raise ValueError(source_type)

    # Create info
    info = torchaudio.info(audio_file)
    return {
        "filepath": audio_file,
        "track_name": os.path.basename(audio_file).split('.')[0],
        "sample_rate": int(info.sample_rate),
        "bits_per_sample": int(info.bits_per_sample),
        "num_channels": int(info.num_channels),
        "num_frames": int(info.num_frames),
        "duration": int(info.num_frames / info.sample_rate),
        "encoding": str.lower(info.encoding),
        }

def process_audio(audio_filepath):
    if audio_filepath is None:
        return None
    audio_info = prepare_media(audio_filepath, source_type='audio_filepath')
    midifile = transcribe(model, audio_info)
    midifile = to_data_url(midifile)
    return create_html_from_midi(midifile) # html midiplayer

def process_video(youtube_url):
    if 'youtu' not in youtube_url:
        return None
    audio_info = prepare_media(youtube_url, source_type='youtube_url')
    midifile = transcribe(model, audio_info)
    midifile = to_data_url(midifile)
    return create_html_from_midi(midifile) # html midiplayer

def play_video(youtube_url):
    if 'youtu' not in youtube_url:
        return None
    return create_html_youtube_player(youtube_url)


In [4]:
# @title Load Checkpoint
model_name = 'YPTF.MoE+Multi (noPS)' # @param ["YMT3+", "YPTF+Single (noPS)", "YPTF+Multi (PS)", "YPTF.MoE+Multi (noPS)", "YPTF.MoE+Multi (PS)"]
precision = 'bf16-mixed' # @param ["32", "bf16-mixed", "16"]
project = '2024'

if model_name == "YMT3+":
    checkpoint = "notask_all_cross_v6_xk2_amp0811_gm_ext_plus_nops_b72@model.ckpt"
    args = [checkpoint, '-p', project, '-pr', precision]
elif model_name == "YPTF+Single (noPS)":
    checkpoint = "ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt"
    args = [checkpoint, '-p', project, '-enc', 'perceiver-tf', '-ac', 'spec',
            '-hop', '300', '-atc', '1', '-pr', precision]
elif model_name == "YPTF+Multi (PS)":
    checkpoint = "mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k@model.ckpt"
    args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256',
            '-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf',
            '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
elif model_name == "YPTF.MoE+Multi (noPS)":
    checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
    args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
            '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
            '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
            '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
elif model_name == "YPTF.MoE+Multi (PS)":
    checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt"
    args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
            '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
            '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
            '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
else:
    raise ValueError(model_name)

model = load_model_checkpoint(args=args)

INFO:pytorch_lightning.utilities.rank_zero:Using bfloat16 Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
INFO:pytorch_lightning.utilities.rank_zero:`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
INFO:pytorch_lightning.utilities.rank_zero:`Trainer(limit_test_batches=1.0)` was configured so 100% of the batches will be used..


Resuming from ../logs/2024/mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops/checkpoints/last.ckpt
Task: mc13_full_plus_256, Max Shift Steps: 206
"add_melody_metric_to_singing": True
"add_pitch_class_metric":       None
"audio_cfg":                    {'codec': 'spec', 'hop_length': 300, 'audio_backend': 'torchaudio', 'sample_rate': 16000, 'input_frames': 32767, 'n_fft': 2048, 'n_mels': 512, 'f_min': 50.0, 'f_max': 8000.0}
"base_lr":                      None
"eval_drum_vocab":              None
"eval_subtask_key":             default
"eval_vocab":                   None
"init_factor":                  None
"max_steps":                    None
"model_cfg":                    {'encoder_type': 'perceiver-tf', 'decoder_type': 'multi-t5', 'pre_encoder_type': 'conv', 'pre_encoder_type_default': {'t5': None, 'perceiver-tf': 'conv', 'conformer': None}, 'pre_decoder_type': 'mc_shared_linear', 'pre_decoder_type_default': {'t5': {'t5': None}, 'perceiver-tf': {'t5': 'linear', 'mu

In [None]:
# @title Run GradIO
output.no_vertical_scroll()

AUDIO_EXAMPLES = glob.glob('/content/examples/*.*', recursive=True)
YOUTUBE_EXAMPLES = ["https://www.youtube.com/watch?v=vMboypSkj3c", "https://youtu.be/OXXRoa1U6xU?si=nhJ6lzGenCmk4P7R",
                    "https://youtu.be/EOJ0wH6h3rE?si=a99k6BnSajvNmXcn", "https://youtu.be/7mjQooXt28o?si=qqmMxCxwqBlLPDI2",
                    "https://youtu.be/bnS-HK_lTHA?si=PQLVAab3QHMbv0S3https://youtu.be/zJB0nnOc7bM?si=EA1DN8nHWJcpQWp_",
                    "https://youtu.be/mIWYTg55h10?si=WkbtKfL6NlNquvT8"]

# theme = 'gradio/dracula_revamped' #'Insuz/Mocha' #gr.themes.Soft()
theme = gr.Theme.from_hub("gradio/dracula_revamped")
# theme.text_sm = '5px'
theme.text_lg = '10px'
theme.text_md = '9px'

theme.body_background_fill_dark = '#060a1c' #'#372037'# '#a17ba5' #'#73d3ac'
theme.border_color_primary_dark = '#45507328'
theme.block_background_fill_dark = '#3845685c'

theme.body_text_color_dark = 'white'
theme.block_title_text_color_dark = 'black'

theme.body_text_color_subdued_dark = '#e4e9e9'


# css = ".gradio-container {background: url(https://miro.medium.com/v2/resize:fit:3840/format:webp/1*AcYLHh0_ve4TNRi6HLFcPA.jpeg)}"
css = """
.gradio-container {
    background: linear-gradient(-45deg, #ee7752, #e73c7e, #23a6d5, #23d5ab);
    background-size: 400% 400%;
    animation: gradient 15s ease infinite;
    height: 100vh;
}
@keyframes gradient {
    0% {
        background-position: 0% 50%;
    }
    50% {
        background-position: 100% 50%;
    }
    100% {
        background-position: 0% 50%;
    }
}
"""


with gr.Blocks(theme=theme, css=css) as demo:

    with gr.Row():
        with gr.Column(scale=10):
            gr.Markdown(
            f"""
            ## 🎶YourMT3+: Multi-instrument Music Transcription with Enhanced Transformer Architectures and Cross-dataset Stem Augmentation

            ### Model name: `{model_name}`

            <div style="display: inline-block;">
                <a href="https://arxiv.org/abs/2407.04822">
                    <img src="https://img.shields.io/badge/arXiv-B31B1B?logo=arxiv&logoColor=fff&style=plastic" alt="arXiv Badge"/>
                </a>
            </div>
            <div style="display: inline-block;">
                <a href="https://github.com/mimbres/YourMT3">
                    <img src="https://img.shields.io/badge/GitHub-181717?logo=github&logoColor=fff&style=plastic" alt="GitHub Badge"/>
                </a>
            </div>
            <div style="display: inline-block;">
                <a href="https://huggingface.co/spaces/mimbres/YourMT3">
                    <img src="https://img.shields.io/badge/Model%20on-🤗-1f425f.svg?style=plastic" alt="Hugging Face Badge"/>
                </a>
            </div>
            """)

    with gr.Group():
        with gr.Tab("Upload audio"):
            # Input
            audio_input = gr.Audio(label="Record Audio", type="filepath",
                                show_share_button=True, show_download_button=True)
            # Display examples
            gr.Examples(examples=AUDIO_EXAMPLES, inputs=audio_input)
            # Submit button
            transcribe_audio_button = gr.Button("Transcribe", variant="primary")
            # Transcribe
            output_tab1 = gr.HTML()
            transcribe_audio_button.click(process_audio, inputs=audio_input, outputs=output_tab1)

        with gr.Tab("From YouTube"):
            with gr.Column(scale=4):
                # Input URL
                youtube_url = gr.Textbox(label="YouTube Link URL",
                        placeholder="https://youtu.be/...")
                # Play youtube
                youtube_player = gr.HTML(render=True)

                # Display examples
                gr.Examples(examples=YOUTUBE_EXAMPLES, inputs=youtube_url)
            with gr.Column(scale=4):
                # Play button
                play_video_button = gr.Button("Get Audio from YouTube", variant="primary")
                # Submit button
                transcribe_video_button = gr.Button("Transcribe", variant="primary")
            with gr.Column(scale=1):
                # Transcribe
                output_tab2 = gr.HTML(render=True)
                # video_output = gr.Text(label="Video Info")
                transcribe_video_button.click(process_video, inputs=youtube_url, outputs=output_tab2)
                # Play
                play_video_button.click(play_video, inputs=youtube_url, outputs=youtube_player)

# demo.launch(debug=True)
demo.launch(debug=True, allowed_paths=["/content/examples"])



<IPython.core.display.Javascript object>

It looks like you are running Gradio on a hosted a Jupyter notebook. For the Gradio app to work, sharing must be enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://1f288158dae1524287.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
