In [None]:
!pip install git+https://github.com/mikeoliphant/neural-amp-modeler.git@devel --quiet

%load_ext tensorboard

In [None]:
%tensorboard --logdir /content/lightning_logs

import pytorch_lightning as pl

from nam.train.colab import _check_for_files
from nam.train.core import _detect_input_version, _calibrate_delay, init_dataset, Split, _get_configs, Architecture, _get_dataloaders, Model, Version

input_version, input_basename = _check_for_files()

train_path = "."

output_path = "output.wav"

input_version, strong_match = _detect_input_version(input_basename)

delay = _calibrate_delay(None, input_version, input_basename, "output.wav", silent=False)

epochs = 300;

device_config = {}
learning_config = {
        "train_dataloader": {
            "batch_size": 16,
            "shuffle": True,
            "pin_memory": True,
            "drop_last": True,
            "num_workers": 0,
        },
        "val_dataloader": {},
        "trainer": {"max_epochs": epochs, **device_config},
    }

data_config, model_config, learning_config = _get_configs(
    input_version,
    input_basename,
    output_path,
    delay,
    epochs,
    "LSTM",
    Architecture.FEATHER,
    ny = 20000,
    lr = 0.01,
    lr_decay = 0.007,
    batch_size = 16,
    fit_cab = False
)

model_config["net"]["config"] = {
            "num_layers": 1,
            "hidden_size": 16,
            "train_burn_in": 4096,
            "train_truncate": 512,
        }

# data_config, model_config, learning_config = _get_configs(
#     input_version,
#     input_basename,
#     output_path,
#     delay,
#     epochs,
#     "WaveNet",
#     Architecture.FEATHER,
#     ny = 8192,
#     lr = 0.004,
#     lr_decay = 0.007,
#     batch_size = 16,
#     fit_cab = False
# )

# model_config["net"]["config"] = {
#             "layers_configs": [
#                 {
#                     "input_size": 1,
#                     "condition_size": 1,
#                     "channels": 4,
#                     "head_size": 2,
#                     "kernel_size": 3,
#                     "dilations": [1, 2, 4, 8, 16, 32, 64],
#                     "activation": "Tanh",
#                     "gated": False,
#                     "head_bias": False,
#                 },
#                 {
#                     "condition_size": 1,
#                     "input_size": 4,
#                     "channels": 2,
#                     "head_size": 1,
#                     "kernel_size": 3,
#                     "dilations": [128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
#                     "activation": "Tanh",
#                     "gated": False,
#                     "head_bias": True,
#                 },
#             ],
#             "head_scale": 0.02,
# }

model = Model.init_from_config(model_config)
train_dataloader, val_dataloader = _get_dataloaders(
    data_config, learning_config, model
)

trainer = pl.Trainer(
    callbacks=[
        pl.callbacks.model_checkpoint.ModelCheckpoint(
            filename="checkpoint_best_{epoch:04d}_{step}_{ESR:.4f}_{MSE:.3e}",
            save_top_k=3,
            monitor="val_loss",
            every_n_epochs=1,
        ),
        pl.callbacks.model_checkpoint.ModelCheckpoint(
            filename="checkpoint_last_{epoch:04d}_{step}", every_n_epochs=1
        ),
    ],
    default_root_dir=train_path,
    **learning_config["trainer"],
)
trainer.fit(model, train_dataloader, val_dataloader)

# Go to best checkpoint
best_checkpoint = trainer.checkpoint_callback.best_model_path
if best_checkpoint != "":
    model = Model.load_from_checkpoint(
        trainer.checkpoint_callback.best_model_path,
        **Model.parse_config(model_config),
    )
model.cpu()
model.eval()

print("Exporting your model...")
model_export_outdir = _get_valid_export_directory()
model_export_outdir.mkdir(parents=True, exist_ok=False)
model.net.export(model_export_outdir)
print(f"Model exported to {model_export_outdir}. Enjoy!")