<div align="center">

# Steerable discovery of neural audio effects

  [Christian J. Steinmetz](https://www.christiansteinmetz.com/)  and  [Joshua D. Reiss](http://www.eecs.qmul.ac.uk/~josh/)


[Code](https://github.com/csteinmetz1/steerable-nafx) • [Paper](https://arxiv.org/abs/2112.02926) • [Demo](https://csteinmetz1.github.io/steerable-nafx)	• [Slides]()

<img src="https://csteinmetz1.github.io/steerable-nafx/assets/steerable-headline.svg">

</div>

## Abtract
Applications of deep learning for audio effects often focus on modeling analog effects or learning to control effects to emulate a trained audio engineer. 
However, deep learning approaches also have the potential to expand creativity through neural audio effects that enable new sound transformations. 
While recent work demonstrated that neural networks with random weights produce compelling audio effects, control of these effects is limited and unintuitive.
To address this, we introduce a method for the steerable discovery of neural audio effects.
This method enables the design of effects using example recordings provided by the user. 
We demonstrate how this method produces an effect similar to the target effect, along with interesting inaccuracies, while also providing perceptually relevant controls.


\* *Accepted to NeurIPS 2021 Workshop on Machine Learning for Creativity and Design*



# Setup
Run this first to install and import the relevant things...

In [None]:
!pip install torchaudio auraloss pyloudnorm

Download some sounds.

In [None]:
!wget https://csteinmetz1.github.io/sounds/assets/drum_kit_clean.wav 
!wget https://csteinmetz1.github.io/sounds/assets/drum_kit_comp_agg.wav 
!wget https://csteinmetz1.github.io/sounds/assets/acgtr_clean.wav 
!wget https://csteinmetz1.github.io/sounds/assets/acgtr_reverb.wav 
!wget https://csteinmetz1.github.io/sounds/assets/piano_clean.wav 

Download the pre-trained models.

In [None]:
!wget https://csteinmetz1.github.io/steerable-nafx/models/compressor_full.pt > /dev/null
!wget https://csteinmetz1.github.io/steerable-nafx/models/reverb_full.pt > /dev/null
!wget https://csteinmetz1.github.io/steerable-nafx/models/amp_full.pt > /dev/null
!wget https://csteinmetz1.github.io/steerable-nafx/models/delay_full.pt > /dev/null
!wget https://csteinmetz1.github.io/steerable-nafx/models/synth2synth_full.pt > /dev/null

In [None]:
import sys
import math
import torch
import librosa.display
import IPython
import auraloss
import torchaudio
import numpy as np
import scipy.signal
from google.colab import files
from tqdm.notebook import tqdm
from time import sleep
import matplotlib
import pyloudnorm as pyln
import matplotlib.pyplot as plt
from IPython.display import Image
%matplotlib inline

In [None]:
# Sources from:
# https://github.com/LCAV/pyroomacoustics/blob/master/pyroomacoustics/experimental/rt60.py
def measure_rt60(h, fs=1, decay_db=30, rt60_tgt=None):
    """
    Analyze the RT60 of an impulse response.
    Args:
        h (ndarray): The discrete time impulse response as 1d array.
        fs (float, optional): Sample rate of the impulse response. (Default: 48000)
        decay_db (float, optional): The decay in decibels for which we actually estimate the time. (Default: 60)
        rt60_tgt (float, optional): This parameter can be used to indicate a target RT60. (Default: None)
    Returns:
        est_rt60 (float): Estimated RT60.
    """

    h = np.array(h)
    fs = float(fs)

    # The power of the impulse response in dB
    power = h ** 2
    energy = np.cumsum(power[::-1])[::-1]  # Integration according to Schroeder

    try:
        # remove the possibly all zero tail
        i_nz = np.max(np.where(energy > 0)[0])
        energy = energy[:i_nz]
        energy_db = 10 * np.log10(energy)
        energy_db -= energy_db[0]

        # -5 dB headroom
        i_5db = np.min(np.where(-5 - energy_db > 0)[0])
        e_5db = energy_db[i_5db]
        t_5db = i_5db / fs

        # after decay
        i_decay = np.min(np.where(-5 - decay_db - energy_db > 0)[0])
        t_decay = i_decay / fs

        # compute the decay time
        decay_time = t_decay - t_5db
        est_rt60 = (60 / decay_db) * decay_time
    except:
        est_rt60 = np.array(0.0)

    return est_rt60

In [None]:
def causal_crop(x, length: int):
    if x.shape[-1] != length:
        stop = x.shape[-1] - 1
        start = stop - length
        x = x[..., start:stop]
    return x

class FiLM(torch.nn.Module):
    def __init__(
        self,
        cond_dim,  # dim of conditioning input
        num_features,  # dim of the conv channel
        batch_norm=True,
    ):
        super().__init__()
        self.num_features = num_features
        self.batch_norm = batch_norm
        if batch_norm:
            self.bn = torch.nn.BatchNorm1d(num_features, affine=False)
        self.adaptor = torch.nn.Linear(cond_dim, num_features * 2)

    def forward(self, x, cond):

        cond = self.adaptor(cond)
        g, b = torch.chunk(cond, 2, dim=-1)
        g = g.permute(0, 2, 1)
        b = b.permute(0, 2, 1)

        if self.batch_norm:
            x = self.bn(x)  # apply BatchNorm without affine
        x = (x * g) + b  # then apply conditional affine

        return x

class TCNBlock(torch.nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, dilation, cond_dim=0, activation=True):
    super().__init__()
    self.conv = torch.nn.Conv1d(
        in_channels, 
        out_channels, 
        kernel_size, 
        dilation=dilation, 
        padding=0, #((kernel_size-1)//2)*dilation,
        bias=True)
    if cond_dim > 0:
      self.film = FiLM(cond_dim, out_channels, batch_norm=False)
    if activation:
      #self.act = torch.nn.Tanh()
      self.act = torch.nn.PReLU()
    self.res = torch.nn.Conv1d(in_channels, out_channels, 1, bias=False)

  def forward(self, x, c=None):
    x_in = x
    x = self.conv(x)
    if hasattr(self, "film"):
      x = self.film(x, c)
    if hasattr(self, "act"):
      x = self.act(x)
    x_res = causal_crop(self.res(x_in), x.shape[-1])
    x = x + x_res

    return x

class TCN(torch.nn.Module):
  def __init__(self, n_inputs=1, n_outputs=1, n_blocks=10, kernel_size=13, n_channels=64, dilation_growth=4, cond_dim=0):
    super().__init__()
    self.kernel_size = kernel_size
    self.n_channels = n_channels
    self.dilation_growth = dilation_growth
    self.n_blocks = n_blocks
    self.stack_size = n_blocks

    self.blocks = torch.nn.ModuleList()
    for n in range(n_blocks):
      if n == 0:
        in_ch = n_inputs
        out_ch = n_channels
        act = True
      elif (n+1) == n_blocks:
        in_ch = n_channels
        out_ch = n_outputs
        act = True
      else:
        in_ch = n_channels
        out_ch = n_channels
        act = True
      
      dilation = dilation_growth ** n
      self.blocks.append(TCNBlock(in_ch, out_ch, kernel_size, dilation, cond_dim=cond_dim, activation=act))

  def forward(self, x, c=None):
    for block in self.blocks:
      x = block(x, c)

    return x
  
  def compute_receptive_field(self):
    """Compute the receptive field in samples."""
    rf = self.kernel_size
    for n in range(1, self.n_blocks):
        dilation = self.dilation_growth ** (n % self.stack_size)
        rf = rf + ((self.kernel_size - 1) * dilation)
    return rf

In [None]:
# setup the pre-trained models
model_comp = torch.load("compressor_full.pt", map_location="cpu").eval()
model_verb = torch.load("reverb_full.pt", map_location="cpu").eval()
model_amp = torch.load("amp_full.pt", map_location="cpu").eval()
model_delay = torch.load("delay_full.pt", map_location="cpu").eval()
model_synth = torch.load("synth2synth_full.pt", map_location="cpu").eval()

# 1. Pre-trained models
Jump right in by processing your own audio with some pre-trained models. 

1. Upload your input audio.
2. Select your desired pre-trained model.
3. Set the audio effect parameters.

In [None]:
#@title Upload input audio
process_upload = files.upload()
process_file = list(process_upload.keys())[-1]
x_p, sample_rate = torchaudio.load(process_file)
print(process_file, x_p.shape)
IPython.display.display(IPython.display.Audio(data=x_p, rate=sample_rate))

Now set the audio effect parameters. 
Here are some more insights into the controls:

- `effect_type` - Choose from one of the pre-trained models.
- `gain_dB` - Adjust the input gain. This can have a big effect since the effects are very nonlinear.
- `c0` and `c1` - These are the effect controls which will adjust perceptual aspects of the effect, depending on the effect type. Very large values will often result in more extreme effects.
- `mix` - Control the wet/dry mix of the effect.
- `width` - Increase stereo width of the effect.
- `max_length` - If you uploaded a very long file this will truncate it.
- `stereo` - Convert mono input to stereo output.
- `tail` - If checked, we will also compute the effect tail (nice for reverbs). 


In [None]:
effect_type = "Compressor" #@param ["Compressor", "Reverb", "Amp", "Analog Delay", "Synth2Synth"]
gain_dB = -24 #@param {type:"slider", min:-24, max:24, step:0.1}
c0 = -1.4 #@param {type:"slider", min:-10, max:10, step:0.1}
c1 = 3 #@param {type:"slider", min:-10, max:10, step:0.1}
mix = 70 #@param {type:"slider", min:0, max:100, step:1}
width = 50 #@param {type:"slider", min:0, max:100, step:1}
max_length = 30 #@param {type:"slider", min:5, max:120, step:1}
stereo = True #@param {type:"boolean"}
tail = True #@param {type:"boolean"}

# select model type
if effect_type == "Compressor":
  pt_model = model_comp
elif effect_type == "Reverb":
  pt_model = model_verb
elif effect_type == "Amp":
  pt_model = model_amp
elif effect_type == "Analog Delay":
  pt_model = model_delay
elif effect_type == "Synth2Synth":
  pt_model = model_synth

# measure the receptive field
pt_model_rf = pt_model.compute_receptive_field()

# crop input signal if needed
max_samples = int(sample_rate * max_length)
x_p_crop = x_p[:,:max_samples]
chs = x_p_crop.shape[0]

# if mono and stereo requested
if chs == 1 and stereo:
  x_p_crop = x_p_crop.repeat(2,1)
  chs = 2

# pad the input signal
front_pad = pt_model_rf-1
back_pad = 0 if not tail else front_pad
x_p_pad = torch.nn.functional.pad(x_p_crop, (front_pad, back_pad))

# design highpass filter
sos = scipy.signal.butter(
    8, 
    20.0, 
    fs=sample_rate, 
    output="sos", 
    btype="highpass"
)

# compute linear gain 
gain_ln = 10 ** (gain_dB / 20.0)

# process audio with pre-trained model
with torch.no_grad():
  y_hat = torch.zeros(x_p_crop.shape[0], x_p_crop.shape[1] + back_pad)
  for n in range(chs):
    if n == 0:
      factor = (width*5e-3)
    elif n == 1:
      factor = -(width*5e-3)
    c = torch.tensor([float(c0+factor), float(c1+factor)]).view(1,1,-1)
    y_hat_ch = pt_model(gain_ln * x_p_pad[n,:].view(1,1,-1), c)
    y_hat_ch = scipy.signal.sosfilt(sos, y_hat_ch.view(-1).numpy())
    y_hat_ch = torch.tensor(y_hat_ch)
    y_hat[n,:] = y_hat_ch

# pad the dry signal 
x_dry = torch.nn.functional.pad(x_p_crop, (0,back_pad))

# normalize each first
y_hat /= y_hat.abs().max()
x_dry /= x_dry.abs().max()

# mix
mix = mix/100.0
y_hat = (mix * y_hat) + ((1-mix) * x_dry)

# remove transient
y_hat = y_hat[...,8192:]
y_hat /= y_hat.abs().max()

torchaudio.save("output.mp3", y_hat.view(chs,-1), sample_rate, compression=320.0)
print("Done.")
print("Sending audio to browser...")

# show the audio
IPython.display.display(IPython.display.Audio("output.mp3"))

Click the three dots to download your processed audio file.

# 2. Steering (training)
Use a pair of audio examples in order to construct neural audio effects.

There are two options. Either start with the pre-loaded audio examples, or upload your own clean/processed audio recordings for the steering process.

a.) Use some of our pre-loaded audio examples. Choose from the compressor or reverb effect.

In [None]:
#@title Use pre-loaded audio examples for steering
effect_type = "Compressor" #@param ["Compressor", "Reverb"]

if effect_type == "Compressor":
  input_file = "drum_kit_clean.wav"
  output_file = "drum_kit_comp_agg.wav"
elif effect_type == "Reverb":
  input_file = "acgtr_clean.wav"
  output_file = "acgtr_reverb.wav"

x, sample_rate = torchaudio.load(input_file)
x = x[0:1,:]

y, sample_rate = torchaudio.load(output_file)
y = y[0:1,:]

print("input file", x.shape)
IPython.display.display(IPython.display.Audio(data=x, rate=sample_rate))
print("output file", y.shape)
IPython.display.display(IPython.display.Audio(data=y, rate=sample_rate))

  b.) or, load you own input/output sounds. 
  
  The files must have the same length.
  

In [None]:
#@title Upload clean sound (x)
# upload the clean input file
input_upload = files.upload()
input_file = list(input_upload.keys())[-1]
x, sample_rate = torchaudio.load(input_file)
print(input_file, x.shape)
IPython.display.display(IPython.display.Audio(data=x, rate=sample_rate))

In [None]:
# upload the same file processed with an effect
#@title Upload processed sound (y)
output_upload = files.upload()
output_file = list(output_upload.keys())[-1]
y, sample_rate = torchaudio.load(output_file)

if not y.shape[-1] == x.shape[-1]:
  print(f"Input and output files are different lengths! Found clean: {x.shape[-1]} processed: {y.shape[-1]}.")
  if y.shape[-1] > x.shape[-1]:
    print(f"Cropping target...")
    y = y[:,:x.shape[-1]]
  else:
    print(f"Cropping input...")
    x = x[:,:y.shape[-1]]

print(output_file, y.shape)
IPython.display.display(IPython.display.Audio(data=y, rate=sample_rate))

Now its time to generate the neural audio effect by training the TCN to emulate the input/output function from the target audio effect. Adjusting the parameters will enable you to tweak the optimization process. 

In [None]:
#@title TCN model training parameters
cond_dim = 2 #@param {type:"slider", min:1, max:10, step:1}
kernel_size = 13 #@param {type:"slider", min:3, max:32, step:1}
n_blocks = 5 #@param {type:"slider", min:2, max:30, step:1}
dilation_growth = 8 #@param {type:"slider", min:1, max:10, step:1}
n_channels = 8 #@param {type:"slider", min:1, max:128, step:1}
n_iters = 2499 #@param {type:"slider", min:0, max:10000, step:1}
length = 228308 #@param {type:"slider", min:0, max:524288, step:1}
lr = 0.001 #@param {type:"number"}

if torch.cuda.is_available():
  device = "cuda"
else:
  device = "cpu"

# reshape the audio
x_batch = x.view(1,x.shape[0],-1)
y_batch = y.view(1,y.shape[0],-1)
c = torch.tensor([0.0, 0.0], device=device).view(1,1,-1)

# crop length
x_batch = x_batch[:,0:1,:]
y_batch = y_batch[:,0:1,:]

_, x_ch, x_samp = x_batch.size()
_, y_ch, y_samp = y_batch.size()

# build the model
model = TCN(
    n_inputs=x_ch,
    n_outputs=y_ch,
    cond_dim=cond_dim, 
    kernel_size=kernel_size, 
    n_blocks=n_blocks, 
    dilation_growth=dilation_growth, 
    n_channels=n_channels)
rf = model.compute_receptive_field()
params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Parameters: {params*1e-3:0.3f} k")
print(f"Receptive field: {rf} samples or {(rf/sample_rate)*1e3:0.1f} ms")

# setup loss function, optimizer, and scheduler
loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
    fft_sizes=[32, 128, 512, 2048],
    win_lengths=[32, 128, 512, 2048],
    hop_sizes=[16, 64, 256, 1024])
loss_fn_l1 = torch.nn.L1Loss()

optimizer = torch.optim.Adam(model.parameters(), lr)
ms1 = int(n_iters * 0.8)
ms2 = int(n_iters * 0.95)
milestones = [ms1, ms2]
print(
    "Learning rate schedule:",
    f"1:{lr:0.2e} ->",
    f"{ms1}:{lr*0.1:0.2e} ->",
    f"{ms2}:{lr*0.01:0.2e}",
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones,
    gamma=0.1,
    verbose=False,
)

# move tensors to GPU
if torch.cuda.is_available():
  model.to(device)
  x_batch = x_batch.to(device)
  y_batch = y_batch.to(device)
  c = c.to(device)

# pad input so that output is same size as input
#x_pad = torch.nn.functional.pad(x_batch, (rf-1, 0))

# iteratively update the weights
pbar = tqdm(range(n_iters))
for n in pbar:
  optimizer.zero_grad()

  start_idx = rf #np.random.randint(rf, x_batch.shape[-1]-length-1)
  stop_idx = start_idx + length
  x_crop = x_batch[...,start_idx-rf+1:stop_idx]
  y_crop = y_batch[...,start_idx:stop_idx]

  y_hat = model(x_crop, c)
  loss = loss_fn(y_hat, y_crop) #+ loss_fn_l1(y_hat, y_crop)

  loss.backward()
  optimizer.step()
  
  scheduler.step()
  if (n+1) % 1 == 0:
    pbar.set_description(f" Loss: {loss.item():0.3e} | ")

y_hat /= y_hat.abs().max()

model.eval()
x_pad = torch.nn.functional.pad(x_batch, (rf-1, 0))
with torch.no_grad():
  y_hat = model(x_pad, c)

input = causal_crop(x_batch.view(-1).detach().cpu().numpy(), y_hat.shape[-1])
output = y_hat.view(-1).detach().cpu().numpy()
target = causal_crop(y_batch.view(-1).detach().cpu().numpy(), y_hat.shape[-1])

# apply highpass to output
sos = scipy.signal.butter(8, 20.0, fs=sample_rate, output="sos", btype="highpass")
output = scipy.signal.sosfilt(sos, output)

input /= np.max(np.abs(input))
output /= np.max(np.abs(output))
target /= np.max(np.abs(target))

fig, ax = plt.subplots(nrows=1, sharex=True)
librosa.display.waveshow(target, sr=sample_rate, alpha=0.5, ax=ax, label='Target')
librosa.display.waveshow(output, sr=sample_rate, color='r', alpha=0.5, ax=ax, label='Output')

print("Input (clean)")
IPython.display.display(IPython.display.Audio(data=input, rate=sample_rate))
print("Target")
IPython.display.display(IPython.display.Audio(data=target, rate=sample_rate))
print("Output")
IPython.display.display(IPython.display.Audio(data=output, rate=sample_rate))
plt.legend()
plt.show(fig)

## 2D Plot
Now we can generate a 2D plot of the parameter space...

In [None]:
#@title Generate plot
size = 22 * 2
max_cond = 5
min_cond = -5
values = np.zeros((size,size))
space = np.linspace(min_cond, max_cond, num=size, dtype=np.float32)

model.eval()

if effect_type == "Reverb": 
  impulse = torch.zeros(1, 1, 65536*3)
  impulse[..., 16384*2] = 1.2
  impulse = impulse.to(device)
  for xidx, x_c in enumerate(tqdm(space)):
    for yidx, y_c in enumerate(space):
      c = torch.tensor([x_c,y_c], device=device).view(1,1,-1).float()
      with torch.no_grad():
        y_hat = model(impulse, c).cpu().view(-1)
      sos = scipy.signal.butter(8, 20.0, fs=sample_rate, output="sos", btype="highpass")
      y_hat = scipy.signal.sosfilt(sos, y_hat.numpy())
      rt60 = measure_rt60(y_hat)
      rt60_sec = rt60 / sample_rate
      values[xidx,yidx] = rt60_sec

  fig, ax = plt.subplots(1,1, figsize=(5,5))
  img = plt.imshow(values, interpolation='nearest', origin="lower")
  ticks = np.linspace(min_cond, max_cond, num=10)
  ticks_str = [f"{t:0.1f}" for t in ticks]
  ax.set_xticks([])
  ax.set_xticklabels([])
  ax.set_yticks([])
  ax.set_yticklabels([])
  cbar = fig.colorbar(img,fraction=0.046, pad=0.04)
  cbar.set_label(r"$T_{60}$ (sec)")
  ax.set_xlabel(r"$c_0$")
  ax.set_ylabel(r"$c_1$")
  
elif effect_type == "Compressor":
  test_signal, _ = torchaudio.load("piano_clean.wav")
  test_signal = test_signal[0,:65536*3]
  test_signal = test_signal.view(1,1,-1)
  test_signal = test_signal.to(device)

  meter = pyln.Meter(sample_rate)

  for xidx, x_c in enumerate(tqdm(space)):
    for yidx, y_c in enumerate(space):
      c = torch.tensor([x_c,y_c], device=device).view(1,1,-1)
      with torch.no_grad():
        y_hat = model(test_signal, c).cpu().view(-1)
      sos = scipy.signal.butter(8, 20.0, fs=sample_rate, output="sos", btype="highpass")
      y_hat = scipy.signal.sosfilt(sos, y_hat.numpy())
      y_hat /= np.max(np.abs(y_hat))
      dB_lufs = meter.integrated_loudness(y_hat.reshape(-1,1))
      values[xidx,yidx] = dB_lufs

  fig, ax = plt.subplots(1,1, figsize=(5,5))
  img = plt.imshow(values, interpolation='nearest', origin="lower")
  ax.set_xticks([])
  ax.set_xticklabels([])
  ax.set_yticks([])
  ax.set_yticklabels([])
  cbar = fig.colorbar(img,fraction=0.046, pad=0.04)
  cbar.set_label("dBFS LUFS")
  ax.set_xlabel(r"$c_0$")
  ax.set_ylabel(r"$c_1$")


In [None]:
fig.tight_layout()
fig.savefig("plot.pdf", dpi=300)
files.download("plot.pdf")

## Process new sounds

In [None]:
x_whole, sample_rate = torchaudio.load("acgtr_clean.wav")
x_whole = torch.nn.functional.pad(x_whole, (rf-1, rf-1))
x_whole = x_whole[0,:]
x_whole = x_whole.view(1,1,-1).to(device)
c_rand = torch.tensor([-0.1,0.0], device=device).view(1,1,-1)

with torch.no_grad():
  y_whole = model(0.2 * x_whole, c_rand)
  x_whole = causal_crop(x_whole, y_whole.shape[-1])

y_whole /= y_whole.abs().max()

# apply high pass filter to remove DC
sos = scipy.signal.butter(8, 20.0, fs=sample_rate, output="sos", btype="highpass")
y_whole = scipy.signal.sosfilt(sos, y_whole.cpu().view(-1).numpy())

# remove start transient
y_whole = y_whole[4410:]
x_whole = x_whole.view(-1)[4410:].cpu().numpy()

y_whole = (y_whole * 0.8)
IPython.display.display(IPython.display.Audio(data=x_whole, rate=sample_rate))
IPython.display.display(IPython.display.Audio(data=y_whole, rate=sample_rate))

x_whole /= np.max(np.abs(x_whole))
y_whole /= np.max(np.abs(y_whole))

fig, ax = plt.subplots(nrows=1, sharex=True)
librosa.display.waveshow(y_whole, sr=sample_rate, color='r', alpha=0.5, ax=ax, label='Output')
librosa.display.waveshow(causal_crop(x_whole, y_whole.shape[-1]), sr=sample_rate, alpha=0.5, ax=ax, label='Input')
plt.legend()
plt.show(fig)


In [None]:
#torch.save(model, "./reverb_full.pt")
#torch.save(model, "./compressor_full.pt")
torch.save(model, "./delay_full.pt")


In [None]:
#files.download("./reverb_full.pt")
#files.download("./compressor_full.pt")
files.download("./delay_full.pt")
