Copyright 2022 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Imports

In [None]:
# Copy google-research:
!git clone https://github.com/google-research/google-research.git

In [None]:
import tensorflow as tf
import os
import numpy as np
import math
import scipy.io.wavfile as wav
import scipy.signal as signal
from matplotlib import pylab as plt

use_colabtools = False
if use_colabtools:
  import colabtools.sound
  from colabtools import sound
else:
  import IPython

In [None]:
# It was validated with TF 2.9.1
print(tf.__version__)
assert tf.__version__=='2.9.1'

# Utils

In [None]:
def WavRead(filename, divide=False, target_sample_rate=16000):
  """Read in audio data from a wav file.  Return d, sr."""
  normalizer = {
      'int32': 2147483648.0,
      'int16': 32768.0,
      'float32': 1.0,
      }
  samplerate, wave_data = wav.read(filename)
  norm = normalizer[wave_data.dtype.name]
  if samplerate != target_sample_rate:
    desired_length = int(
        round(float(len(wave_data)) / samplerate * target_sample_rate))
    wave_data = signal.resample(wave_data, desired_length)
    print("resample input wav samplerate " + str(samplerate))

  # Normalize floats in range [-1..1).
  data = np.array(wave_data, np.float32) / norm

  return data, target_sample_rate

In [None]:
def RunNonStreaming(input_features, tflite_model_path):
  """Runs tflite_model in non streaming mode.

  It relies on assumption that tflite inputs/outputs are set in order and we can
  access them by index.

  Arguments:
    input_features: input features
    tflite_model_path: path to tflite model

  Returns:
    Output produced by non streaming.
  """

  interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
  input_details = interpreter.get_input_details()
  output_details = interpreter.get_output_details()
  input_shape_signature = input_details[0]['shape_signature']

  if -1 in input_shape_signature:
    interpreter.resize_tensor_input(input_details[0]['index'], input_features.shape)

  interpreter.allocate_tensors()
  interpreter.set_tensor(input_details[0]['index'], input_features)
  interpreter.invoke()
  non_stream_output = interpreter.get_tensor(output_details[0]['index'])

  return non_stream_output


def RunStreaming(input_features, step, tflite_model_path, inp_to_out, input_index=0, padding_index=-1):
  """Runs tflite_model in streaming mode.

  It relies on assumption that tflite inputs/outputs are set in order and we can
  access them by index.

  Arguments:
    input_features: input features
    step: stride to process input data
    tflite_model_path: path to tflite model
    input_index: index of input data in TFLite module
    padding_index: index of padding data in TFLite module.
      It is optional: if -1 then ignored.

  Returns:
    Output produced by streaming: it is a concatenation of outputs produced
     by streaming mode, so that we can compare it with non streaming output
  """

  interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
  interpreter.allocate_tensors()

  input_details = interpreter.get_input_details()
  output_details = interpreter.get_output_details()

  # create input states
  input_states = []
  for s in range(len(input_details)):
    input_states.append(
        np.zeros(input_details[s]['shape'], dtype=input_details[s]['dtype']))

  stream_features = None

  start = 0
  end = step
  while end <= input_features.shape[1]:
    input_packet = input_features[:, start:end]
    paddings_packet = tf.zeros(input_packet.shape[0:2])

    # update indexes of streamed updates
    start = end
    end += step

    # set input audio data (by default input data at index 0, 1)
    interpreter.set_tensor(input_details[input_index]['index'], input_packet)
    if padding_index > 0:
      interpreter.set_tensor(input_details[padding_index]['index'], paddings_packet)

    # set input states
    for s in range(len(input_details)):
      if s not in [input_index, padding_index]:
        interpreter.set_tensor(input_details[s]['index'], input_states[s])

    # run inference
    interpreter.invoke()

    # get output data (and ignore output paddings)
    stream_output = interpreter.get_tensor(output_details[inp_to_out[input_index]]['index'])


    # get output states and set it back to input states
    # which will be fed in the next inference cycle
    for s in range(len(input_details)):
      # The function `get_tensor()` returns a copy of the tensor data.
      # Use `tensor()` in order to get a pointer to the tensor.
      if s not in [input_index, padding_index]:
        input_states[s] = interpreter.get_tensor(output_details[inp_to_out[s]]['index'])

    if stream_features is None:
      stream_features = stream_output
    else:
      stream_features = tf.concat((stream_features, stream_output), axis=1)

  return stream_features


In [None]:
def Wav2Spectrogram(wav_data):
  frame_size_ms = 50.0
  frame_step_ms = 12.5
  sample_rate = 16000

  frame_step = int(round(sample_rate * frame_step_ms / 1000.0))
  frame_size = int(round(sample_rate * frame_size_ms / 1000.0))

  input_features = tf.expand_dims(wav_data, 0)

  # Preempasis
  preemph = 0.97
  pad = [[0, 0]] * input_features.shape.rank
  pad[1] = [1, 0]  # Pad on the left side, becasue of preemphasis
  input_features = tf.pad(input_features, pad, 'constant')
  preemph_features = input_features[:, 1:] - preemph * input_features[:, :-1]

  # Framing
  framed_features = tf.signal.frame(preemph_features, frame_size, frame_step, False)

  # Windowing
  window = tf.signal.hann_window(frame_size, periodic=True)
  window_features = framed_features * window

  # RFFT
  fft_size = int(math.pow(2, math.ceil(math.log(frame_size, 2))))
  fft_size = max(2048, fft_size)
  rfft = tf.signal.rfft(window_features, [fft_size])
  magnitude_spectrum = tf.math.abs(rfft)

  # Log
  output_features = tf.math.log(magnitude_spectrum + 1e-2)
  return output_features

# Load input wav

### Set path to input wav file:

In [None]:
wav_file_name = "p232_118.wav"
# Path to https://github.com/google-research/google-research/tree/master/specinvert/vctk/input
wav_path = "google-research/specinvert/vctk/input/"
wav_path = os.path.join(wav_path, wav_file_name)

In [None]:
wav_data, sample_rate = WavRead(wav_path)

In [None]:
%matplotlib inline
plt.plot(wav_data)
if use_colabtools:
  colabtools.sound.PlaySound(wav_data, sample_rate)
else:
  IPython.display.Audio(wav_path) 

# Convert wav to spectrogram

In [None]:
spectrogram_magnitude = Wav2Spectrogram(wav_data)

In [None]:
spectrogram_magnitude.shape

TensorShape([1, 247, 1025])

In [None]:
%matplotlib inline
plt.imshow(spectrogram_magnitude[0])

# Prepare models TFlite modules

In [None]:
# Download TFLite modules
# and place them in the current folder of the notebook
!wget http://storage.googleapis.com/gresearch/specinvert/non_stream_GAN.tflite
!wget http://storage.googleapis.com/gresearch/specinvert/stream_GAN_lookahead1.tflite
!wget http://storage.googleapis.com/gresearch/specinvert/stream_GAN_causal.tflite
!wget http://storage.googleapis.com/gresearch/specinvert/stream_GL.tflite

# Invert spectrogram with non streaming MelGAN

In [None]:
non_stream_tfl = RunNonStreaming(spectrogram_magnitude, "non_stream_GAN.tflite")

In [None]:
%matplotlib inline
plt.plot(non_stream_tfl[0])
if use_colabtools:
  colabtools.sound.PlaySound(non_stream_tfl[0], sample_rate)
else:
  IPython.display.Audio(non_stream_tfl[0], rate=16000, autoplay=True)   

# Invert spectrogram with streaming MelGAN lookahead 1 hop

In [None]:
# Mapping of input output indexes in TFLite
inp_to_out_n={}
inp_to_out_n[0] = 0
stream_lookahead_path_tfl_path = "stream_GAN_lookahead1.tflite"
output_stream_lookahead_tfl = RunStreaming(spectrogram_magnitude, 1, stream_lookahead_path_tfl_path, inp_to_out_n, input_index=0)

In [None]:
%matplotlib inline
plt.plot(output_stream_lookahead_tfl[0])
if use_colabtools:
  colabtools.sound.PlaySound(output_stream_lookahead_tfl[0], sample_rate)
else:
  IPython.display.Audio(output_stream_lookahead_tfl[0], rate=16000, autoplay=True)   

# Invert spectrogram with streaming causal MelGAN (no lookahead)

In [None]:
# Mapping of input output indexes in TFLite
inp_to_out_n={}
inp_to_out_n[0] = 0

stream_causal_path_tfl_path = "stream_GAN_causal.tflite"
output_stream_causal_tfl = RunStreaming(spectrogram_magnitude, 1, stream_causal_path_tfl_path, inp_to_out_n, input_index=0)

In [None]:
%matplotlib inline
plt.plot(output_stream_causal_tfl[0])
if use_colabtools:
  colabtools.sound.PlaySound(output_stream_causal_tfl[0], sample_rate)
else:
  IPython.display.Audio(output_stream_causal_tfl[0], rate=16000, autoplay=True)   

# Invert spectrogram with streaming GL

In [None]:
# Mapping of input output indexes in TFLite
inp_to_out={}
inp_to_out[0] = 2
inp_to_out[1] = 3
inp_to_out[2] = 0
inp_to_out[3] = 1
inp_to_out[4] = 4
inp_to_out[5] = 5

stream_gl_tfl_path = "stream_GL.tflite"
output_stream_gl_tfl = RunStreaming(spectrogram_magnitude, 1, stream_gl_tfl_path, inp_to_out, input_index=4)

In [None]:
%matplotlib inline
plt.plot(output_stream_gl_tfl[0])
if use_colabtools:
  colabtools.sound.PlaySound(output_stream_gl_tfl[0], sample_rate)
else:
  IPython.display.Audio(output_stream_gl_tfl[0], rate=16000, autoplay=True)   