In [1]:
import os
import yaml
os.environ["CUDA_VISIBLE_DEVICES"]="-1" 

import tensorflow as tf
tf.compat.v1.disable_eager_execution()

import numpy as np

from parallel_wavegan.models import TFMelGANGenerator, MelGANGenerator
import torch

In [2]:
vocoder_conf = '../egs/ljspeech/voc1/conf/melgan.v1.long.yaml'

In [3]:
with open(vocoder_conf) as f:
    config = yaml.load(f, Loader=yaml.Loader)

In [4]:
inputs = tf.keras.Input(batch_shape=[None, None, 80], dtype=tf.float32)
audio = TFMelGANGenerator(**config["generator_params"])(inputs)
tf_melgan = tf.keras.models.Model(inputs, audio)
tf_melgan.summary()

Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, None, 80)]        0         
_________________________________________________________________
tf_mel_gan_generator (TFMelG (None, None, 1)           4260257   
Total params: 4,260,257
Trainable params: 4,260,257
Non-trainable params: 0
_________________________________________________________________


# load pytorch generator checkpoint

In [5]:
pytorch_melgan = MelGANGenerator(**config["generator_params"])
pytorch_melgan.load_state_dict(
    torch.load("./checkpoint/pytorch_generator/generator_4000000.pth", map_location='cpu'))
pytorch_melgan.remove_weight_norm()
pytorch_melgan.to("cpu")

MelGANGenerator(
  (melgan): Sequential(
    (0): ReflectionPad1d((3, 3))
    (1): Conv1d(80, 512, kernel_size=(7,), stride=(1,))
    (2): LeakyReLU(negative_slope=0.2)
    (3): ConvTranspose1d(512, 256, kernel_size=(16,), stride=(8,), padding=(4,))
    (4): ResidualStack(
      (stack): Sequential(
        (0): LeakyReLU(negative_slope=0.2)
        (1): ReflectionPad1d((1, 1))
        (2): Conv1d(256, 256, kernel_size=(3,), stride=(1,))
        (3): LeakyReLU(negative_slope=0.2)
        (4): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
      )
      (skip_layer): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (5): ResidualStack(
      (stack): Sequential(
        (0): LeakyReLU(negative_slope=0.2)
        (1): ReflectionPad1d((3, 3))
        (2): Conv1d(256, 256, kernel_size=(3,), stride=(1,), dilation=(3,))
        (3): LeakyReLU(negative_slope=0.2)
        (4): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
      )
      (skip_layer): Conv1d(256, 256, kernel_size=(1,)

In [6]:
state_dict = pytorch_melgan.state_dict()
tf_vars = tf.compat.v1.global_variables()
print("Number TF variables: ", len(tf_vars))
print("Number Pytorch variables: ", len(state_dict.keys()))

Number TF variables:  84
Number Pytorch variables:  84


# Convert weight from pytorch checkpoint to tensorflow

In [7]:
def reorder_tf_vars(tf_vars):
    """
    Reorder tf variables to match with pytorch
    state dict order. In this case, each TF layer
    have order is bias -> weight while pytorch layer
    is weight -> bias.
    """
    tf_new_var = []
    for i in range(0, len(tf_vars), 2):
        tf_new_var.append(tf_vars[i+1])
        tf_new_var.append(tf_vars[i])
    return tf_new_var

In [8]:
tf_vars = reorder_tf_vars(tf_vars)

In [9]:
def convert_weights_pytorch_to_tensorflow(weights_pytorch):
    weights_tensorflow = np.transpose(weights_pytorch, (0,2,1)) # [f_output, kernel_size, f_input]
    weights_tensorflow = np.transpose(weights_tensorflow, (1,0,2)) # [kernel-size, f_output, f_input]
    weights_tensorflow = np.transpose(weights_tensorflow, (0,2,1)) # [kernel-size, f_input, f_output]
    weights_tensorflow = np.expand_dims(weights_tensorflow, 1) # [kernel-size, f_input, 1, f_output]
    return weights_tensorflow

In [10]:
for i, var_name in enumerate(state_dict):
    try:
        tf_name = tf_vars[i]
        torch_tensor = state_dict[var_name].numpy()
        if torch_tensor.ndim >= 2:
            tensorflow_tensor = convert_weights_pytorch_to_tensorflow(torch_tensor)
        else:
            tensorflow_tensor = torch_tensor

        tf.keras.backend.set_value(tf_name, tensorflow_tensor)
    except:
        print(tf_name)

# Sanity check output variance

In [11]:
fake_mels = np.random.sample((1, 80, 250)).astype(np.float32)

In [12]:
y_pytorch = pytorch_melgan(torch.Tensor(fake_mels).to("cpu"))

In [13]:
y_tensorflow = tf_melgan.predict(np.transpose(fake_mels, (0, 2, 1)))

In [14]:
y_pytorch[0,0,:]

tensor([0.0037, 0.0119, 0.0176,  ..., 0.0129, 0.0058, 0.0022],
       grad_fn=<SliceBackward>)

In [15]:
y_tensorflow[0,:,0]

array([0.00366604, 0.01187663, 0.0176357 , ..., 0.01294548, 0.00580924,
       0.00215334], dtype=float32)

# Saved Tensorflow model to saved_model format

In [16]:
tf.saved_model.save(tf_melgan, "./checkpoint/tensorflow_generator/")

INFO:tensorflow:Assets written to: ./checkpoint/tensorflow_generator/assets


INFO:tensorflow:Assets written to: ./checkpoint/tensorflow_generator/assets
