# Neural Amp Modeler (Trainer) [DEPRECATED]

**This trainer is deprecated and will be removed after version 0.5.3. Instead use either `easy_colab.ipynb` or the local GUI trainer.**

This notebook allows you to train a neural amp model based on two pairs of input/output WAV files that you have of the amp you want to model.

**To use this notebook**:
Go to [colab.research.google.com](https://colab.research.google.com/), select the "GitHub" tab, and select this notebook. Or, if you've cloned the repo, you can upload it from your computer.

🔶**Before you run**🔶

Make sure to get a GPU! (Runtime->Change runtime type->Select "GPU" from the "Hardware accelerator dropdown menu)

⚠**Warning**⚠

Google Colab GPU instances only last for 12 hours.
Plan your training accordingly!

## Steps:
1. Upload audio files
2. Installation
3. Settings
4. Run!
5. Check
6. Export
7. Download your files

## Step 1: Upload audio files
We're gonna need data. **Read this because it's important. Your model lives and dies with its data!**

### Some tips for making good data:
I'm going to assume you know about proper gain staging for reamping. Beyond that, here are a few things that are less obvious:
* **Show your model everything!** The model is going to learn from your examples, so demonstrate everything! Play loud, play soft, play single notes, chords, different pickups, play through an overdrive pedal (you wanted your model to understand how pedals sound, right?), etc etc. Just think: You'll ask "But can the model clean up like the real thing?" _Just show it!_ (**Don't riff(!!!)** It sounds weird, but riffs are repetitive, and repetition is wasted data. Instead, just play every fret up and down every string. It's boring, but it's good data!)
* **"How much data?"** More is better, but there's diminishing returns. About 3 minutes is a good compromise, but up to maybe 15 minutes can still help if you really want the best model possible.
* 🔶**Measure the latency!**🔶 Most interfaces will have a little lag between when they send the signal and when the reamp comes back. Use your DAW to figure out how many samples it is--I'll ask you for it below. _This is important--If there's too much delay, then the model may not learn well. The closer you get this, the better the results will be, but don't over-compensate or else you're effectively asking the model to predict the future!_

### What you need
You'll need two pairs of files (4 in total):
* A training pair (`x_train.wav`, `y_train.wav`) for the model to fit to.
* A validation pair, (`x_test.wav`, `y_test.wav`) to check how the model's doing on something new.

`x_train.py` and `x_test.py` should be two (different!) DI files, and `y_train.wav` and `y_test.wav` should be their corresponding outputs that you reamped. **The train files should hold most of the data; the test files can be just a few seconds long.** The point of the test files is to just quickly check if your model gets it right if it sees something new (but not _too_ new--shouldn't you be training on those? ⬆)

### What to do
Upload the input (DI) and output (amped) files you want to use by clicking the Folder icon on the left ⬅ and then clicking the upload icon.

Once you're done, run the next cell and I'll check that everything looks good.

In [None]:
from pathlib import Path
# I'm just gonna check that you were paying attention ;)
for name in ("x_train.wav", "y_train.wav", "x_test.wav", "y_test.wav"):
  if not Path(name).exists():
    raise RuntimeError(f"I didn't find all of your data files. Where is {name}?")

## Step 2: Installation
Install `nam` into this Colab instance.

In [None]:
!pip install "neural-amp-modeler<=0.5.3"

In [None]:
from time import time
from typing import Optional, Union

import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader

from nam.data import Split, init_dataset
from nam.models import Model
from nam.models.losses import esr

## Step 3: Settings
The defaults are what I tend to start with and should usually work well, but if you'd like, you can make changes.

In [None]:
data_config = {
    "train": {
        "x_path": "x_train.wav",
        "y_path": "y_train.wav",
        "ny": 8192
    },
    "validation": {
        "x_path": "x_test.wav",
        "y_path": "y_test.wav",
        "ny": None
    },
    "common": {
        "delay": int(input("What is the latency (in samples) of your reamp? "))
    }
}
model_config = {
    "net": {
        "name": "WaveNet",
        # This is the "standard" model in easy mode / the local GUI trainer.
        "config": {
            "layers_configs": [
                {
                    "condition_size": 1,
                    "input_size": 1,
                    "channels": 16,
                    "head_size": 8,
                    "kernel_size": 3,
                    "dilations": [1,2,4,8,16,32,64,128,256,512],
                    "activation": "Tanh",
                    "gated": False,
                    "head_bias": False
                },
                {
                    "condition_size": 1,
                    "input_size": 16,
                    "channels": 8,
                    "head_size": 1,
                    "kernel_size": 3,
                    "dilations": [1,2,4,8,16,32,64,128,256,512],
                    "activation": "Tanh",
                    "gated": False,
                    "head_bias": True
                }
            ]
        }
    },
    "loss": {
        "val_loss": "esr"
    },
    "optimizer": {
        "lr": 0.004
    },
    "lr_scheduler": {
        "class": "ExponentialLR",
        "kwargs": {
            "gamma": 0.993
        }
    }
}
learning_config = {
    "train_dataloader": {
        "batch_size": 16,
        "shuffle": True,
        "pin_memory": True,
        "drop_last": True,
        "num_workers": 0
    },
    "val_dataloader": {},
    "trainer": {
        "accelerator": "gpu", 
        "devices": 1,
        "max_epochs": 100
    }
}

## Step 4: Run!
Let's rock

In [None]:
model = Model.init_from_config(model_config)

In [None]:
data_config["common"]["nx"] = model.net.receptive_field

In [None]:
dataset_train = init_dataset(data_config, Split.TRAIN)
dataset_validation = init_dataset(data_config, Split.VALIDATION)
train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"])
val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"])

In [None]:
trainer = pl.Trainer(
    callbacks=[
        pl.callbacks.model_checkpoint.ModelCheckpoint(
            filename="{epoch:04d}_{step}_{ESR:.3e}_{MSE:.3e}",
            save_top_k=3,
            monitor="val_loss",
            every_n_epochs=1,
        ),
        pl.callbacks.model_checkpoint.ModelCheckpoint(
            filename="checkpoint_last_{epoch:04d}_{step}", every_n_epochs=1
        ),
    ],
    **learning_config["trainer"],
)

Here we go!

🕙For a 3 minutes of training data, training will take just over 10 minutes.🕙 

But, if you want to stop early, you can always press the stop button.

If you want to train shorter or longer, you can also change the `"max_epochs"` above.

In [None]:
trainer.fit(model, train_dataloader, val_dataloader)
# Monitor the progress in lightning_logs/version_0/checkpoints.
#
# Many models can get a good result (rule of thumb: look for ESR<0.01) in about 15 
# minutes of training, but if you're more patient, it'll probably keep getting better.

In [None]:
# Go to best checkpoint
best_checkpoint = trainer.checkpoint_callback.best_model_path
if best_checkpoint != "":
    model = Model.load_from_checkpoint(
        trainer.checkpoint_callback.best_model_path,
        **Model.parse_config(model_config),
    )
model.cpu()
model.eval()

# Step 5: Check
Let's look at how well our model matches the real thing.

In [None]:
def plot(
    model,
    ds,
    savefig=None,
    show=True,
    window_start: Optional[int] = None,
    window_end: Optional[int] = None,
):
    with torch.no_grad():
        tx = len(ds.x) / 48_000
        print(f"Run (t={tx})")
        t0 = time()
        output = model(ds.x).flatten().cpu().numpy()
        t1 = time()
        print(f"Took {t1 - t0} ({tx / (t1 - t0):.2f}x)")

    plt.figure(figsize=(16, 5))
    plt.plot(output[window_start:window_end], label="Prediction")
    plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target")
    plt.title(f"ESR={esr(torch.Tensor(output), ds.y):.4f}")
    plt.legend()
    if savefig is not None:
        plt.savefig(savefig)
    if show:
        plt.show()

In [None]:
plot(
    model,
    dataset_validation,
    window_start=100_000,  # Start of the plotting window, in samples
    window_end=101_000,  # End of the plotting window, in samples
)

## Step 6: Export your model
Now we'll use NAM's exporting utility to convert the model from its PyTorch representation to something that you can put into the plugin.

In [None]:
Path("exported_model").mkdir()
model.net.export("exported_model")

## Step 7: Download your artifacts
We're done! 
Go to the file browser on the left panel ⬅ and download `model.nam` from the `exported_model` directory (you may need to hit the refresh button).

Additionally, if you want to continue to train this model later you can download the lightning model artifacts from `lightning_logs`. If not, that's fine.

# 🎸 **ENJOY!** 🎸