<a href="https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/demos/train_autoencoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


##### Copyright 2020 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");





In [None]:
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# Train a DDSP Autoencoder on GPU

This notebook demonstrates how to install the DDSP library and train it for synthesis based on your own data using our command-line scripts. If run inside of Colab, it will automatically use a free Google Cloud GPU.

At the end, you'll have a custom-trained checkpoint that you can download to use with the [DDSP Timbre Transfer Colab](https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/demos/timbre_transfer.ipynb).

<img src="https://storage.googleapis.com/ddsp/additive_diagram/ddsp_autoencoder.png" alt="DDSP Autoencoder figure" width="700">


## Install Dependencies

First we install the required dependencies using Miniconda with Python 3.9 for compatibility.

In [None]:
#@title Install DDSP

#@markdown Install ddsp in a conda environment with Python 3.9 for compatibility.
#@markdown This transfers a lot of data and _should take about 5 minutes_.
#@markdown You can ignore warnings.

!rm -rf /content/miniconda
!curl -L https://repo.anaconda.com/miniconda/Miniconda3-py39_23.11.0-2-Linux-x86_64.sh -o miniconda.sh
!chmod +x miniconda.sh
!sh miniconda.sh -b -p /content/miniconda
!sudo apt-get install -y libportaudio2
!/content/miniconda/bin/conda install -y -c conda-forge cudatoolkit=11.2 cudnn=8.1
!/content/miniconda/bin/pip install tensorflow==2.11 tensorflow-probability==0.19.0 tensorflowjs==3.18.0 tensorflow-datasets==4.9.0 tflite-support==0.1.0a1 ddsp[data_preparation]==3.7.0 hmmlearn

# Initialize global path for using google drive. 
DRIVE_DIR = ''
print('\nDone installing DDSP in conda environment!')

## Setup Google Drive (Optional, Recommeded)

This notebook requires uploading audio and saving checkpoints. While you can do this with direct uploads / downloads, it is recommended to connect to your google drive account. This will enable faster file transfer, and regular saving of checkpoints so that you do not lose your work if the colab kernel restarts (common for training more than 12 hours). 

#### Login and mount your drive

This will require an authentication code. You should then be able to see your drive in the file browser on the left panel.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

#### Set your base directory
* In drive, put all of the audio (.wav, .mp3) files with which you would like to train in a single folder.
 * Typically works well with 10-20 minutes of audio from a single monophonic source (also, one acoustic environment). 
* Use the file browser in the left panel to find a folder with your audio, right-click **"Copy Path", paste below**, and run the cell.

In [None]:
#@markdown (ex. `/content/drive/My Drive/...`) Leave blank to skip loading from Drive.
DRIVE_DIR = '' #@param {type: "string"}

import os
assert os.path.exists(DRIVE_DIR)
print('Drive Folder Exists:', DRIVE_DIR)


## Make directories to save model and data

In [None]:
AUDIO_DIR = 'data/audio'
AUDIO_FILEPATTERN = AUDIO_DIR + '/*'
!mkdir -p $AUDIO_DIR

if DRIVE_DIR:
  SAVE_DIR = os.path.join(DRIVE_DIR, 'ddsp-solo-instrument')
else:
  SAVE_DIR = '/content/models/ddsp-solo-instrument'
!mkdir -p "$SAVE_DIR"

## Prepare Dataset


#### Upload training audio

Upload audio files to use for training your model. Uses `DRIVE_DIR` if connected to drive, otherwise prompts local upload.

In [None]:
import glob
import os

if DRIVE_DIR:
  mp3_files = glob.glob(os.path.join(DRIVE_DIR, '*.mp3'))
  wav_files = glob.glob(os.path.join(DRIVE_DIR, '*.wav'))
  audio_files = mp3_files + wav_files
else:
  from google.colab import files
  uploaded = files.upload()
  audio_files = list(uploaded.keys())

for fname in audio_files:
  target_name = os.path.join(AUDIO_DIR, 
                             os.path.basename(fname).replace(' ', '_'))
  print('Copying {} to {}'.format(fname, target_name))
  !cp "$fname" $target_name

### Preprocess raw audio into TFRecord dataset

We need to do some preprocessing on the raw audio you uploaded to get it into the correct format for training. This involves turning the full audio into short (4-second) examples, inferring the fundamental frequency (or "pitch") with [CREPE](http://github.com/marl/crepe), and computing the loudness. These features will then be stored in a sharded [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) file for easier loading. Depending on the amount of input audio, this process usually takes a few minutes.

* (Optional) Transfer dataset from drive. If you've already created a dataset, from a previous run, this cell will skip the dataset creation step and copy the dataset from `$DRIVE_DIR/data` 

In [None]:
import glob
import os

TRAIN_TFRECORD = 'data/train.tfrecord'
TRAIN_TFRECORD_FILEPATTERN = TRAIN_TFRECORD + '*'

# Copy dataset from drive if dataset has already been created.
drive_data_dir = os.path.join(DRIVE_DIR, 'data') 
drive_dataset_files = glob.glob(drive_data_dir + '/*')

if DRIVE_DIR and len(drive_dataset_files) > 0:
  !cp "$drive_data_dir"/* data/

else:
  # Make a new dataset.
  if not glob.glob(AUDIO_FILEPATTERN):
    raise ValueError('No audio files found. Please use the previous cell to '
                    'upload.')

  cmd = (
      "unset PYTHONPATH PYTHONHOME && "
      "export LD_LIBRARY_PATH=/content/miniconda/lib:$LD_LIBRARY_PATH && "
      "/content/miniconda/bin/ddsp_prepare_tfrecord "
      f"--input_audio_filepatterns={AUDIO_FILEPATTERN} "
      f"--output_tfrecord_path={TRAIN_TFRECORD} "
      "--num_shards=10 "
      "--alsologtostderr"
  )
  !{cmd}

  # Copy dataset to drive for safe-keeping.
  if DRIVE_DIR:
    !mkdir "$drive_data_dir"/
    print('Saving to {}'.format(drive_data_dir))
    !cp $TRAIN_TFRECORD_FILEPATTERN "$drive_data_dir"/

### Save dataset statistics for timbre transfer

Quantile normalization helps match loudness of timbre transfer inputs to the 
loudness of the dataset, so let's calculate it here and save in a pickle file.

In [None]:
SCRIPT = r'''
import os
import pickle
import sys
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import ddsp
import ddsp.training
from ddsp.training.postprocessing import detect_notes, fit_quantile_transform
from ddsp import spectral_ops
from ddsp.core import hz_to_midi
import tensorflow.compat.v2 as tf

file_pattern = sys.argv[1]
save_dir = sys.argv[2]

data_provider = ddsp.training.data.TFRecordProvider(file_pattern)
ds = data_provider.get_batch(1, repeats=1)

loudness, power, f0, f0_conf = [], [], [], []
batch = next(iter(ds))
audio_key = "audio_16k" if "audio_16k" in batch.keys() else "audio"
i = 0
for batch in iter(ds):
    loudness.append(batch["loudness_db"])
    power.append(spectral_ops.compute_power(batch[audio_key], frame_size=1024, frame_rate=250))
    f0.append(batch["f0_hz"])
    f0_conf.append(batch["f0_confidence"])
    i += 1
print(f"Computing statistics for {i} examples.")

loudness = np.vstack(loudness)
power = np.vstack(power)
f0 = np.vstack(f0)
f0_conf = np.vstack(f0_conf)

# Trim to match dimensions (power may have +/- 1 frame from padding)
trim_end = 20
min_len = min(loudness.shape[1], power.shape[1]) - trim_end
f0_trimmed = f0[:, :min_len]
pitch_trimmed = hz_to_midi(f0_trimmed)
power_trimmed = power[:, :min_len]
loudness_trimmed = loudness[:, :min_len]
f0_conf_trimmed = f0_conf[:, :min_len]

mask_on, _ = detect_notes(loudness_trimmed, f0_conf_trimmed)
mask_on = np.logical_or(
    mask_on, np.logical_not(np.any(mask_on, axis=1, keepdims=True)))
quantile_transform = fit_quantile_transform(loudness_trimmed, mask_on)

def get_stats(x, prefix, note_mask=None):
    if note_mask is None:
        mean_max = np.mean(np.max(x, axis=-1))
        mean_min = np.mean(np.min(x, axis=-1))
    else:
        max_list = [np.max(x_i[m]) for x_i, m in zip(x, note_mask) if np.sum(m) > 0]
        mean_max = np.mean(max_list)
        min_list = [np.min(x_i[m]) for x_i, m in zip(x, note_mask) if np.sum(m) > 0]
        mean_min = np.mean(min_list)
        x = x[note_mask]
    return {
        f"mean_{prefix}": np.mean(x), f"max_{prefix}": np.max(x),
        f"min_{prefix}": np.min(x), f"mean_max_{prefix}": mean_max,
        f"mean_min_{prefix}": mean_min, f"std_{prefix}": np.std(x),
    }

ds_stats = {}
ds_stats.update(get_stats(pitch_trimmed, "pitch"))
ds_stats.update(get_stats(power_trimmed, "power"))
ds_stats.update(get_stats(loudness_trimmed, "loudness"))
ds_stats.update(get_stats(pitch_trimmed, "pitch_note", mask_on))
ds_stats.update(get_stats(power_trimmed, "power_note", mask_on))
ds_stats.update(get_stats(loudness_trimmed, "loudness_note", mask_on))
ds_stats["quantile_transform"] = quantile_transform

pickle_path = os.path.join(save_dir, "dataset_statistics.pkl")
with tf.io.gfile.GFile(pickle_path, "wb") as fp:
    pickle.dump(ds_stats, fp)
print(f"Saved dataset statistics to: {pickle_path}")
'''

with open('/content/save_stats.py', 'w') as f:
  f.write(SCRIPT)

PICKLE_FILE_PATH = os.path.join(SAVE_DIR, "dataset_statistics.pkl")
cmd = (
    "unset PYTHONPATH PYTHONHOME && "
    "export LD_LIBRARY_PATH=/content/miniconda/lib:$LD_LIBRARY_PATH && "
    "/content/miniconda/bin/python /content/save_stats.py "
    f"'{TRAIN_TFRECORD_FILEPATTERN}' '{SAVE_DIR}'"
)
!{cmd}

Let's load the dataset in the `ddsp` library and have a look at one of the examples.

In [None]:
SCRIPT = r'''
import os
import sys
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import ddsp.training

file_pattern = sys.argv[1]
output_dir = sys.argv[2]

data_provider = ddsp.training.data.TFRecordProvider(file_pattern)
dataset = data_provider.get_dataset(shuffle=False)

try:
    ex = next(iter(dataset))
except StopIteration:
    raise ValueError(
        'TFRecord contains no examples. Please try re-running the pipeline '
        'with different audio file(s).')

os.makedirs(output_dir, exist_ok=True)
for key in ex:
    val = ex[key]
    if hasattr(val, 'numpy'):
        val = val.numpy()
    np.save(os.path.join(output_dir, f'{key}.npy'), np.array(val))
print('Saved dataset example to:', output_dir)
'''

with open('/content/load_example.py', 'w') as f:
  f.write(SCRIPT)

cmd = (
    "unset PYTHONPATH PYTHONHOME && "
    "export LD_LIBRARY_PATH=/content/miniconda/lib:$LD_LIBRARY_PATH && "
    "/content/miniconda/bin/python /content/load_example.py "
    f"'{TRAIN_TFRECORD_FILEPATTERN}' '/content/example_output'"
)
!{cmd}

# --- Display results in Colab kernel ---
import warnings
warnings.filterwarnings("ignore")

import base64
import io

import numpy as np
import matplotlib.pyplot as plt
from IPython import display
from scipy.io import wavfile
from scipy import signal as scipy_signal

SAMPLE_RATE = 16000


def play(array_of_floats, sample_rate=SAMPLE_RATE):
  """Play audio in colab using HTML5 audio widget."""
  if len(array_of_floats.shape) == 2:
    array_of_floats = array_of_floats[0]
  normalizer = float(np.iinfo(np.int16).max)
  array_of_ints = np.array(
      np.asarray(array_of_floats) * normalizer, dtype=np.int16)
  memfile = io.BytesIO()
  wavfile.write(memfile, sample_rate, array_of_ints)
  html = """<audio controls>
              <source controls src="data:audio/wav;base64,{base64_wavfile}"
              type="audio/wav" />
              Your browser does not support the audio element.
            </audio>"""
  html = html.format(
      base64_wavfile=base64.b64encode(memfile.getvalue()).decode('ascii'))
  memfile.close()
  display.display(display.HTML(html))


def specplot(audio, vmin=-5, vmax=1, rotate=True, size=512 + 256):
  """Plot the log magnitude spectrogram of audio."""
  if len(audio.shape) == 2:
    audio = audio[0]
  f, t, Sxx = scipy_signal.stft(audio, fs=SAMPLE_RATE, nperseg=size,
                                 noverlap=size * 3 // 4)
  logmag = np.log10(np.abs(Sxx) + 1e-7)
  if rotate:
    logmag = np.flipud(logmag)
  plt.matshow(logmag, vmin=vmin, vmax=vmax, cmap=plt.cm.magma, aspect='auto')
  plt.xticks([])
  plt.yticks([])
  plt.xlabel('Time')
  plt.ylabel('Frequency')


# Load and display
audio = np.load('/content/example_output/audio.npy')
f0_hz = np.load('/content/example_output/f0_hz.npy')
loudness_db = np.load('/content/example_output/loudness_db.npy')
f0_confidence = np.load('/content/example_output/f0_confidence.npy')

specplot(audio)
play(audio)

f, ax = plt.subplots(3, 1, figsize=(14, 4))
x = np.linspace(0, 4.0, len(loudness_db))
ax[0].set_ylabel('loudness_db')
ax[0].plot(x, loudness_db)
ax[1].set_ylabel('F0_Hz')
ax[1].set_xlabel('seconds')
ax[1].plot(x, f0_hz)
ax[2].set_ylabel('F0_confidence')
ax[2].set_xlabel('seconds')
ax[2].plot(x, f0_confidence)

## Train Model

We will now train a "solo instrument" model. This means the model is conditioned only on the fundamental frequency (f0) and loudness with no instrument ID or latent timbre feature. If you uploaded audio of multiple instruemnts, the neural network you train will attempt to model all timbres, but will likely associate certain timbres with different f0 and loudness conditions. 

First, let's start up a [TensorBoard](https://www.tensorflow.org/tensorboard) to monitor our loss as training proceeds. 

Initially, TensorBoard will report `No dashboards are active for the current data set.`, but once training begins, the dashboards should appear.

In [None]:
%reload_ext tensorboard
import tensorboard as tb
tb.notebook.start('--logdir "{}"'.format(SAVE_DIR))

### We will now begin training. 

Note that we specify [gin configuration](https://github.com/google/gin-config) files for the both the model architecture ([solo_instrument.gin](TODO)) and the dataset ([tfrecord.gin](TODO)), which are both predefined in the library. You could also create your own. We then override some of the spefic params for `batch_size` (which is defined in in the model gin file) and the tfrecord path (which is defined in the dataset file). 

### Training Notes:
* Models typically perform well when the loss drops to the range of ~4.5-5.0.
* Depending on the dataset this can take anywhere from 5k-30k training steps usually.
* The default is set to 30k, but you can stop training at any time, and for timbre transfer, it's best to stop before the loss drops too far below ~5.0 to avoid overfitting.
* On the colab GPU, this can take from around 3-20 hours. 
* We **highly recommend** saving checkpoints directly to your drive account as colab will restart naturally after about 12 hours and you may lose all of your checkpoints.
* By default, checkpoints will be saved every 300 steps with a maximum of 10 checkpoints (at ~60MB/checkpoint this is ~600MB). Feel free to adjust these numbers depending on the frequency of saves you would like and space on your drive.
* If you're restarting a session and `DRIVE_DIR` points a directory that was previously used for training, training should resume at the last checkpoint.

In [None]:
cmd = (
    "unset PYTHONPATH PYTHONHOME && "
    "export LD_LIBRARY_PATH=/content/miniconda/lib:$LD_LIBRARY_PATH && "
    "/content/miniconda/bin/ddsp_run "
    "--mode=train "
    "--alsologtostderr "
    f"--save_dir='{SAVE_DIR}' "
    "--gin_file=models/solo_instrument.gin "
    "--gin_file=datasets/tfrecord.gin "
    f"--gin_param=\"TFRecordProvider.file_pattern='{TRAIN_TFRECORD_FILEPATTERN}'\" "
    "--gin_param=\"batch_size=16\" "
    "--gin_param=\"train_util.train.num_steps=30000\" "
    "--gin_param=\"train_util.train.steps_per_save=300\" "
    "--gin_param=\"trainers.Trainer.checkpoints_to_keep=10\""
)
!{cmd}

## Resynthesis

Check how well the model reconstructs the training data

In [None]:
SCRIPT = r'''
import os
import sys
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import ddsp.training
import gin

file_pattern = sys.argv[1]
save_dir = sys.argv[2]
output_dir = sys.argv[3]

os.makedirs(output_dir, exist_ok=True)

data_provider = ddsp.training.data.TFRecordProvider(file_pattern)
dataset = data_provider.get_batch(batch_size=1, shuffle=False)

try:
    batch = next(iter(dataset))
except:
    raise ValueError(
        'TFRecord contains no examples. Please try re-running the pipeline '
        'with different audio file(s).')

# Parse the gin config.
gin_file = os.path.join(save_dir, 'operative_config-0.gin')
gin.parse_config_file(gin_file)

# Load model
model = ddsp.training.models.Autoencoder()
model.restore(save_dir)

# Resynthesize audio.
outputs = model(batch, training=False)
audio_gen = model.get_audio_from_outputs(outputs)
audio = batch['audio']

# Convert to numpy
if hasattr(audio_gen, 'numpy'):
    audio_gen = audio_gen.numpy()
if hasattr(audio, 'numpy'):
    audio = audio.numpy()

np.save(os.path.join(output_dir, 'audio_gen.npy'), np.array(audio_gen))
np.save(os.path.join(output_dir, 'audio_orig.npy'), np.array(audio))
print('Saved resynthesis outputs to:', output_dir)
'''

with open('/content/resynthesize.py', 'w') as f:
  f.write(SCRIPT)

cmd = (
    "unset PYTHONPATH PYTHONHOME && "
    "export LD_LIBRARY_PATH=/content/miniconda/lib:$LD_LIBRARY_PATH && "
    "/content/miniconda/bin/python /content/resynthesize.py "
    f"'{TRAIN_TFRECORD_FILEPATTERN}' '{SAVE_DIR}' '/content/resynth_output'"
)
!{cmd}

# --- Display results ---
import numpy as np

audio_gen = np.load('/content/resynth_output/audio_gen.npy')
audio_orig = np.load('/content/resynth_output/audio_orig.npy')

print('Original Audio')
specplot(audio_orig)
play(audio_orig)

print('Resynthesis')
specplot(audio_gen)
play(audio_gen)

## Download Checkpoint

Below you can download the final checkpoint. You are now ready to use it in the [DDSP Timbre Tranfer Colab](https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/demos/timbre_transfer.ipynb).

In [None]:
from google.colab import files
import tensorflow as tf
import os

CHECKPOINT_ZIP = 'my_solo_instrument.zip'
latest_checkpoint_fname = os.path.basename(tf.train.latest_checkpoint(SAVE_DIR))
!cd "$SAVE_DIR" && zip $CHECKPOINT_ZIP $latest_checkpoint_fname* operative_config-0.gin dataset_statistics.pkl
!cp "$SAVE_DIR/$CHECKPOINT_ZIP" ./
files.download(CHECKPOINT_ZIP)