<a href="https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/demos/pitch_detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##### Copyright 2021 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");





In [None]:
# Copyright 2021 Google LLC. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# DDSP Pitch Detection Demo

This notebook is a demo of pitch detection using inverse audio synthesis. 


* [ICML Workshop paper](https://openreview.net/forum?id=RlVTYWhsky7)
* [Audio Examples](http://goo.gl/magenta/ddsp-inv) 

This notebook extracts these features from input audio (either uploaded files, or recorded from the microphone) and resynthesizes the audio from the model. The DDSP-INV model is hierarchical, and provides both resynthesis from the sinusoidal model and harmonic model.

<img src="https://storage.googleapis.com/ddsp-inv/full_stack/diagram.png" alt="DDSP Pitch Detection" width="700">


### Instructions for running:

* Make sure to use a GPU runtime, click:  __Runtime >> Change Runtime Type >> GPU__
* Press ▶️ on the left of each of the cells
* View the code: Double-click any of the cells
* Hide the code: Double click the right side of the cell





In [None]:
#@title # Step 1: Install DDSP

#@markdown Install ddsp in a conda environment with Python 3.9 for compatibility.
#@markdown This transfers a lot of data and _should take about 5 minutes_.
#@markdown You can ignore warnings.

!rm -rf /content/miniconda
!curl -L https://repo.anaconda.com/miniconda/Miniconda3-py39_23.11.0-2-Linux-x86_64.sh -o miniconda.sh
!chmod +x miniconda.sh
!sh miniconda.sh -b -p /content/miniconda
!sudo apt-get install -y libportaudio2
!/content/miniconda/bin/conda install -y -c conda-forge cudatoolkit=11.2 cudnn=8.1
!/content/miniconda/bin/pip install tensorflow==2.11 tensorflow-probability==0.19.0 tensorflowjs==3.18.0 tensorflow-datasets==4.9.0 tflite-support==0.1.0a1 ddsp==3.7.0 hmmlearn
print('\nDone installing DDSP in conda environment!')

In [None]:
#@title # Step 2: Record or Upload Audio
#@markdown * Either record audio from microphone or upload audio from file (.mp3 or .wav) 
#@markdown * Audio should be monophonic (single instrument / voice)

record_or_upload = "Upload (.mp3 or .wav)"  #@param ["Record", "Upload (.mp3 or .wav)"]

record_seconds =     5#@param {type:"number", min:1, max:10, step:1}

import warnings
warnings.filterwarnings("ignore")

import base64
import io
import os

import numpy as np
import matplotlib.pyplot as plt
from IPython import display
from scipy.io import wavfile

from google.colab import files as colab_files
from google.colab import output

SAMPLE_RATE = 16000


def play(array_of_floats, sample_rate=SAMPLE_RATE):
  """Play audio in colab using HTML5 audio widget."""
  if len(array_of_floats.shape) == 2:
    array_of_floats = array_of_floats[0]
  normalizer = float(np.iinfo(np.int16).max)
  array_of_ints = np.array(
      np.asarray(array_of_floats) * normalizer, dtype=np.int16)
  memfile = io.BytesIO()
  wavfile.write(memfile, sample_rate, array_of_ints)
  html = """<audio controls>
              <source controls src="data:audio/wav;base64,{base64_wavfile}"
              type="audio/wav" />
              Your browser does not support the audio element.
            </audio>"""
  html = html.format(
      base64_wavfile=base64.b64encode(memfile.getvalue()).decode('ascii'))
  memfile.close()
  display.display(display.HTML(html))


def record_audio(seconds=3, sample_rate=SAMPLE_RATE):
  """Record audio from the browser microphone."""
  record_js_code = """
  const sleep  = time => new Promise(resolve => setTimeout(resolve, time))
  const b2text = blob => new Promise(resolve => {
    const reader = new FileReader()
    reader.onloadend = e => resolve(e.srcElement.result)
    reader.readAsDataURL(blob)
  })

  var record = time => new Promise(async resolve => {
    stream = await navigator.mediaDevices.getUserMedia({ audio: true })
    recorder = new MediaRecorder(stream)
    chunks = []
    recorder.ondataavailable = e => chunks.push(e.data)
    recorder.start()
    await sleep(time)
    recorder.onstop = async ()=>{
      blob = new Blob(chunks)
      text = await b2text(blob)
      resolve(text)
    }
    recorder.stop()
  })
  """
  print('Starting recording for {} seconds...'.format(seconds))
  display.display(display.Javascript(record_js_code))
  audio_string = output.eval_js('record(%d)' % (seconds * 1000.0))
  print('Finished recording!')
  audio_bytes = base64.b64decode(audio_string.split(',')[1])
  # Convert bytes to numpy using pydub
  from pydub import AudioSegment
  segment = AudioSegment.from_file(io.BytesIO(audio_bytes))
  segment = segment.set_frame_rate(sample_rate).set_channels(1).set_sample_width(2)
  samples = np.array(segment.get_array_of_samples()).astype(np.float32)
  samples = samples / float(np.iinfo(np.int16).max)
  return samples


def upload_audio(sample_rate=SAMPLE_RATE):
  """Upload audio files and return (filenames, audio_arrays)."""
  from pydub import AudioSegment
  audio_files = colab_files.upload()
  fnames = list(audio_files.keys())
  audios = []
  for fname in fnames:
    segment = AudioSegment.from_file(io.BytesIO(audio_files[fname]))
    segment = segment.set_frame_rate(sample_rate).set_channels(1).set_sample_width(2)
    samples = np.array(segment.get_array_of_samples()).astype(np.float32)
    samples = samples / float(np.iinfo(np.int16).max)
    audios.append(samples)
  return fnames, audios


def specplot(audio, vmin=-5, vmax=1, rotate=True, size=512 + 256):
  """Plot the log magnitude spectrogram of audio."""
  if len(audio.shape) == 2:
    audio = audio[0]
  # Compute spectrogram using numpy/scipy (no ddsp needed)
  from scipy import signal as scipy_signal
  f, t, Sxx = scipy_signal.stft(audio, fs=SAMPLE_RATE, nperseg=size,
                                 noverlap=size * 3 // 4)
  logmag = np.log10(np.abs(Sxx) + 1e-7)
  if rotate:
    logmag = np.flipud(logmag)
  plt.matshow(logmag, vmin=vmin, vmax=vmax, cmap=plt.cm.magma, aspect='auto')
  plt.xticks([])
  plt.yticks([])
  plt.xlabel('Time')
  plt.ylabel('Frequency')


# --- Record or Upload ---
if record_or_upload == "Record":
  audio = record_audio(seconds=record_seconds)
else:
  filenames, audios = upload_audio()
  audio = audios[0]

if len(audio.shape) == 1:
  audio = audio[np.newaxis, :]

# Save audio for the inference script
np.save('/content/input_audio.npy', audio)
print(f'Audio shape: {audio.shape}, saved to /content/input_audio.npy')

# Plot and play
specplot(audio)
play(audio)


In [None]:
#@title # Step 3: Run Pitch Detection

#@markdown Choose a pretrained model and run pitch detection.
#@markdown Models separately trained on the [URMP](http://www2.ece.rochester.edu/projects/air/projects/URMP/annotations_5P.html), [MDB-stem-synth](https://zenodo.org/record/1481172#.Xzouy5NKhTY), and [MIR1k](https://sites.google.com/site/unvoicedsoundseparation/mir-1k) datasets.

model = 'urmp' #@param ['urmp', 'mdb_stem_synth', 'mir1k']


# Create inference script
SCRIPT = r'''
"""DDSP-INV Pitch Detection inference script.
Runs inside conda environment with Python 3.9 and ddsp==3.7.0.
Reads input audio, loads model, runs DDSP-INV and CREPE pitch detection,
and writes output data as .npy files.
"""
import argparse
import os
import time
import warnings
warnings.filterwarnings("ignore")

import numpy as np

import ddsp
import ddsp.training
import gin
import tensorflow.compat.v2 as tf


def main():
  parser = argparse.ArgumentParser()
  parser.add_argument('--audio_path', required=True)
  parser.add_argument('--output_dir', default='/content/pitch_output')
  parser.add_argument('--model', default='urmp')
  args = parser.parse_args()

  os.makedirs(args.output_dir, exist_ok=True)

  # Load audio
  audio = np.load(args.audio_path)
  print(f'Loaded audio: shape={audio.shape}')

  # --- Download and load model checkpoint ---
  PRETRAINED_DIR = '/content/pretrained'
  os.system(f'rm -rf {PRETRAINED_DIR}')
  os.makedirs(PRETRAINED_DIR, exist_ok=True)
  GCS_CKPT_DIR = 'gs://ddsp-inv/ckpts'
  model_dir_gcs = os.path.join(GCS_CKPT_DIR, '%s_ckpt' % args.model.lower())
  os.system(f'gsutil cp {model_dir_gcs}/* {PRETRAINED_DIR}')
  model_dir = PRETRAINED_DIR

  # Find gin config file
  gin_file_pattern = os.path.join(model_dir, 'operative_config*.gin')
  gin_file = tf.io.gfile.glob(gin_file_pattern)[0]

  # The old checkpoints use 'TranscribingAutoencoder' as the gin class name,
  # but it has been renamed to 'InverseSynthesis'. Patch the gin config.
  with open(gin_file, 'r') as f:
    gin_text = f.read()
  gin_text = gin_text.replace('TranscribingAutoencoder', 'InverseSynthesis')
  patched_gin_file = os.path.join(model_dir, 'patched_config.gin')
  with open(patched_gin_file, 'w') as f:
    f.write(gin_text)

  # Parse gin config
  with gin.unlock_config():
    gin.parse_config_file(patched_gin_file, skip_unknown=True)

  # Find checkpoint
  ckpt_files = [f for f in tf.io.gfile.listdir(model_dir) if 'ckpt' in f]
  ckpt_name = ckpt_files[0].split('.')[0]
  ckpt = os.path.join(model_dir, ckpt_name)

  # Ensure dimensions and sampling rates are equal
  time_steps_train = 125
  n_samples_train = 64000
  hop_size = int(n_samples_train / time_steps_train)

  time_steps = int(audio.shape[1] / hop_size)
  n_samples = time_steps * hop_size
  audio = audio[:, :n_samples]

  gin_params = [
      'InverseSynthesis.n_samples = {}'.format(n_samples),
      'oscillator_bank.use_angular_cumsum = True',
  ]

  with gin.unlock_config():
    gin.parse_config(gin_params)

  # --- Build and restore model ---
  print('Loading model...')
  model = ddsp.training.models.InverseSynthesis()
  model.restore(ckpt)

  # Build model by running a batch through it.
  start_time = time.time()
  _ = model({'audio': audio}, training=False)
  print('Restoring model took %.1f seconds' % (time.time() - start_time))

  # --- Predict with DDSP-INV ---
  start_time = time.time()
  print('\nExtracting f0 with DDSP-INV...')
  controls = model({'audio': audio}, training=False)
  print('Prediction took %.1f seconds' % (time.time() - start_time))

  # --- Predict with CREPE ---
  start_time = time.time()
  print('\nExtracting f0 with CREPE...')
  ddsp.spectral_ops.reset_crepe()
  f0_crepe, f0_confidence = ddsp.spectral_ops.compute_f0(
      audio[0],
      sample_rate=16000,
      frame_rate=31.25,
      viterbi=False)
  print('Prediction took %.1f seconds' % (time.time() - start_time))

  # --- Synthesize comparison audio ---
  synth = ddsp.synths.Wavetable(n_samples=n_samples, scale_fn=None)
  wavetable = np.sin(np.linspace(0, 2.0 * np.pi, 2048))[np.newaxis, np.newaxis, :]
  amps = np.ones([1, time_steps, 1]) * 0.1
  audio_crepe = synth(amps, wavetable, f0_crepe[np.newaxis, :, np.newaxis])
  audio_ddsp_inv = synth(controls['harm_amp'], wavetable, controls['f0_hz'])

  # --- Convert all tensors to numpy and save ---
  def to_np(x):
    return np.array(x.numpy() if hasattr(x, 'numpy') else x)

  np.save(os.path.join(args.output_dir, 'audio.npy'), to_np(audio))
  np.save(os.path.join(args.output_dir, 'sin_audio.npy'), to_np(controls['sin_audio']))
  np.save(os.path.join(args.output_dir, 'harm_audio.npy'), to_np(controls['harm_audio']))
  np.save(os.path.join(args.output_dir, 'f0_hz.npy'), to_np(controls['f0_hz']))
  np.save(os.path.join(args.output_dir, 'sin_freqs.npy'), to_np(controls['sin_freqs']))
  np.save(os.path.join(args.output_dir, 'sin_amps.npy'), to_np(controls['sin_amps']))
  np.save(os.path.join(args.output_dir, 'harm_freqs.npy'), to_np(controls['harm_freqs']))
  np.save(os.path.join(args.output_dir, 'harm_amps.npy'), to_np(controls['harm_amps']))
  np.save(os.path.join(args.output_dir, 'audio_ddsp_inv.npy'), to_np(audio_ddsp_inv))
  np.save(os.path.join(args.output_dir, 'audio_crepe.npy'), to_np(audio_crepe))
  np.save(os.path.join(args.output_dir, 'f0_crepe.npy'), to_np(f0_crepe))

  print('Done! Outputs saved to', args.output_dir)


if __name__ == '__main__':
  main()
'''

with open('/content/pitch_detection_inference.py', 'w') as f:
  f.write(SCRIPT)
print('Inference script written to /content/pitch_detection_inference.py')


# Run inference in conda environment
cmd = (
    "unset PYTHONPATH PYTHONHOME && "
    "export LD_LIBRARY_PATH=/content/miniconda/lib:$LD_LIBRARY_PATH && "
    "/content/miniconda/bin/python /content/pitch_detection_inference.py "
    f"--audio_path=/content/input_audio.npy "
    f"--output_dir=/content/pitch_output "
    f"--model={model}"
)
print('Running pitch detection...')
!{cmd}

In [None]:
#@title # Step 4: View Results

#@markdown Load and display the pitch detection results.

import warnings
warnings.filterwarnings("ignore")

import base64
import io

import numpy as np
import matplotlib.pyplot as plt
from IPython import display
from scipy.io import wavfile

SAMPLE_RATE = 16000


def play(array_of_floats, sample_rate=SAMPLE_RATE):
  if len(array_of_floats.shape) == 2:
    array_of_floats = array_of_floats[0]
  normalizer = float(np.iinfo(np.int16).max)
  array_of_ints = np.array(
      np.asarray(array_of_floats) * normalizer, dtype=np.int16)
  memfile = io.BytesIO()
  wavfile.write(memfile, sample_rate, array_of_ints)
  html = """<audio controls>
              <source controls src="data:audio/wav;base64,{base64_wavfile}"
              type="audio/wav" />
              Your browser does not support the audio element.
            </audio>"""
  html = html.format(
      base64_wavfile=base64.b64encode(memfile.getvalue()).decode('ascii'))
  memfile.close()
  display.display(display.HTML(html))


def specplot(audio, vmin=-5, vmax=1, rotate=True, size=512 + 256):
  if len(audio.shape) == 2:
    audio = audio[0]
  from scipy import signal as scipy_signal
  f, t, Sxx = scipy_signal.stft(audio, fs=SAMPLE_RATE, nperseg=size,
                                 noverlap=size * 3 // 4)
  logmag = np.log10(np.abs(Sxx) + 1e-7)
  if rotate:
    logmag = np.flipud(logmag)
  plt.matshow(logmag, vmin=vmin, vmax=vmax, cmap=plt.cm.magma, aspect='auto')
  plt.xticks([])
  plt.yticks([])
  plt.xlabel('Time')
  plt.ylabel('Frequency')


def hz_to_midi(hz):
  """Convert Hz to MIDI note number (no ddsp needed)."""
  return 12.0 * (np.log2(np.maximum(hz, 1e-7)) - np.log2(440.0)) + 69.0


# --- Load outputs ---
output_dir = '/content/pitch_output'
audio = np.load(f'{output_dir}/audio.npy')
sin_audio = np.load(f'{output_dir}/sin_audio.npy')
harm_audio = np.load(f'{output_dir}/harm_audio.npy')
f0_hz = np.load(f'{output_dir}/f0_hz.npy')
sin_freqs = np.load(f'{output_dir}/sin_freqs.npy')
sin_amps = np.load(f'{output_dir}/sin_amps.npy')
harm_freqs = np.load(f'{output_dir}/harm_freqs.npy')
harm_amps = np.load(f'{output_dir}/harm_amps.npy')
audio_ddsp_inv = np.load(f'{output_dir}/audio_ddsp_inv.npy')
audio_crepe = np.load(f'{output_dir}/audio_crepe.npy')
f0_crepe = np.load(f'{output_dir}/f0_crepe.npy')

k = 0

# --- Plot Pitch Comparison ---
plt.figure(figsize=(6, 4))
f0_crepe_midi = hz_to_midi(f0_crepe)
f0_harm_midi = hz_to_midi(f0_hz)
plt.plot(np.ravel(f0_crepe_midi), label='crepe')
plt.plot(np.ravel(f0_harm_midi[k]), label='ddsp-inv')
plt.ylabel('Pitch (MIDI)')
plt.xlabel('Time')
plt.xticks([])
plt.legend(loc='upper right')
plt.show()

# --- Audio Playback ---
print('Original')
play(audio)

print('Sinusoidal Resynthesis')
play(sin_audio[k])

print('Harmonic Resynthesis')
play(harm_audio[k])

print('DDSP-INV Pitch')
play(audio_ddsp_inv[k])

print('CREPE Pitch')
play(audio_crepe[k])

# --- Spectrograms ---
specplot(audio)
plt.title("Original")
plt.show()

specplot(sin_audio[k])
_ = plt.title("Sinusoidal Resynthesis")
plt.show()

specplot(harm_audio[k])
_ = plt.title("Harmonic Resynthesis")
plt.show()

# --- Sinusoid Scatter Plots ---
plt.figure(figsize=(6, 6))
t = np.arange(sin_freqs.shape[1])
for a, f in zip(np.transpose(sin_amps[k]), np.transpose(sin_freqs[k])):
  plt.scatter(t, f, s=a*200, linewidths=1)
  plt.ylim(0, 8000)
plt.title('Sinusoids (Sinusoidal)')
plt.ylabel('Frequency (Hz)')
plt.xlabel('Time')
plt.xticks([])
plt.show()

plt.figure(figsize=(6, 6))
t = np.arange(harm_freqs.shape[1])
for a, f in zip(np.transpose(harm_amps[k]), np.transpose(harm_freqs[k])):
  plt.scatter(t, f, s=a*200, linewidths=1)
  plt.ylim(0, 8000)
plt.title('Sinusoids (Harmonic)')
plt.ylabel('Frequency (Hz)')
plt.xlabel('Time')
_ = plt.xticks([])
plt.show()