# Convert MelGAN generator from pytorch to tensorflow

This notebook proivdies the procedure of conversion of MelGAN generator from pytorch to tensorflow.  
Tensorflow version can accelerate the inference speed on both CPU and GPU.

In [1]:
import os
import yaml

import numpy as np
import torch
import tensorflow as tf

# this is needed
tf.compat.v1.disable_eager_execution()

from parallel_wavegan.models import MelGANGenerator
from parallel_wavegan.models.tf_models import TFMelGANGenerator

# disable cuda for this demonstration
os.environ["CUDA_VISIBLE_DEVICES"]=""

## Define Tensorflow and Pytorch models

In [2]:
# load vocoder config 
vocoder_conf = '../egs/ljspeech/voc1/conf/melgan.v1.long.yaml'
with open(vocoder_conf) as f:
    config = yaml.load(f, Loader=yaml.Loader)

In [3]:
# define Tensorflow MelGAN generator
inputs = tf.keras.Input(batch_shape=[None, None, 80], dtype=tf.float32)
audio = TFMelGANGenerator(**config["generator_params"])(inputs)
tf_melgan = tf.keras.models.Model(inputs, audio)
tf_melgan.summary()

Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, None, 80)]        0         
_________________________________________________________________
tf_mel_gan_generator (TFMelG (None, None, 1)           4260257   
Total params: 4,260,257
Trainable params: 4,260,257
Non-trainable params: 0
_________________________________________________________________


In [4]:
# define pytorch model
pytorch_melgan = MelGANGenerator(**config["generator_params"])
pytorch_melgan.remove_weight_norm()  # needed since TFMelGANGenerator does not support weight norm
pytorch_melgan.to("cpu")

MelGANGenerator(
  (melgan): Sequential(
    (0): ReflectionPad1d((3, 3))
    (1): Conv1d(80, 512, kernel_size=(7,), stride=(1,))
    (2): LeakyReLU(negative_slope=0.2)
    (3): ConvTranspose1d(512, 256, kernel_size=(16,), stride=(8,), padding=(4,))
    (4): ResidualStack(
      (stack): Sequential(
        (0): LeakyReLU(negative_slope=0.2)
        (1): ReflectionPad1d((1, 1))
        (2): Conv1d(256, 256, kernel_size=(3,), stride=(1,))
        (3): LeakyReLU(negative_slope=0.2)
        (4): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
      )
      (skip_layer): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (5): ResidualStack(
      (stack): Sequential(
        (0): LeakyReLU(negative_slope=0.2)
        (1): ReflectionPad1d((3, 3))
        (2): Conv1d(256, 256, kernel_size=(3,), stride=(1,), dilation=(3,))
        (3): LeakyReLU(negative_slope=0.2)
        (4): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
      )
      (skip_layer): Conv1d(256, 256, kernel_size=(1,)

In [5]:
# check the number of variables are the same
state_dict = pytorch_melgan.state_dict()
tf_vars = tf.compat.v1.global_variables()
print("Number Tensorflow variables: ", len(tf_vars))
print("Number Pytorch variables: ", len(state_dict.keys()))

Number Tensorflow variables:  84
Number Pytorch variables:  84


## Convert parameters from pytorch to tensorflow

In [6]:
def reorder_tf_vars(tf_vars):
    """
    Reorder tensorflow variables to match with pytorch state dict order. 
    Since each tensorflow layer's order is bias -> weight while pytorch's 
    one is weight -> bias, we change the order of variables.
    """
    tf_new_var = []
    for i in range(0, len(tf_vars), 2):
        tf_new_var.append(tf_vars[i + 1])
        tf_new_var.append(tf_vars[i])
    return tf_new_var

In [7]:
# change the order of variables to be the same as pytorch
tf_vars = reorder_tf_vars(tf_vars)

In [8]:
def convert_weights_pytorch_to_tensorflow(weights_pytorch):
    """
    Convert pytorch Conv1d weight variable to tensorflow Conv2D weights.
    Pytorch (f_output, f_input, kernel_size) -> TF (kernel_size, f_input, 1, f_output)
    """
    weights_tensorflow = np.transpose(weights_pytorch, (0,2,1))  # [f_output, kernel_size, f_input]
    weights_tensorflow = np.transpose(weights_tensorflow, (1,0,2))  # [kernel-size, f_output, f_input]
    weights_tensorflow = np.transpose(weights_tensorflow, (0,2,1))  # [kernel-size, f_input, f_output]
    weights_tensorflow = np.expand_dims(weights_tensorflow, 1)  # [kernel-size, f_input, 1, f_output]
    return weights_tensorflow

In [9]:
# convert pytorch's variables to tensorflow's one
for i, var_name in enumerate(state_dict):
    try:
        tf_name = tf_vars[i]
        torch_tensor = state_dict[var_name].numpy()
        if torch_tensor.ndim >= 2:
            tensorflow_tensor = convert_weights_pytorch_to_tensorflow(torch_tensor)
        else:
            tensorflow_tensor = torch_tensor
        tf.keras.backend.set_value(tf_name, tensorflow_tensor)
    except:
        print(tf_name)

## Check both outputs are almost the equal

In [10]:
fake_mels = np.random.sample((1, 80, 250)).astype(np.float32)
with torch.no_grad():
    y_pytorch = pytorch_melgan(torch.Tensor(fake_mels).to("cpu"))
y_tensorflow = tf_melgan.predict(np.transpose(fake_mels, (0, 2, 1)))
np.testing.assert_almost_equal(
    y_pytorch[0, 0, :].numpy(),
    y_tensorflow[0, :, 0],
)

## Save Tensorflow and Pytorch models for benchmark

In [11]:
os.makedirs("./checkpoint/tensorflow_generator/", exist_ok=True)
os.makedirs("./checkpoint/pytorch_generator/", exist_ok=True)
tf.saved_model.save(tf_melgan, "./checkpoint/tensorflow_generator/")
torch.save(pytorch_melgan.state_dict(), "./checkpoint/pytorch_generator/checkpoint.pkl")

INFO:tensorflow:Assets written to: ./checkpoint/tensorflow_generator/assets


INFO:tensorflow:Assets written to: ./checkpoint/tensorflow_generator/assets


## Inference speed benchmark

From here, we will compare the inference speed using pytorch model and converted tensorflow model.

In [None]:
# To enable eager mode, we need to restart the ipython kernel
from IPython.core.display import HTML
HTML("<script>Jupyter.notebook.kernel.restart()</script>")

In [1]:
# setup pytorch model
import os
import yaml

import numpy as np
import torch

from parallel_wavegan.models import MelGANGenerator

os.environ["CUDA_VISIBLE_DEVICES"]="" 

vocoder_conf = '../egs/ljspeech/voc1/conf/melgan.v1.long.yaml'
with open(vocoder_conf) as f:
    config = yaml.load(f, Loader=yaml.Loader)
pytorch_melgan = MelGANGenerator(**config["generator_params"])
pytorch_melgan.remove_weight_norm()
pytorch_melgan.load_state_dict(torch.load(
    "./checkpoint/pytorch_generator/checkpoint.pkl", map_location="cpu"))
pytorch_melgan.to("cpu")

MelGANGenerator(
  (melgan): Sequential(
    (0): ReflectionPad1d((3, 3))
    (1): Conv1d(80, 512, kernel_size=(7,), stride=(1,))
    (2): LeakyReLU(negative_slope=0.2)
    (3): ConvTranspose1d(512, 256, kernel_size=(16,), stride=(8,), padding=(4,))
    (4): ResidualStack(
      (stack): Sequential(
        (0): LeakyReLU(negative_slope=0.2)
        (1): ReflectionPad1d((1, 1))
        (2): Conv1d(256, 256, kernel_size=(3,), stride=(1,))
        (3): LeakyReLU(negative_slope=0.2)
        (4): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
      )
      (skip_layer): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (5): ResidualStack(
      (stack): Sequential(
        (0): LeakyReLU(negative_slope=0.2)
        (1): ReflectionPad1d((3, 3))
        (2): Conv1d(256, 256, kernel_size=(3,), stride=(1,), dilation=(3,))
        (3): LeakyReLU(negative_slope=0.2)
        (4): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
      )
      (skip_layer): Conv1d(256, 256, kernel_size=(1,)

In [2]:
# setup tensorflow model
import tensorflow as tf

from tensorflow.python.framework import convert_to_constants
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants

class TFMelGAN(object):
    def __init__(self, saved_path):
        self.saved_path = saved_path
        self.graph = self._load_model()
        self.mels = None
        self.audios = None
    
    def _load_model(self):
        saved_model_loaded = tf.saved_model.load(
            self.saved_path, tags=[tag_constants.SERVING])
        graph_func = saved_model_loaded.signatures[
            signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
        graph_func = convert_to_constants.convert_variables_to_constants_v2(graph_func)
        return graph_func

    def set_mels(self, values):
        self.mels = values

    def get_mels(self):
        return self.mels

    def get_audio(self):
        return self.audios

    def run_inference(self):
        tf_mels = tf.constant(self.mels)
        self.audios = self.graph(tf_mels)[0].numpy()[:, :, 0]
        return self.audios   
    
tf_melgan = TFMelGAN(saved_path='./checkpoint/tensorflow_generator/')

In [3]:
# warmup
fake_mels = np.random.sample((4, 80, 500)).astype(np.float32)
y = pytorch_melgan(torch.Tensor(fake_mels))
y = pytorch_melgan(torch.Tensor(fake_mels))
tf_melgan.set_mels(np.random.sample((4, 500, 80)).astype(np.float32))
y = tf_melgan.run_inference()
y = tf_melgan.run_inference()

In [4]:
%%time
# check pytorch inference speed
with torch.no_grad():
    y = pytorch_melgan(torch.Tensor(fake_mels))

CPU times: user 10min 19s, sys: 8.18 s, total: 10min 27s
Wall time: 26.4 s


In [5]:
%%time
# check tensorflow inference speed
y = tf_melgan.run_inference()

CPU times: user 7.62 s, sys: 6.51 s, total: 14.1 s
Wall time: 1.63 s
