# Simple Training




In [None]:
USE_PRIVATE_DISTRO = True

DRIVE_BASE_DIR = '/content/drive/MyDrive/SMC 10/DDSP-10/' 
DRIVE_DISTRO = DRIVE_BASE_DIR + 'dist/ddsp-1.2.0.tar.gz'

if USE_PRIVATE_DISTRO:
    print("[INFO] Using private distro.")
    from google.colab import drive
    drive.mount('/content/drive')
    !pip install -qU "$DRIVE_DISTRO"
else:
    !pip install -qU ddsp

%tensorflow_version 2.x
import tensorflow as tf
#import tensorflow.compat.v2 as tf

%reload_ext tensorboard
import tensorboard as tb

import seaborn as sns
import matplotlib.pyplot as plt
%config InlineBackend.figure_format='retina'

from ddsp.colab.colab_utils import specplot
from ddsp.colab.colab_utils import play
from ddsp.training import data
from ddsp.training import decoders
from ddsp.training import eval_util
from ddsp.training import evaluators
from ddsp.training import models
from ddsp.training import preprocessing
from ddsp.training import train_util
from ddsp.training import trainers
from ddsp import core
from ddsp import losses
from ddsp import processors
from ddsp import synths

from absl import logging

In [None]:
TIME_STEPS = 1000
N_SAMPLES = 64000
SAMPLE_RATE = 16000
FRAME_RATE = 250

INSTRUMENT = 'violin'

logging.set_verbosity(logging.INFO)
sns.set_theme(style='whitegrid')
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
# tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

DRIVE_CHECKPOINTS_DIR = DRIVE_BASE_DIR + 'audio/fm_' + \
                         INSTRUMENT + '_checkpoints/'
!mkdir -p "$DRIVE_CHECKPOINTS_DIR"

DRIVE_TFRECORD_PATTERN = DRIVE_BASE_DIR + 'audio/' + \
                         INSTRUMENT + '_dataset/train.tfrecord*'

### Define architecture

In [None]:
preprocessor = preprocessing.F0LoudnessPreprocessor(time_steps=TIME_STEPS)

decoder = decoders.RnnFcDecoder(rnn_channels = 128,
                                rnn_type = 'gru',
                                ch = 128,
                                layers_per_stack = 1,
                                input_keys = ('ld_scaled', 'f0_scaled'),
                                output_splits = (
                                    ('op1', 4),
                                    ('op2', 4),
                                    ('op3', 4),
                                    ('op4', 4),
                                    ('modulators', 6),
                                    # ('noise_magnitudes', 3),
                                                 ))

fm = synths.FrequencyModulation(n_samples=N_SAMPLES, 
                                sample_rate=SAMPLE_RATE,
                                amp_scale_fn=core.exp_sigmoid,
                                name='fm')

# noise = synths.FilteredNoise(window_size=0, 
#                              initial_bias=-10.0,
#                              scale_fn=core.exp_sigmoid,
#                              name='noise')

# add = processors.Add(name='add')

dag = [
       (fm, ['f0_hz', 'op1', 'op2', 'op3', 'op4', 'modulators']),
      #  (noise, ['noise_magnitudes']),
      #  (add, ['noise/signal', 'fm/signal'])
      ]

processor_group = processors.ProcessorGroup(dag=dag,
                                            name='processor_group')

spectral_loss = losses.SpectralLoss(loss_type='L1',
                                    mag_weight=1.0,
                                    logmag_weight=1.0)



### Get a distribution strategy


In [None]:
strategy = train_util.get_strategy()

### Get the model and the trainer

In [None]:
with strategy.scope():

  model = models.Autoencoder(preprocessor=preprocessor,
                             encoder=None,
                             decoder=decoder,
                             processor_group=processor_group,
                             losses=[spectral_loss])
  
  trainer = trainers.Trainer(model, 
                             strategy, 
                             checkpoints_to_keep=5, 
                             learning_rate = 0.0001)

### Get the data providers

In [None]:
data_provider = data.TFRecordProvider(DRIVE_TFRECORD_PATTERN, frame_rate=FRAME_RATE)

dataset_batch = data_provider.get_batch(batch_size=1, shuffle=True).take(1).repeat()
dataset_batch_iter = iter(dataset_batch)

In [None]:
frame = next(dataset_batch_iter)

play(frame['audio'])

f, ax = plt.subplots(1, 2, figsize=(10.5, 3))
f.suptitle('Original audio', fontsize=14)
ax[0].set_ylabel('Amplitude')
ax[0].plot(frame['loudness_db'][0])
ax[1].set_ylabel('Freqs')
_ = ax[1].plot(frame['f0_hz'][0])

specplot(frame['audio'])

### Start Tensorboard

In [None]:
tb.notebook.start('--reload_interval 15 --logdir "{}"'.format(DRIVE_CHECKPOINTS_DIR))

### Train

In [None]:
train_util.train(data_provider=data_provider,
                 trainer=trainer,
                 batch_size=16,
                 num_steps=10000,
                 steps_per_summary = 50,
                 steps_per_save = 100,
                 save_dir=DRIVE_CHECKPOINTS_DIR,
                 restore_dir=DRIVE_CHECKPOINTS_DIR,
                 early_stop_loss_value=8.0,
                 report_loss_to_hypertune=False)

In [None]:
frame = next(dataset_batch_iter)
audio_baseline = frame['audio']

controls = model(frame, training=False)
audio_full = model.get_audio_from_outputs(controls)

print('Original Audio')
play(audio_baseline)

print('Full reconstruction')
play(audio_full)

print('Only FM')
play(controls['fm']['signal'])

# print('Only noise')
# play(controls['noise']['signal'])

specplot(audio_baseline)
specplot(audio_full)



# MODULATORS
# -----------------------------------------------------------------
plt.figure(figsize=(7, 3))
for m in range(6):
  plt.plot(controls['fm']['controls']['modulators'][0,:,m])
plt.legend(['m21','m31','m32','m41','m42','m43'])
# plt.suptitle('Magnitude of modulators', fontsize=14)
plt.ylabel('Modulation')
plt.show()


# AMPLITUDES AND INDEXES
# -----------------------------------------------------------------
f, ax = plt.subplots(2, 2, figsize=(16, 6))

for o in range(4):
  ax[0][0].plot(controls['fm']['controls']['op'+str(o+1)][0,:,0])
ax[0][0].legend(['Amp 1','Amp 2','Amp 3','Amp 4'])
ax[0][0].set_ylabel('Amplitude')

for o in range(4):
  ax[0][1].plot(controls['fm']['controls']['op'+str(o+1)][0,:,1])
ax[0][1].legend(['Idx 1','Idx 2','Idx 3','Idx 4'])
_ = ax[0][1].set_ylabel('Index')


# ADSR (SAMPLES)
# -----------------------------------------------------------------
# f, ax = plt.subplots(1, 2, figsize=(16, 3.5), sharey=True)

for env in range(2):
  for o in range(4):
    ax[1][env].plot(controls['fm']['controls']['op'+str(o+1)][0,:,2+env])
ax[1][0].set_ylabel('From')
_ = ax[1][1].set_ylabel('To')

In [None]:
raise SystemExit("Stop right there!")