In [1]:
import warnings
warnings.filterwarnings("ignore")

import time

import ddsp
from ddsp.training import (data, decoders, encoders, models, preprocessing, 
                           train_util, trainers)
from ddsp.colab.colab_utils import play, specplot, DEFAULT_SAMPLE_RATE
import gin
import matplotlib.pyplot as plt
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds

from kymatio.tensorflow import Scattering1D

sample_rate = DEFAULT_SAMPLE_RATE  # 16000

## get data

In [None]:
train_audio = '/scratch/hh2263/drum_data_ver2/train/'
TRAIN_TFRECORD = '/scratch/hh2263/drum_data_ver2/tf_dataset/train.tfrecord'
#pkl_dir = '/scratch/hh2263/drum_data_ver2/drumv2_sc-pkl/' #ignore pkl files for now.

#run this first to create a tfrecord formatted dataset on the wav2shape folder
gin_string = """

!ddsp_prepare_tfrecord \
--input_audio_filepatterns=/scratch/hh2263/drum_data_ver2/train/*wav \
--output_tfrecord_path= $TRAIN_TFRECORD \
--num_shards=10 \ac
--alsologtostderr
"""

In [None]:
#could make a generator

data_provider = data.FTMProvider(split='test')
dataset = data_provider.get_batch(batch_size=1, shuffle=False).take(1).repeat()
batch = next(iter(dataset))
audio = batch['audio']
n_samples = audio.shape[1]

## build model

audio->wav2shape encoder (compute scattering, log scale it, compute through the network and yields a 5-d vector theta) -> no decoder -> FTM processor (takes a 5-d vector and synthesize the sound) -> loss: spectral loss of the resulting audio


In [2]:
strategy = train_util.get_strategy() #Get a distribution strategy





INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


In [3]:
# preprocessor as scattering transform
#users should pass their own instantiation of kymatio 

#parameters of scattering
J=8
shape=(2**15,)
Q=1
order=2
scattering = Scattering1D(J=J,shape=shape,Q=Q,max_order=order)

preprocessor = preprocessing.ScatteringPreprocessor(scattering=scattering, eps=1e-3)

#encoder is wav2shape or some FC structure - need to output estimates of p_x,p_y,w11,tau11,p,D,alpha
encoder = encoders.wav2shapeEncoder(k_size=8,
                                   nchan_out=16,
                                   input_keys = ('scattering_scaled'), # from preprocessor
                                   output_splits=(('position_x',1),
                                                  ('position_y',1),
                                                  ('w_est', 1),
                                                  ('tau_est', 1),
                                                  ('p_est', 1),
                                                  ('D_est',1),
                                                  ('alpha',1)),
                                   name='wav2shape_encoder')


decoder = None

# Create Processors.

ftm = ddsp.ftm.FTM(n_samples=2**15,
                    sample_rate=sample_rate,
                   mode=20,
                  name='ftm')

# Create ProcessorGroup.

dag = [(ftm,['position_x','position_y','w_est','tau_est','p_est','D_est','alpha_est'])]

processor_group = ddsp.processors.ProcessorGroup(dag=dag,
                                                 name='ftm_processor')


# Loss_functions
spectral_loss = ddsp.losses.SpectralLoss(loss_type='L1',
                                         mag_weight=1.0,
                                         logmag_weight=1.0)

with strategy.scope():
    # Put it together in a model.
    model = models.Autoencoder(preprocessor=preprocessor,
                             encoder=encoder,
                             decoder=None,
                             processor_group=processor_group,
                             losses=[spectral_loss])
    trainer = trainers.Trainer(model, strategy, learning_rate=1e-3)

## Gin

In [None]:
gin_string = """
import ddsp
import ddsp.training

# Preprocessor
models.Autoencoder.preprocessor = @preprocessing.ScatteringPreprocessor()
preprocessing.DefaultPreprocessor.eps = 1e-3


# Encoder
models.Autoencoder.encoder = @encoders.wav2shapeEncoder()

encoders.wav2shapeEncoder.k_size=8
encoders.wav2shapeEncoder.nchan_out=16
encoders.wav2shapeEncoder.activation='linear'
encoders.wav2shapeEncoder.input_keys=('scattering_scaled')
encoders.wav2shapeEncoder.output_splits = (('theta',5))
encoders.wav2shapeEncoder.name='wav2shape_decoder'

# Decoder
models.Autoencoder.decoder = None


# ProcessorGroup
models.Autoencoder.processor_group = @processors.ProcessorGroup()

processors.ProcessorGroup.dag = [
  (@additive/synths.Additive(),
    ['amps', 'harmonic_distribution', 'f0_hz']),
  (@noise/synths.FilteredNoise(),
    ['noise_magnitudes']),
  (@add/processors.Add(),
    ['noise/signal', 'additive/signal']),
]


# FTM Synthesizer
ftm/ftm.FTM.n_samples = 2**15
ftm/ftm.FTM.sample_rate=16000


# Additive Synthesizer
additive/synths.Additive.name = 'additive'
additive/synths.Additive.n_samples = 64000
additive/synths.Additive.scale_fn = @core.exp_sigmoid

# Filtered Noise Synthesizer
noise/synths.FilteredNoise.name = 'noise'
noise/synths.FilteredNoise.n_samples = 64000
noise/synths.FilteredNoise.window_size = 0
noise/synths.FilteredNoise.scale_fn = @core.exp_sigmoid
noise/synths.FilteredNoise.initial_bias = -10.0

# Add
add/processors.Add.name = 'add'

models.Autoencoder.losses = [
    @losses.SpectralLoss(),
]
losses.SpectralLoss.loss_type = 'L1'
losses.SpectralLoss.mag_weight = 1.0
losses.SpectralLoss.logmag_weight = 1.0
"""

with gin.unlock_config():
  gin.parse_config(gin_string)

with strategy.scope():
  # Autoencoder arguments are filled by gin.
  model = ddsp.training.models.Autoencoder()
  trainer = trainers.Trainer(model, strategy, learning_rate=1e-3)

## train

In [None]:
"""

# Build model, easiest to just run forward pass.
dataset = trainer.distribute_dataset(dataset)
trainer.build(next(iter(dataset)))

dataset_iter = iter(dataset)

for i in range(300):
    losses = trainer.train_step(dataset_iter)
    res_str = 'step: {}\t'.format(i)
    for k, v in losses.items():
        res_str += '{}: {:.2f}\t'.format(k, v)
    print(res_str)
"""

In [None]:
#training script

    
ddsp_run \
  --mode=train \
  --save_dir=/tmp/$USER-ddsp-0 \
  --gin_file=models/solo_instrument.gin \ #one gin file for model configuration
  --gin_file=datasets/tfrecord.gin \ #one gin file for dataset configuration
  --gin_file=eval/basic_f0_ld.gin \ #one gin file for evaluation storage??
  --gin_param="TFRecordProvider.file_pattern='/path/to/dataset_name.tfrecord*'" \
  --gin_param="batch_size=16" \
  --alsologtostderr

In [None]:
#evaluation script
ddsp_run \
  --mode=eval \
  --save_dir=/tmp/$USER-ddsp-0 \
  --gin_file=dataset/nsynth.gin \
  --gin_file=eval/basic_f0_ld.gin \
  --alsologtostderr
    