# Simple Training




In [None]:
USE_PRIVATE_DISTRO = True

DRIVE_BASE_DIR = '/content/drive/MyDrive/SMC 10/DDSP-10/' 
DRIVE_DISTRO = DRIVE_BASE_DIR + 'dist/ddsp-1.2.0.tar.gz'

if USE_PRIVATE_DISTRO:
    print("[INFO] Using private distro.")
    from google.colab import drive
    drive.mount('/content/drive')
    !pip install -qU "$DRIVE_DISTRO"
else:
    !pip install -qU ddsp

import warnings
import gin

%reload_ext tensorboard
import tensorboard as tb

import seaborn as sns
import matplotlib.pyplot as plt
%config InlineBackend.figure_format='retina'

from ddsp.colab.colab_utils import specplot
from ddsp.colab.colab_utils import play
from ddsp.training import data
from ddsp.training import models
from ddsp import core

#### Some configuration

In [None]:
TIME_STEPS = 1000
N_SAMPLES = 64000
SAMPLE_RATE = 16000
FRAME_RATE = 250

INSTRUMENT = 'violin'

sns.set(style="whitegrid")
warnings.filterwarnings("ignore")

DRIVE_CHECKPOINTS_DIR = DRIVE_BASE_DIR + 'audio/AM_' + \
                         INSTRUMENT + '_checkpoints/'
!mkdir -p "$DRIVE_CHECKPOINTS_DIR"

DRIVE_TFRECORD_PATTERN = DRIVE_BASE_DIR + 'audio/' + \
                         INSTRUMENT + '_dataset/train.tfrecord*'

#### Start Tensorboard

In [None]:
tb.notebook.start('--logdir "{}"'.format(DRIVE_CHECKPOINTS_DIR))

#### Train the model

In [None]:
!ddsp_run \
  --mode=train \
  --alsologtostderr \
  --save_dir="$DRIVE_CHECKPOINTS_DIR" \
  --gin_file=models/am_nsynth.gin \
  --gin_file=datasets/tfrecord.gin \
  --gin_param="TFRecordProvider.file_pattern='$DRIVE_TFRECORD_PATTERN'" \
  --gin_param="TFRecordProvider.frame_rate=$FRAME_RATE" \
  --gin_param="train_util.train.batch_size=8" \
  --gin_param="train_util.train.num_steps=1000" \
  --gin_param="train_util.train.steps_per_save=100" \
  --gin_param="train_util.train.steps_per_summary=25" \
  --gin_param="trainers.Trainer.checkpoints_to_keep=5" \

#### Load pretrained model

In [None]:
data_provider_eval = data.TFRecordProvider(DRIVE_TFRECORD_PATTERN, frame_rate=FRAME_RATE)
dataset_eval = data_provider_eval.get_batch(batch_size=1, shuffle=True).take(1).repeat()
dataset_eval_iter = iter(dataset_eval)

gin_file = DRIVE_CHECKPOINTS_DIR + 'operative_config-0.gin'
gin.parse_config_file(gin_file)

model = models.Autoencoder()
model.restore(DRIVE_CHECKPOINTS_DIR)

In [None]:
frame = next(dataset_eval_iter)
audio_baseline = frame['audio']

controls = model(frame, training=False)
audio_full = model.get_audio_from_outputs(controls)

print('Original Audio')
play(audio_baseline)

print('Full reconstruction')
play(audio_full)

for synth in ['harmonic', 'am', 'noise']:
  if synth in controls:
    print('Only ' + synth)
    play(controls[synth]['signal'])

# specplot(audio_baseline)
# specplot(audio_full)
get = lambda key: core.nested_lookup(key, controls)[0] #batch 0

amps = get('am/controls/amps')
mod_amps = get('am/controls/mod_amps')

f0 = get('am/controls/f0_hz')
mod_f0 = get('am/controls/mod_f0_hz')

f, ax = plt.subplots(1, 2, figsize=(10.5, 3))
f.suptitle('Synthesized audio', fontsize=14)
ax[0].plot(amps)
ax[0].plot(mod_amps)
ax[0].set_ylabel('Amplitude')
ax[0].legend(['Carrier', 'Modulator'])
ax[1].plot(f0)
ax[1].plot(mod_f0)
ax[1].set_ylabel('Freqs')
_ = ax[1].legend(['Carrier', 'Modulator'])

In [None]:
raise SystemExit("Stop right there!")