##### Copyright 2021 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Yamnet controlled mixing of speech enhanced audio
This notebook will mix a variable amount of speech enhanced (cleaned) and original (noisy) audio, using one of two strategies:
1. Mix a fixed ratio of cleaned and noisy audio.
1. Mix a variable ratio of cleaned and noisy audio, determined by causally running [YAMNet](https://www.tensorflow.org/hub/tutorials/yamnet) model inference. YAMNet is used here to estimate, every 0.480s, how much the audio is like music instead or speech/silence. Because the speech enhancement model will often remove music, we reduce the fraction of speech enhanced audio that is mixed whenever YAMNet detects music.

Inputs:
* A directory containing wav files of the original (uncleaned/noisy) audio.
* A directory containing wav files of the speech enhanced audio, having the same basename as the corresponding original audio file.

Outputs:
* A directory containing wav files of the mixed audio, mixed based off of either the fixed or variable strategy.

**Note**: YAMNet takes an input window of 0.960 s to make a prediction, and does so every 0.480 s (i.e., windows are overlapped by 50%). Hence, to ensure a causal mixing strategy, we implement a fixed default mix ratio for the first 0.960 s.

In [None]:
import os
from tensorflow.io import gfile
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import csv

from IPython.display import Audio, display
import ipywidgets
import matplotlib.pyplot as plt
from scipy.io import wavfile
from scipy.signal import lfilter
from scipy.signal import resample

from google.colab import widgets

tf.compat.v1.enable_eager_execution()

In [None]:
# @title Mount Google Drive
from google.colab import drive
ROOT_DIR = '/content/gdrive'
drive.mount(ROOT_DIR, force_remount=True)

In [None]:
# @title Helper function for playing audio.
def PlaySound(samples, sample_rate=16000):
  out = ipywidgets.Output()
  with out:
    display(Audio(samples, rate=sample_rate))
  display(out)

In [None]:
#@title File read/write functions.
def write_wav(filename, waveform, sample_rate=16000):
  """Write a audio waveform (float numpy array) as .wav file."""
  wavfile.write(
      filename, sample_rate,
      np.round(np.clip(waveform * 2**15, -32768, 32767)).astype(np.int16))

def read_wav(wav_path, sample_rate=16000, channel=None):
  """Read a wav file as numpy array.

  Args:
    wav_path: String, path to .wav file.
    sample_rate: Int, sample rate for audio to be converted to.
    channel: Int, option to select a particular channel for stereo audio.

  Returns:
    Audio as float numpy array.
  """
  sr_read, x = wavfile.read(wav_path)
  x = x.astype(np.float32) / (2**15)

  if sr_read != sample_rate:
    x = resample(x, int(round((float(sample_rate) / sr_read) * len(x))))
  if x.ndim > 1 and channel is not None:
    return x[:, channel]
  return x

In [None]:
# @title YAMNet inference code.
def class_names_from_csv(class_map_csv_text):
  """Returns list of class names corresponding to score vector."""
  class_names = []
  with tf.io.gfile.GFile(class_map_csv_text) as csvfile:
    reader = csv.DictReader(csvfile)
    for row in reader:
      class_names.append(row['display_name'])

  return class_names


yamnet_model = hub.load('https://tfhub.dev/google/yamnet/1')
class_map_path = yamnet_model.class_map_path().numpy()
CLASS_NAMES = class_names_from_csv(class_map_path)
# See https://github.com/tensorflow/models/blob/master/research/audioset/yamnet/yamnet_class_map.csv
INDEX_SILENCE = np.where(np.array(CLASS_NAMES) == 'Silence')[0][0]
INDICES_MUSIC = slice(132, 277)
INDICES_SPEECH = slice(0, 67)

In [None]:
# @title Code for variable mixing.


def get_music_fraction(scores):
  """For each inference window, conver YAMNet prediction to fraction of music.

  While YAMNet outputs predictions on 521 classes, here we only look at the
  fraction of music relative to music + speech + silence.

  Args:
    scores: A (N, 521) shape array of YAMNet predictions in [0, 1.0], for the
      521 classes. N is roughly input_duration / 0.480 s.

  Returns:
    music_normalized: A 1D array of fraction YAMNet predicted music in [0,
    1.0].
  """
  music = np.max(scores[:, INDICES_MUSIC], axis=1)
  speech = np.max(scores[:, INDICES_SPEECH], axis=1)
  speech_and_silence = speech + scores[:, INDEX_SILENCE]
  music_normalized = music / (music + speech_and_silence)
  return music_normalized.numpy()


def map_fraction_music_to_fraction_speech_enhancement_to_mix(
    x, threshold=0.2, ceiling=0.4):
  """A mapping from fraction music detected to non-speech-enhanced audio mixed.

  Args:
    x: A 1D array of fraction YAMNet predicted music in [0, 1.0].
    threshold: Float in [0, 1.0]; values below this are mapped to 0.0.
    ceiling: Float, the maximum output value (i.e. the output for 1.0 input).

  Returns:
    fraction_non_speech_enhanced_audio: A 1D array of fraction of
    non-speech-enhanced audio to mix in [0, 1.0].
  """
  fraction_non_speech_enhanced_audio = (x > threshold) * (x**(1 / 5)) * ceiling
  return fraction_non_speech_enhanced_audio


def get_causal_speech_enhancement_mixing_strategy(
    x, default_non_speech_enhanced_mix=0.05, num_periods_to_run_default_mix=2):
  """Get a causal speech enhancement mixing strategy.

  Args:
    x: A 1D array of fraction of non-speech-enhanced audio to mix in [0, 1.0].
    default_non_speech_enhanced_mix: Float in [0.0, 1.0], the default mix of non
      speech enhanced audio to mix.
    num_periods_to_run_default_mix: Int, number of periods to run the default
      mix for. For the YAMNet model example inference, this should be at least
      2, in order to use the YAMNet predictions in a casual way.

  Returns:
    A 1D array of fraction of non-speech enhanced audio to mix in [0,
    1.0], which is smoothed with a weighting function and shifted to be causal
    by having a fixed default mix ratio. The length of this array is one larger
    than len(x).
  """
  if num_periods_to_run_default_mix < 2:
    raise ValueError(
        'num_periods_to_run_default_mix=%d would yield non-causal result' %
        num_periods_to_run_default_mix)
  # We weight the current prediction 60%, the previous prediction 30%, and the
  # one before it 10%.
  kernel = np.array([.6, .3, .1])

  x_pad = np.append([default_non_speech_enhanced_mix] *
                    (num_periods_to_run_default_mix + len(kernel) - 1), x)
  return np.convolve(x_pad, kernel, 'valid')[:len(x) + 1]


def gen_audio_mixing_waveform(mix_strategy_discrete, samples_per_window,
                              cross_fade):
  """Map a discrete mixing strategy to a mix waveform with crossfade.

  Args:
    mix_strategy_discrete: A 1D array of fraction of non-speech enhanced audio
      to mix in [0, 1.0], each element corresponding to a 0.480 s window. This
      array should be long enough so that len(mix_strategy_discrete) *
      samples_per_window is at least as long as the input audio.
    samples_per_window: Int, number of audio samples per window.
    cross_fade: Int, number of audio samples over which to crossfade

  Returns:
    mix_continuous_crossfaded: A 1D numpy array of audio with values in [0,
    1.0], representing the fraction of non-speech-enhanced audio to mix. The
    output length is an integer multiple of samples_per_window and should be
    cropped to match the exact length of input audio.
  """
  mix_continuous = np.repeat(mix_strategy_discrete, samples_per_window)
  window = np.hanning(cross_fade)
  mix_continuous_crossfaded = lfilter(window / np.sum(window), 1,
                                      mix_continuous)
  return mix_continuous_crossfaded


def run_yamnet_mix_and_save_audio(audio_clip_subpath,
                                  input_path,
                                  input_enhanced_path,
                                  output_path,
                                  strategy='variable'):
  """Runs YAMNet inference, causally mixes speech enhanced and original audio.

  Args:
    audio_clip_subpath: String, the input .wav filename, of original and
      enhanced audio.
    input_path: String, path to directory with original audio.
    input_enhanced_path: String, path to directory with speech enhanced audio.
    output_path: String, path where the mixed audio will be saved.
    strategy: String, either 'variable' or 'fixed', for the variable mixing
      strategy utilizing YAMNet, or a baseline fixed strategy.
  """
  original_audio = read_wav(
      os.path.join(input_path, audio_clip_subpath), sample_rate=SAMPLE_RATE)
  cleaned_audio = read_wav(
      os.path.join(input_enhanced_path, audio_clip_subpath),
      sample_rate=SAMPLE_RATE,
      channel=0)
  if original_audio.shape != cleaned_audio.shape:
    raise ValueError('Cleaned audio shape does not match: %s, %s' %
                     (original_audio.shape, cleaned_audio.shape))

  scores, _, spectrogram = yamnet_model(original_audio)

  if strategy == 'variable':
    music_fraction = get_music_fraction(scores)
    mix_strategy_discrete = get_causal_speech_enhancement_mixing_strategy(
        map_fraction_music_to_fraction_speech_enhancement_to_mix(
            music_fraction))
    mix_waveform = gen_audio_mixing_waveform(mix_strategy_discrete,
                                             SAMPLES_PER_INFERENCE_PERIOD,
                                             CROSS_FADE)
    mix_waveform = mix_waveform[:original_audio.shape[0]]
  elif strategy == 'fixed':
    mix_waveform = np.ones((original_audio.shape[0],)) * FIXED_NOISE_FRACTION
  else:
    raise ValueError('Invalid mixing strategy: %s' % strategy)
  mixed_audio = mix_waveform * original_audio + (1 -
                                                 mix_waveform) * cleaned_audio

  display_audio(original_audio, cleaned_audio, mixed_audio)
  visualize_mixing(
      scores,
      spectrogram,
      original_audio,
      mix_waveform,
      mixed_audio,
      output_plot_filename=os.path.splitext(
          os.path.join(output_path, audio_clip_subpath))[0] + '.png')

  write_wav(
      os.path.join(output_path, audio_clip_subpath), mixed_audio, SAMPLE_RATE)

In [None]:
# @title Code for listening and visualizing mixing.
def visualize_mixing(scores,
                     spectrogram,
                     waveform,
                     mix_waveform,
                     mixed_waveform=None,
                     output_plot_filename=None):
  """Generates a plot showing the input and mixed audio, and YAMNet predictions.

  Args:
    scores: A N x 521, array of predictions in [0, 1.0], for the 521 classes. N
      is roughly input_duration / 0.480 s.
    spectrogram: A 2D array, the spectrogram of input audio, for visualization.
    waveform: A 1D array, the input audio.
    mix_waveform: A a 1D array in [0, 1.0] denoting the fraction noise to mix.
    output_plot_filename: String, output filename.
  """
  duration = len(waveform) / SAMPLE_RATE

  scores_np = scores.numpy()
  spectrogram_np = spectrogram.numpy()

  plt.figure(figsize=(10, 7))

  # Plot the waveform.
  plt.subplot(5, 1, 1)
  plt.plot(np.arange(0, waveform.shape[0]) / SAMPLE_RATE, waveform)
  plt.xlim([0, duration])
  plt.xlabel('time (s)')
  plt.ylabel('input')

  # Plot the log-mel spectrogram (returned by the model).
  plt.subplot(5, 1, 2)
  plt.imshow(
      spectrogram_np.T[:, :int(100 * duration)],
      aspect='auto',
      interpolation='nearest',
      origin='lower')
  plt.xlabel('spectrogram # (10 ms hop each)')

  # Plot and label the model output scores for the top-scoring classes.
  mean_scores = np.mean(scores, axis=0)
  top_n = 5
  top_class_indices = np.argsort(mean_scores)[::-1][:top_n]
  plt.subplot(5, 1, 3)
  plt.imshow(
      scores_np[:, top_class_indices].T,
      aspect='auto',
      interpolation='nearest',
      cmap='gray_r')
  plt.xlim([-1.5, (duration / 0.480) - 1.5])

  # Label the top_N classes.
  yticks = range(0, top_n, 1)
  plt.yticks(yticks, [CLASS_NAMES[top_class_indices[x]] for x in yticks])
  _ = plt.ylim(-0.5 + np.array([top_n, 0]))
  plt.xlabel('prediction # (960 ms input, every 480 ms)')
  plt.ylabel('YAMNet pred.')

  plt.subplot(5, 1, 4)
  plt.plot(np.arange(0, mix_waveform.shape[0]) / SAMPLE_RATE, mix_waveform)
  plt.xlim([0, duration])
  plt.ylim([-0.01, 1.01])
  plt.xlabel('time (s)')
  plt.ylabel('strategy')

  plt.subplot(5, 1, 5)
  plt.plot(np.arange(0, mixed_waveform.shape[0]) / SAMPLE_RATE, mixed_waveform)
  plt.xlim([0, duration])
  plt.xlabel('time (s)')
  plt.ylabel('mixed')
  plt.tight_layout()
  if output_plot_filename is not None:
    plt.savefig(
        gfile.GFile(output_plot_filename, 'w'), dpi=300, bbox_inches='tight')


def display_audio(original_audio, cleaned_audio, mixed_audio):
  t = widgets.Grid(1, 3)
  with t.output_to(0, 0):
    print('original')
    PlaySound(original_audio, SAMPLE_RATE)
  with t.output_to(0, 1):
    print('cleaned')
    PlaySound(cleaned_audio, SAMPLE_RATE)
  with t.output_to(0, 2):
    print('mixed')
    PlaySound(mixed_audio, SAMPLE_RATE)

In [None]:
# Desired rate, required by YAMNet.
SAMPLE_RATE = 16000
SAMPLES_PER_INFERENCE_PERIOD = int(
    SAMPLE_RATE * 0.480)  # time between YAMNet inference windows

CROSS_FADE = int(0.100 * SAMPLE_RATE)
FIXED_NOISE_FRACTION = 0.05

In [None]:
# Specify input paths.
PATH_ORIGINAL = 'gdrive/My Drive/cihack_audio'  # E.g. gdrive/My Drive/cihack_audio
PATH_SPEECH_ENHANCED = PATH_ORIGINAL + '_enhanced'

# Specify output paths.
PATH_MIXED_VARIABLE = PATH_SPEECH_ENHANCED + '_mixed_variable'
PATH_MIXED_FIXED = PATH_SPEECH_ENHANCED + '_mixed_fixed'

PATH_SETS_TO_MIX = [
    (PATH_ORIGINAL, PATH_SPEECH_ENHANCED, PATH_MIXED_VARIABLE),
    (PATH_ORIGINAL, PATH_SPEECH_ENHANCED, PATH_MIXED_FIXED),
]

In [None]:
audio_clip_matcher = '*.wav'  #@param

for input_path, enhanced_input_path, output_path in PATH_SETS_TO_MIX:
  wavs = gfile.glob(os.path.join(input_path, audio_clip_matcher))
  gfile.makedirs(output_path)
  for wav in wavs:
    if 'variable' in output_path:
      strategy = 'variable'
    elif 'fixed' in output_path:
      strategy = 'fixed'

    run_yamnet_mix_and_save_audio(
        os.path.basename(wav),
        input_path,
        enhanced_input_path,
        output_path,
        strategy=strategy)