<a href="https://colab.research.google.com/github/magenta/mt3/blob/main/mt3/colab/music_transcription_with_transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="在Colab中打开"/></a> <-打开原Colab笔记本
# 使用多乐器转录模型进行音乐转录 
此笔记本是一个互动演示,展示谷歌[Magenta](g.co/magenta)团队创建的一些[音乐转录模型](g.co/magenta/mt3)。您可以上传音频,让我们的一个模型自动转录它。
<img src="https://magenta.tensorflow.org/assets/transcription-with-transformers/architecture_diagram.png" alt="基于Transformer的转录体系结构">
该笔记本支持两种预训练模型:
1. 我们在 [ISMIR 2021论文](https://archives.ismir.net/ismir2021/paper/000030.pdf)中提出的钢琴转录模型 
2. 我们在[ICLR 2022论文](https://openreview.net/pdf?id=iMSjopcOn0p)中提出的多乐器转录模型 
***警告***:两个模型都没有针对歌声进行训练。如果您上传包含人声的音频,您很可能会得到奇怪的结果。多乐器转录还不是一个完全成熟的模型，仍存在未解决的问题,所以无论如何您都可能会得到奇怪的结果。  
无论如何,我们希望您在转录的过程中玩得开心!如果有任何有趣的输出,欢迎在推特 [@GoogleMagenta](https://twitter.com/googlemagenta)上发文。
### 运行说明: 
* 尽量确保使用GPU运行时,点击: __运行时>更改运行时类型> GPU__ （CPU也支持，但是运算速度较慢）
* 单击每个单元格左侧的▶️来执行该单元格 
* 在*__加载模型__*单元格中,选择`ismir2021`进行钢琴转录或`mt3`进行多乐器转录 
* 在*__上传音频__*单元格中,当提示时从电脑中选择MP3或WAV文件 
* 使用*__转录音频__*单元格转录音频(根据音频的长度,可能需要几分钟) 
--- 
此笔记本会向Google Analytics发送基本的使用数据。有关更多信息,请参阅[Google的隐私政策](https://policies.google.com/privacy)。 

In [None]:
#欢迎使用 MT3:多任务多轨道音乐转录

#Copyright 2021 Google LLC。保留所有权利。

#根据 Apache 许可证 2.0 版("许可证")授权
#仅允许在遵守许可证的情况下使用此文件
#您可以通过以下网址获得许可证副本:

#     http://www.apache.org/licenses/LICENSE-2.0

#除非适用法律要求或书面同意,否则在本许可证下分发的
#软件是在“原样”基础上分发的,不附带任何种类的保证或
#条件,无论明示还是默示。

#有关授权和限制的具体语言,请参阅许可证。

# ==============================================================================


# @title 配置环境
# @markdown 安装MT3及其依赖项(可能需要几分钟)。

!apt-get update -qq && apt-get install -qq libfluidsynth2 build-essential libasound2-dev libjack-dev

#使用cuda驱动程序升级jax,否则t5x将其替换为非cuda版本
!pip install "jax[cuda11_local]>=0.4.10" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

#安装mt3
!git clone --branch=main https://github.com/magenta/mt3
!mv mt3 mt3_tmp; mv mt3_tmp/* .; rm -r mt3_tmp

# TODO(iansimon):移除锁定的numba/llvmlite/ddsp后ddsp更新
!python3 -m pip install nest-asyncio numba==0.56.4 llvmlite==0.39.1 pyfluidsynth==1.3.0 -e .
!python3 -m pip install --no-dependencies --upgrade ddsp

#复制检查点
!gsutil -q -m cp -r gs://mt3/checkpoints .

#复制音色(最初来自https://sites.google.com/site/soundfonts4u)
!gsutil -q -m cp gs://magentadata/soundfonts/SGM-v2.01-Sal-Guit-Bass-V1.3.sf2 .

import json
import IPython

# 以下函数(load_gtag和log_event)用于处理谷歌分析
#事件日志记录。日志记录是匿名的,只存储音频和转录的非
#常基本的统计信息,例如音频长度,转录音符的数量。

def load_gtag():
  
  """Loads gtag.js."""
  #注意:gtag.js 必须在执行合成的同一单元格中加载。
  #它不会跨单元格执行持续存在!

  html_code = '''
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4P250YRJ08"></script>
<script>
  window.dataLayer = window.dataLayer || [];
  function gtag(){dataLayer.push(arguments);}
  gtag('js', new Date());
  gtag('config', 'G-4P250YRJ08',
       {'referrer': document.referrer.split('?')[0],
        'anonymize_ip': true,
        'page_title': '',
        'page_referrer': '',
        'cookie_prefix': 'magenta',
        'cookie_domain': 'auto',
        'cookie_expires': 0,
        'cookie_flags': 'SameSite=None;Secure'});
</script>
'''
  IPython.display.display(IPython.display.HTML(html_code))

def log_event(event_name, event_details):
  """Log event with name and details dictionary."""
  details_json = json.dumps(event_details)
  js_string = "gtag('event', '%s', %s);" % (event_name, details_json)
  IPython.display.display(IPython.display.Javascript(js_string))

load_gtag()
log_event('setupComplete', {})

In [None]:
#@title 模块导入与定义

import functools
import os

import numpy as np
import tensorflow.compat.v2 as tf

import functools
import gin
import jax
import librosa
import note_seq
import seqio
import t5
import t5x

from mt3 import metrics_utils
from mt3 import models
from mt3 import network
from mt3 import note_sequences
from mt3 import preprocessors
from mt3 import spectrograms
from mt3 import vocabularies

from google.colab import files

import nest_asyncio
nest_asyncio.apply()

SAMPLE_RATE = 16000
SF2_PATH = 'SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'

def upload_audio(sample_rate):
  data = list(files.upload().values())
  if len(data) > 1:
    print('您上传了多个文件。请上传单个文件')
  return note_seq.audio_io.wav_data_to_samples_librosa(
    data[0], sample_rate=sample_rate)



class InferenceModel(object):
  """用于音乐转录的T5X模型"""

  def __init__(self, checkpoint_path, model_type='mt3'):

    # 模型相关常量定义
    if model_type == 'ismir2021':
      num_velocity_bins = 127
      self.encoding_spec = note_sequences.NoteEncodingSpec
      self.inputs_length = 512
    elif model_type == 'mt3':
      num_velocity_bins = 1
      self.encoding_spec = note_sequences.NoteEncodingWithTiesSpec
      self.inputs_length = 256
    else:
      raise ValueError('unknown model_type: %s' % model_type)

    gin_files = ['/content/mt3/gin/model.gin',
                 f'/content/mt3/gin/{model_type}.gin']

    self.batch_size = 8
    self.outputs_length = 1024
    self.sequence_length = {'inputs': self.inputs_length, 
                            'targets': self.outputs_length}

    self.partitioner = t5x.partitioning.PjitPartitioner(
        num_partitions=1)

    # 构建编解码器和词典
    self.spectrogram_config = spectrograms.SpectrogramConfig()
    self.codec = vocabularies.build_codec(
        vocab_config=vocabularies.VocabularyConfig(
            num_velocity_bins=num_velocity_bins))
    self.vocabulary = vocabularies.vocabulary_from_codec(self.codec)
    self.output_features = {
        'inputs': seqio.ContinuousFeature(dtype=tf.float32, rank=2),
        'targets': seqio.Feature(vocabulary=self.vocabulary),
    }

    # 创建T5X模型
    self._parse_gin(gin_files)
    self.model = self._load_model()

    # 从检查点恢复
    self.restore_from_checkpoint(checkpoint_path)

  @property
  def input_shapes(self):
    return {
          'encoder_input_tokens': (self.batch_size, self.inputs_length),
          'decoder_input_tokens': (self.batch_size, self.outputs_length)
    }

  def _parse_gin(self, gin_files):
    """解析用于训练模型的gin文件"""
    gin_bindings = [
        'from __gin__ import dynamic_registration',
        'from mt3 import vocabularies',
        'VOCAB_CONFIG=@vocabularies.VocabularyConfig()',
        'vocabularies.VocabularyConfig.num_velocity_bins=%NUM_VELOCITY_BINS'
    ]
    with gin.unlock_config():
      gin.parse_config_files_and_bindings(
          gin_files, gin_bindings, finalize_config=False)

  def _load_model(self):
    """在解析训练所用的gin配置后,加载T5X模型。"""
    model_config = gin.get_configurable(network.T5Config)()
    module = network.Transformer(config=model_config)
    return models.ContinuousInputsEncoderDecoderModel(
        module=module,
        input_vocabulary=self.output_features['inputs'].vocabulary,
        output_vocabulary=self.output_features['targets'].vocabulary,
        optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
        input_depth=spectrograms.input_depth(self.spectrogram_config))


  def restore_from_checkpoint(self, checkpoint_path):
    """从检查点恢复训练状态,并重置self._predict_fn()。"""
    train_state_initializer = t5x.utils.TrainStateInitializer(
      optimizer_def=self.model.optimizer_def,
      init_fn=self.model.get_initial_variables,
      input_shapes=self.input_shapes,
      partitioner=self.partitioner)

    restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(
        path=checkpoint_path, mode='specific', dtype='float32')

    train_state_axes = train_state_initializer.train_state_axes
    self._predict_fn = self._get_predict_fn(train_state_axes)
    self._train_state = train_state_initializer.from_checkpoint_or_scratch(
        [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))

  @functools.lru_cache()
  def _get_predict_fn(self, train_state_axes):
    """生成一个分区预测函数用于解码。"""
    def partial_predict_fn(params, batch, decode_rng):
      return self.model.predict_batch_with_aux(
          params, batch, decoder_params={'decode_rng': None})
    return self.partitioner.partition(
        partial_predict_fn,
        in_axis_resources=(
            train_state_axes.params,
            t5x.partitioning.PartitionSpec('data',), None),
        out_axis_resources=t5x.partitioning.PartitionSpec('data',)
    )

  def predict_tokens(self, batch, seed=0):
    """从预处理的数据集batch中预测token。"""
    prediction, _ = self._predict_fn(
        self._train_state.params, batch, jax.random.PRNGKey(seed))
    return self.vocabulary.decode_tf(prediction).numpy()

  def __call__(self, audio):
    """预测音频样本的音符
    
    Args:
      audio: 音频样本 （16kHz） 的 1-d numpy 数组，用于单个示例.

    Returns:
      转录音频的note_sequence.
    """
    ds = self.audio_to_dataset(audio)
    ds = self.preprocess(ds)

    model_ds = self.model.FEATURE_CONVERTER_CLS(pack=False)(
        ds, task_feature_lengths=self.sequence_length)
    model_ds = model_ds.batch(self.batch_size)

    inferences = (tokens for batch in model_ds.as_numpy_iterator()
                  for tokens in self.predict_tokens(batch))

    predictions = []
    for example, tokens in zip(ds.as_numpy_iterator(), inferences):
      predictions.append(self.postprocess(tokens, example))

    result = metrics_utils.event_predictions_to_ns(
        predictions, codec=self.codec, encoding_spec=self.encoding_spec)
    return result['est_ns']

  def audio_to_dataset(self, audio):
    """创建输入音频频谱图的 TF 数据集。"""
    frames, frame_times = self._audio_to_frames(audio)
    return tf.data.Dataset.from_tensors({
        'inputs': frames,
        'input_times': frame_times,
    })

  def _audio_to_frames(self, audio):
    """计算音频频谱图帧。"""
    frame_size = self.spectrogram_config.hop_width
    padding = [0, frame_size - len(audio) % frame_size]
    audio = np.pad(audio, padding, mode='constant')
    frames = spectrograms.split_audio(audio, self.spectrogram_config)
    num_frames = len(audio) // frame_size
    times = np.arange(num_frames) / self.spectrogram_config.frames_per_second
    return frames, times

  def preprocess(self, ds):
    pp_chain = [
        functools.partial(
            t5.data.preprocessors.split_tokens_to_inputs_length,
            sequence_length=self.sequence_length,
            output_features=self.output_features,
            feature_key='inputs',
            additional_feature_keys=['input_times']),
        # 训练期间缓存。
        preprocessors.add_dummy_targets,
        functools.partial(
            preprocessors.compute_spectrograms,
            spectrogram_config=self.spectrogram_config)
    ]
    for pp in pp_chain:
      ds = pp(ds)
    return ds

  def postprocess(self, tokens, example):
    tokens = self._trim_eos(tokens)
    start_time = example['input_times'][0]
    # 向下舍入至最接近的token step。
    start_time -= start_time % (1 / self.codec.steps_per_second)
    return {
        'est_tokens': tokens,
        'start_time': start_time,
        # 内部MT3代码需要原始输入，此处不使用。
        'raw_inputs': []
    }

  @staticmethod
  def _trim_eos(tokens):
    tokens = np.array(tokens, np.int32)
    if vocabularies.DECODED_EOS_ID in tokens:
      tokens = tokens[:np.argmax(tokens == vocabularies.DECODED_EOS_ID)]
    return tokens



In [None]:
# @title 修复T5X,jax以及jaxlib
# @markdown 当  `模块导入与定义`  单元格报错时,运行此单元格并重启。 \
# @markdown * 启用CPU会导致GPU不可用！！！！！

#卸载T5X  jax  jaxlib
!pip uninstall jax jaxlib t5x

#安装T5X
!pip install git+https://github.com/google-research/t5x.git@2b010160e7fe8a4505a6d1032a7b737a633636e5

#安装jax jaxlib
!pip install git+https://github.com/google/jax.git@47df8628a0fa83900e38431c88a7a0e27660b7aa
!pip install "jax[cuda11_local]==0.4.9" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
#!pip install jaxlib==0.4.9  #启用此项并禁用上一项以切换至CPU

In [None]:
# @title 重启运行时后运行此单元格
import json
import IPython

# 以下函数(load_gtag和log_event)用于处理谷歌分析
#事件日志记录。日志记录是匿名的,只存储音频和转录的非
#常基本的统计信息,例如音频长度,转录音符的数量。

def load_gtag():
  
  """Loads gtag.js."""
  #注意:gtag.js 必须在执行合成的同一单元格中加载。
  #它不会跨单元格执行持续存在!

  html_code = '''
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4P250YRJ08"></script>
<script>
  window.dataLayer = window.dataLayer || [];
  function gtag(){dataLayer.push(arguments);}
  gtag('js', new Date());
  gtag('config', 'G-4P250YRJ08',
       {'referrer': document.referrer.split('?')[0],
        'anonymize_ip': true,
        'page_title': '',
        'page_referrer': '',
        'cookie_prefix': 'magenta',
        'cookie_domain': 'auto',
        'cookie_expires': 0,
        'cookie_flags': 'SameSite=None;Secure'});
</script>
'''
  IPython.display.display(IPython.display.HTML(html_code))

def log_event(event_name, event_details):
  """Log event with name and details dictionary."""
  details_json = json.dumps(event_details)
  js_string = "gtag('event', '%s', %s);" % (event_name, details_json)
  IPython.display.display(IPython.display.Javascript(js_string))

load_gtag()
log_event('setupComplete', {})

In [None]:
#@title 加载模型
#@markdown `ismir2021` 模型仅用于钢琴音频的转录,
#@markdown 支持音符速度预测。\
#@markdown `mt3` 模型支持多乐器音频的转录，
#@markdown 但不支持音符速度预测。

MODEL = "mt3" #@param["ismir2021", "mt3"]

checkpoint_path = f'/content/checkpoints/{MODEL}/'

load_gtag()

log_event('loadModelStart', {'event_category': MODEL})
inference_model = InferenceModel(checkpoint_path, MODEL)
log_event('loadModelComplete', {'event_category': MODEL})


In [None]:
#@title 上传音频

load_gtag()

log_event('uploadAudioStart', {})
audio = upload_audio(sample_rate=SAMPLE_RATE)
log_event('uploadAudioComplete', {'value': round(len(audio) / SAMPLE_RATE)})

note_seq.notebook_utils.colab_play(audio, sample_rate=SAMPLE_RATE)

In [None]:
#@title 转录音频至MIDI
#@markdown 这可能需要几分钟时间，具体取决于您上传的音频文件的长度。

load_gtag()

log_event('transcribeStart', {
    'event_category': MODEL,
    'value': round(len(audio) / SAMPLE_RATE)
})

est_ns = inference_model(audio)

log_event('transcribeComplete', {
    'event_category': MODEL,
    'value': round(len(audio) / SAMPLE_RATE),
    'numNotes': sum(1 for note in est_ns.notes if not note.is_drum),
    'numDrumNotes': sum(1 for note in est_ns.notes if note.is_drum),
    'numPrograms': len(set(note.program for note in est_ns.notes
                           if not note.is_drum))
})

note_seq.play_sequence(est_ns, synth=note_seq.fluidsynth, 
                       sample_rate=SAMPLE_RATE, sf2_path=SF2_PATH)
note_seq.plot_sequence(est_ns)

In [None]:
#@title 下载转录好的MIDI

load_gtag()
log_event('downloadTranscription', {
    'event_category': MODEL,
    'value': round(len(audio) / SAMPLE_RATE),
    'numNotes': sum(1 for note in est_ns.notes if not note.is_drum),
    'numDrumNotes': sum(1 for note in est_ns.notes if note.is_drum),
    'numPrograms': len(set(note.program for note in est_ns.notes
                           if not note.is_drum))
})

note_seq.sequence_proto_to_midi_file(est_ns, '/tmp/transcribed.mid')
files.download('/tmp/transcribed.mid')