##### Copyright 2021 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [None]:
import os
import numpy as np
from scipy.io import wavfile
from scipy.signal import resample
import tensorflow.compat.v1 as tf
from tensorflow.io import gfile
from colabtools import sound

In [None]:
# @title Helper class for separation model inference.
class SeparationModel(object):
  """Tensorflow audio separation model."""

  def __init__(self,
               checkpoint_path,
               metagraph_path):
    self.graph = tf.Graph()
    self.sess = tf.Session(graph=self.graph)
    with self.graph.as_default():
      new_saver = tf.train.import_meta_graph(metagraph_path)
      new_saver.restore(self.sess, checkpoint_path)
    self.input_placeholder = self.graph.get_tensor_by_name(
        'input_audio/receiver_audio:0')
    self.output_tensor = self.graph.get_tensor_by_name('denoised_waveforms:0')


  def separate(self, mixture_waveform):
    """Separates a mixture waveform into sources.

    Args:
      mixture_waveform: numpy.ndarray of shape (num_mics, num_samples).

    Returns:
      numpy.ndarray of separated waveforms of shape (num_sources, num_samples).
      dict of additional tensor outputs.
    """
    mixture_waveform_input = np.expand_dims(mixture_waveform, 0)
    feed_dict = {self.input_placeholder: mixture_waveform_input}

    separated_waveforms = self.sess.run(self.output_tensor, feed_dict=feed_dict)
    return separated_waveforms[0]

Manually download the pre-trained speech enhancement model files using [gsutil](https://cloud.google.com/storage/docs/gsutil) with:

`gsutil cp -r gs://gresearch/cochlear_implant/speech_enhancement_model .`

In [None]:
# @title Load speech enhancement model.
MODEL_PATH = '/path/to/model'

checkpoint = os.path.join(MODEL_PATH, 'checkpoint')
metagraph = os.path.join(MODEL_PATH, 'inference.meta')
model = SeparationModel(checkpoint, metagraph)

In [None]:
# @title Get some wav paths.
PATH_AUDIO = '/path/to/audio'
PATH_ENHANCED = PATH_AUDIO + '_enhanced'

audio_clip_matcher = '*.wav'  #@param
wavs = gfile.Glob(os.path.join(PATH_AUDIO, audio_clip_matcher))

In [None]:
#@title File read/write functions.
def write_wav(filename, waveform, sample_rate=16000):
  """Write a audio waveform (float numpy array) as .wav file."""
  gfile.MakeDirs(os.path.dirname(filename))
  with gfile.GFile(filename, 'w') as fh:
    wavfile.write(
        fh, sample_rate,
        np.round(np.clip(waveform * 2**15, -32768, 32767)).astype(np.int16))


def read_wav(wav_path, sample_rate=16000, channel=None):
  """Read a wav file as numpy array.

  Args:
    wav_path: String, path to .wav file.
    sample_rate: Int, sample rate for audio to be converted to.
    channel: Int, option to select a particular channel for stereo audio.

  Returns:
    Audio as float numpy array.
  """
  with gfile.Open(wav_path, 'rb') as f:
    sr_read, x = wavfile.read(f)
  x = x.astype(np.float32) / (2**15)

  if sr_read != sample_rate:
    x = resample(x, int(round((float(sample_rate) / sr_read) * len(x))))
  if x.ndim > 1 and channel is not None:
    return x[:, channel]
  return x

In [None]:
# @title Enhance some wavs and play audio.
for wav in wavs:
  print('wav path:', wav)
  print('Input')
  sr, receiver_audio = read_wav(wav)

  sound.PlaySound(receiver_audio, sr)

  enhanced = model.separate(receiver_audio[np.newaxis])[0]
  output_path = os.path.join(PATH_ENHANCED, os.path.basename(wav))
  gfile.MakeDirs(os.path.dirname(output_path))
  write_wav(output_path, enhanced[0], sr)

  print('Speech estimate')
  sound.PlaySound(enhanced[0], sr)
  print('Noise estimate')
  sound.PlaySound(enhanced[1], sr)