In [2]:
import sys
sys.path.append("../git_libraries/ddsp-pytorch")
sys.path.append("../git_libraries/ddsp-pytorch/train")
sys.path.append("../git_libraries/ddsp-pytorch/components")
sys.path.append("../git_libraries/ddsp-pytorch/configs")
sys.path.append("../git_libraries/ddsp-pytorch/data")

DEV = 'cuda'
data_dir = "../datasets/datasets-master/GT-ITM-Flute-99"

## 0. Pre-process Dataset

In [30]:
from pathlib import Path
import os
data_dir = Path("../datasets/datasets-master/GT-ITM-Flute-99")
resample_dir = data_dir.parent/ 'flute_99_resampled'
resample_dir.mkdir(exist_ok=True)
for wav in data_dir.glob("*.wav"):
  os.system(f"ffmpeg -y -loglevel fatal -i {wav} -ac 1 -ar 16000 {resample_dir / (wav.stem + '_16kHz.wav')}")


In [23]:
os.system(f"crepe {data_dir.parent/'flute_99_resampled'} --viterbi --step-size 4")

/bin/bash: crepe: command not found


32512

In [32]:
import torch
import torch.nn as nn
import torchaudio
from omegaconf import OmegaConf
import sys, os, tqdm, glob
import numpy as np

from torch.utils.data.dataloader import DataLoader
import torch.optim as optim

from trainer.trainer import Trainer
from trainer.io import setup, set_seeds

from dataset.audiodata import SupervisedAudioData, AudioData
from network.autoencoder.autoencoder import AutoEncoder
from loss.mss_loss import MSSLoss
from optimizer.radam import RAdam
from pathlib import Path

%load_ext autoreload
%autoreload 2


# 1. Train Model

In [53]:
config = OmegaConf.load("../git_libraries/ddsp-pytorch/configs/violin.yaml")
set_seeds(config.seed)
Trainer.set_experiment_name(config.experiment_name)

In [55]:
net = AutoEncoder(config).to(DEV)

loss = MSSLoss([2048, 1024, 512, 256], use_reverb=config.use_reverb).to(DEV)
def metric(output, gt):
  with torch.no_grad():
   return -loss(output, gt)


In [62]:
data_dir = Path("../datasets/datasets-master/flute_99_resampled")
train_data =  list(data_dir.glob("*.wav")) * config.batch_size
train_data_csv = [wav.parent / (wav.stem + ".f0.csv") for wav in train_data]
# train_data_csv = [
#     os.path.dirname(wav)
#     + f"/f0_{config.frame_resolution:.3f}/"
#     + os.path.basename(os.path.splitext(wav)[0])
#     + ".f0.csv"
#     for wav in train_data
# ]
print(len(train_data)//config.batch_size)

88


In [64]:
valid_data = list((data_dir / "test").glob("*.wav"))
valid_data_csv = [wav.parent / (wav.stem + ".f0.csv") for wav in valid_data]
# valid_data_csv = [
#     os.path.dirname(wav)
#     + f"/f0_{config.frame_resolution:.3f}/"
#     + os.path.basename(os.path.splitext(wav)[0])
#     + ".f0.csv"
#     for wav in valid_data
# ]
print(len(valid_data))

10


In [65]:
train_dataset = SupervisedAudioData(
    sample_rate=config.sample_rate,
    paths=train_data,
    csv_paths=train_data_csv,
    seed=config.seed,
    waveform_sec=config.waveform_sec,
    frame_resolution=config.frame_resolution,
)

valid_dataset = SupervisedAudioData(
    sample_rate=config.sample_rate,
    paths=valid_data,
    csv_paths=valid_data_csv,
    seed=config.seed,
    waveform_sec=config.valid_waveform_sec,
    frame_resolution=config.frame_resolution,
    random_sample=False,
)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    pin_memory=True,
)

valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=int(config.batch_size // (config.valid_waveform_sec / config.waveform_sec)),
    shuffle=False,
    num_workers=config.num_workers,
    pin_memory=False,
)

In [46]:
# Setting Optimizer
if config.optimizer == "adam":
    optimizer = optim.Adam(filter(lambda x: x.requires_grad, net.parameters()), lr=config.lr)
elif config.optimizer == "radam":
    optimizer = RAdam(filter(lambda x: x.requires_grad, net.parameters()), lr=config.lr)
else:
    raise NotImplementedError


In [47]:
# Setting Scheduler
if config.lr_scheduler == "cosine":
    # restart every T_0 * validation_interval steps
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=20, eta_min=config.lr_min
    )
elif config.lr_scheduler == "plateau":
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="max", patience=5, factor=config.lr_decay
    )
elif config.lr_scheduler == "multi":
    # decay every ( 10000 // validation_interval ) steps
    scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer,
        [(x + 1) * 10000 // config.validation_interval for x in range(10)],
        gamma=config.lr_decay,
    )
elif config.lr_scheduler == "no":
    scheduler = None
else:
    raise ValueError(f"unknown lr_scheduler :: {config.lr_scheduler}")


In [48]:
def validation_callback():
    global save_counter, save_interval
    # Save generated audio per every validation
    net.eval()

    def tensorboard_audio(data_loader, phase):

        bd = next(iter(data_loader))
        for k, v in bd.items():
            bd[k] = v.cuda()

        original_audio = bd["audio"][0]
        estimation = net(bd)

        if config.use_reverb:
            reconed_audio = estimation["audio_reverb"][0, : len(original_audio)]
            trainer.tensorboard.add_audio(
                f"{trainer.config['experiment_id']}/{phase}_recon",
                reconed_audio.cpu(),
                trainer.config["step"],
                sample_rate=config.sample_rate,
            )

        reconed_audio_dereverb = estimation["audio_synth"][0, : len(original_audio)]
        trainer.tensorboard.add_audio(
            f"{trainer.config['experiment_id']}/{phase}_recon_dereverb",
            reconed_audio_dereverb.cpu(),
            trainer.config["step"],
            sample_rate=config.sample_rate,
        )
        trainer.tensorboard.add_audio(
            f"{trainer.config['experiment_id']}/{phase}_original",
            original_audio.cpu(),
            trainer.config["step"],
            sample_rate=config.sample_rate,
        )

    tensorboard_audio(train_dataloader, phase="train")
    tensorboard_audio(valid_dataloader, phase="valid")

    save_counter += 1
    if save_counter % save_interval == 0:
        trainer.save(trainer.ckpt + f"-{trainer.config['step']}")


In [66]:
trainer = Trainer(
    net,
    criterion=loss,
    metric=metric,
    train_dataloader=train_dataloader,
    val_dataloader=valid_dataloader,
    optimizer=optimizer,
    lr_scheduler=scheduler,
    ckpt=config.ckpt,
    is_data_dict=True,
    experiment_id=os.path.splitext(os.path.basename(config.ckpt))[0],
    tensorboard_dir=config.tensorboard_dir,
)

save_counter = 0
save_interval = 10

trainer.register_callback(validation_callback)
trainer.add_external_config(config)


In [57]:
trainer.config

defaultdict(float,
            {'max_train_metric': -100000000.0,
             'max_val_metric': -100000000.0,
             'max_test_metric': -100000000.0,
             'tensorboard_dir': '../tensorboard_log/',
             'timestamp': '20211114_142430',
             'clip_gradient_norm': False,
             'is_data_dict': True,
             'experiment_id': '200131',
             'config__metadata': ContainerMetadata(ref_type=typing.Any, object_type=None, optional=True, key=None, flags={}, flags_root=False, resolver_cache=defaultdict(<class 'dict'>, {}), key_type=typing.Any, element_type=typing.Any),
             'config__parent': None,
             'config__flags_cache': {'struct': None},
             'config__content': {'batch_size': 64,
              'bidirectional': False,
              'ckpt': '../../ckpt/violin/200131.pth',
              'crepe': 'full',
              'experiment_name': 'DDSP_violin',
              'f0_threshold': 0.5,
              'frame_resolution': 0.004,

In [67]:
trainer.train(step=config.num_step, validation_interval=config.validation_interval)


|---------|---------|---------|---------|---------|---------|---------|
|   step  |train_loss|train_metric| val_loss|val_metric|    lr   |   time  |
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|1000-best|82.756420|-82.756420|164.310028|-164.310028| 0.001000|90.340890|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|   2000  |82.488003|-82.488003|164.310188|-164.310188| 0.001000|91.434499|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|3000-best|82.614553|-82.614553|164.308380|-164.308380| 0.001000|91.569626|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|4000-best|82.798254|-82.798254|164.307266|-164.307266| 0.001000|91.756862|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|   5000  |82.655777|-82.655777|164.308853|-164.308853| 0.001000|91.646980|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|   6000  |82.687615|-82.687615|164.308220|-164.308220| 0.001000|91.804785|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|   7000  |82.924833|-82.924833|164.316971|-164.316971| 0.001000|91.791871|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|8000-best|82.746657|-82.746657|164.305611|-164.305611| 0.001000|91.707196|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|9000-best|82.748947|-82.748947|164.301430|-164.301430| 0.001000|91.743277|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  10000  |82.706520|-82.706520|164.303978|-164.303978| 0.000980|91.728432|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|11000-best|82.805070|-82.805070|164.300957|-164.300957| 0.000980|91.740675|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  12000  |82.637506|-82.637506|164.313019|-164.313019| 0.000980|91.700974|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  13000  |82.794443|-82.794443|164.313187|-164.313187| 0.000980|91.696833|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  14000  |82.705108|-82.705108|164.309837|-164.309837| 0.000980|91.644739|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  15000  |82.819343|-82.819343|164.316795|-164.316795| 0.000980|91.669772|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  16000  |82.382693|-82.382693|164.306511|-164.306511| 0.000980|91.808795|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  17000  |82.781577|-82.781577|164.320351|-164.320351| 0.000980|91.637954|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  18000  |82.552828|-82.552828|164.321251|-164.321251| 0.000980|91.663061|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  19000  |82.638182|-82.638182|164.307892|-164.307892| 0.000980|91.837503|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  20000  |82.802564|-82.802564|164.311180|-164.311180| 0.000960|91.809103|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  21000  |82.621818|-82.621818|164.314804|-164.314804| 0.000960|91.639196|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  22000  |82.567645|-82.567645|164.309219|-164.309219| 0.000960|91.787132|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  23000  |82.659238|-82.659238|164.308571|-164.308571| 0.000960|91.668063|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  24000  |82.779346|-82.779346|164.319450|-164.319450| 0.000960|91.648122|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  25000  |82.840183|-82.840183|164.307854|-164.307854| 0.000960|91.809179|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  26000  |82.727431|-82.727431|164.308754|-164.308754| 0.000960|91.698351|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  27000  |82.716462|-82.716462|164.305656|-164.305656| 0.000960|91.708323|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  28000  |83.013525|-83.013525|164.303581|-164.303581| 0.000960|91.738077|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  29000  |82.707791|-82.707791|164.306419|-164.306419| 0.000960|91.712373|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  30000  |82.652360|-82.652360|164.320740|-164.320740| 0.000941|91.764746|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  31000  |82.770935|-82.770935|164.311600|-164.311600| 0.000941|91.690772|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  32000  |82.751112|-82.751112|164.312828|-164.312828| 0.000941|91.671338|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  33000  |82.788318|-82.788318|164.302711|-164.302711| 0.000941|91.752187|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  34000  |82.831611|-82.831611|164.306877|-164.306877| 0.000941|91.688044|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  35000  |82.636993|-82.636993|164.311455|-164.311455| 0.000941|91.729205|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  36000  |82.716445|-82.716445|164.315872|-164.315872| 0.000941|91.668897|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  37000  |82.728664|-82.728664|164.313805|-164.313805| 0.000941|91.732434|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  38000  |82.738425|-82.738425|164.310074|-164.310074| 0.000941|91.664852|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  39000  |82.521798|-82.521798|164.310806|-164.310806| 0.000941|91.776742|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  40000  |82.636348|-82.636348|164.313362|-164.313362| 0.000922|91.866961|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  41000  |82.997714|-82.997714|164.310799|-164.310799| 0.000922|91.709276|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  42000  |82.748334|-82.748334|164.307503|-164.307503| 0.000922|91.988487|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

|  43000  |82.698873|-82.698873|164.314743|-164.314743| 0.000922|91.936343|
|---------|---------|---------|---------|---------|---------|---------|


                                                                                                                                                                                                       

KeyboardInterrupt: 

In [54]:
net

AutoEncoder(
  (decoder): Decoder(
    (mlp_f0): MLP(
      (mlp_layer1): Sequential(
        (0): Linear(in_features=1, out_features=512, bias=True)
        (1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (2): ReLU(inplace=True)
      )
      (mlp_layer2): Sequential(
        (0): Linear(in_features=512, out_features=512, bias=True)
        (1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (2): ReLU(inplace=True)
      )
      (mlp_layer3): Sequential(
        (0): Linear(in_features=512, out_features=512, bias=True)
        (1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (2): ReLU(inplace=True)
      )
    )
    (mlp_loudness): MLP(
      (mlp_layer1): Sequential(
        (0): Linear(in_features=1, out_features=512, bias=True)
        (1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (2): ReLU(inplace=True)
      )
      (mlp_layer2): Sequential(
        (0): Linear(in_features=512, out_features=512, bias=True)
 

In [63]:
torchaudio.__version__

'0.10.0+cu113'

In [69]:
audio_path = "../datasets/datasets-master/GT-ITM-Flute-99/51_OGr_CnocBui_Tk2Tu1R4_TheFishermansLilt.wav"
y, sr = torchaudio.load(audio_path)

if sr != net.config.sample_rate:
  y = torchaudio.functional.resample(y, sr, net.config.sample_rate)

In [70]:
recon = net.reconstruction(y)

In [71]:
recon

{'harmonic': tensor([ 0.0010, -0.0001,  0.0004,  ..., -0.0011, -0.0011, -0.0011],
        device='cuda:0'),
 'noise': tensor([ 1.9868e-10,  4.4803e-08,  4.4018e-07,  ..., -1.1652e-06,
         -4.0710e-07,  2.9430e-09], device='cuda:0'),
 'audio_synth': tensor([ 0.0010, -0.0001,  0.0004,  ..., -0.0031,  0.0008, -0.0012],
        device='cuda:0'),
 'audio_reverb': tensor([ 9.5092e-04, -1.3779e-04,  1.9597e-04,  ...,  1.0118e-06,
         -1.0107e-06,  1.0298e-06], device='cuda:0'),
 'a': tensor([0.8052, 0.8241, 0.8320,  ..., 0.8603, 0.9064, 0.9751], device='cuda:0'),
 'c': tensor([[0.0084, 0.0087, 0.0087,  ..., 0.0082, 0.0084, 0.0073],
         [0.0061, 0.0064, 0.0067,  ..., 0.0072, 0.0069, 0.0071],
         [0.0113, 0.0111, 0.0111,  ..., 0.0111, 0.0118, 0.0120],
         ...,
         [0.0143, 0.0142, 0.0141,  ..., 0.0138, 0.0141, 0.0147],
         [0.0099, 0.0098, 0.0099,  ..., 0.0099, 0.0078, 0.0075],
         [0.0103, 0.0095, 0.0091,  ..., 0.0090, 0.0094, 0.0095]],
        device='c

In [73]:
import IPython.display as ipd

In [74]:
ipd.Audio(recon["audio_synth"].cpu(), rate=net.config.sample_rate)