In [45]:
import os
from collections import OrderedDict
import pickle

import ddsp.training
import gin
import tensorflow as tf

In [43]:
def save_weights(instrument):
    af = {
        'f0_hz': tf.ones((1, 1000, 1)),
        'loudness_db': tf.ones((1, 1000, 1))
    }
    
    # Pretrained models.
    model_dir = f'pretrained/{instrument.lower()}'

    gin_file = os.path.join(model_dir, 'operative_config-0.gin')

    # Parse gin config,
    with gin.unlock_config():
        gin.parse_config_file(gin_file, skip_unknown=True)

    # Assumes only one checkpoint in the folder, 'ckpt-[iter]`.
    ckpt_files = [f for f in os.listdir(model_dir) if 'ckpt' in f]
    ckpt_name = ckpt_files[0].split('.')[0]
    ckpt = os.path.join(model_dir, ckpt_name)

    # Ensure dimensions and sampling rates are equal
    time_steps_train = gin.query_parameter('F0LoudnessPreprocessor.time_steps')
    n_samples_train = gin.query_parameter('Harmonic.n_samples')
    hop_size = int(n_samples_train / time_steps_train)

    time_steps = 1000
    n_samples = time_steps * hop_size

    gin_params = [
        'Harmonic.n_samples = {}'.format(n_samples),
        'FilteredNoise.n_samples = {}'.format(n_samples),
        'F0LoudnessPreprocessor.time_steps = {}'.format(time_steps),
        'oscillator_bank.use_angular_cumsum = True',  # Avoids cumsum accumulation errors.
    ]

    with gin.unlock_config():
        gin.parse_config(gin_params)

    # Set up the model just to predict audio given new conditioning
    instrument = ddsp.training.models.Autoencoder()
    instrument.restore(ckpt)
    
    outputs = instrument(af, training=False)
    
    variables = OrderedDict()

    for item in instrument.trainable_variables:
        variables[item.name] = item.value().numpy()
    
    pickle.dump(variables, open('violin.pkl', 'wb'))

In [44]:
inst = save_weights('violin')