# Baselines

In [3]:
import os
import torch
import IPython
import auraloss
import torchaudio
import numpy as np
import scipy.signal
import torch.nn.functional as F
import pyloudnorm as pyln
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import IPython.display as ipd
%matplotlib inline

ModuleNotFoundError: No module named 'src'

In [None]:
# Global variables
DATA_DIR = Path("data/raw/")
CONVERTED_DIR = Path("data/springset_converted")
PROCESSED_DIR = Path("data/processed")
MODELS_DIR = Path("models/")
PLOTS_DIR = Path("plots/")

sample_rate = 16000

In [None]:
def causal_crop(x, length: int):
    if x.shape[-1] != length:
        stop = x.shape[-1] - 1
        start = stop - length
        x = x[..., start:stop]
    return x

class FiLM(torch.nn.Module):
    def __init__(
        self,
        cond_dim,  # dim of conditioning input
        num_features,  # dim of the conv channel
        batch_norm=True,
    ):
        super().__init__()
        self.num_features = num_features
        self.batch_norm = batch_norm
        if batch_norm:
            self.bn = torch.nn.BatchNorm1d(num_features, affine=False)
        self.adaptor = torch.nn.Linear(cond_dim, num_features * 2)

    def forward(self, x, cond):

        cond = self.adaptor(cond)
        g, b = torch.chunk(cond, 2, dim=-1)
        g = g.permute(0, 2, 1)
        b = b.permute(0, 2, 1)

        if self.batch_norm:
            x = self.bn(x)  # apply BatchNorm without affine
        x = (x * g) + b  # then apply conditional affine

        return x

class TCNBlock(torch.nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, dilation, cond_dim=None, activation=True):
    super().__init__()
    self.conv = torch.nn.Conv1d(
        in_channels,
        out_channels,
        kernel_size,
        dilation=dilation,
        padding=0, #((kernel_size-1)//2)*dilation,
        bias=True)
    if cond_dim > 0:
      self.film = FiLM(cond_dim, out_channels, batch_norm=False)
    if activation:
      #self.act = torch.nn.Tanh()
      self.act = torch.nn.PReLU()
    self.res = torch.nn.Conv1d(in_channels, out_channels, 1, bias=False)

  def forward(self, x, c=None):
    x_in = x
    x = self.conv(x)
    if hasattr(self, "film"):
      x = self.film(x, c)
    if hasattr(self, "act"):
      x = self.act(x)
    x_res = causal_crop(self.res(x_in), x.shape[-1])
    x = x + x_res

    return x

class TCN(torch.nn.Module):
  def __init__(self, n_inputs=1, n_outputs=1, n_blocks=10, kernel_size=13, n_channels=64, dilation_growth=4, cond_dim=0):
    super().__init__()
    self.kernel_size = kernel_size
    self.n_channels = n_channels
    self.dilation_growth = dilation_growth
    self.n_blocks = n_blocks
    self.stack_size = n_blocks

    self.blocks = torch.nn.ModuleList()
    for n in range(n_blocks):
      if n == 0:
        in_ch = n_inputs
        out_ch = n_channels
        act = True
      elif (n+1) == n_blocks:
        in_ch = n_channels
        out_ch = n_outputs
        act = True
      else:
        in_ch = n_channels
        out_ch = n_channels
        act = True

      dilation = dilation_growth ** n
      self.blocks.append(TCNBlock(in_ch, out_ch, kernel_size, dilation, cond_dim=cond_dim, activation=act))

  def forward(self, x, c=None):
    c = torch.tensor([-0.1,0.0]).view(1,1,-1)
    for block in self.blocks:
      x = block(x, c)

    return x

  def compute_receptive_field(self):
    """Compute the receptive field in samples."""
    rf = self.kernel_size
    for n in range(1, self.n_blocks):
        dilation = self.dilation_growth ** (n % self.stack_size)
        rf = rf + ((self.kernel_size - 1) * dilation)
    return rf

In [None]:
# setup the pre-trained models
model_verb = torch.load("models/reverb_full.pt", map_location="cpu").eval()

## 1. Pre-trained models

In [None]:
process_file = os.path.join(CONVERTED_DIR, 'stack_Y_train.wav')
x_p, sample_rate = torchaudio.load(process_file)
print(process_file, x_p.shape)
ipd.Audio(data=x_p, rate=sample_rate)

NameError: name 'CONVERTED_DIR' is not defined

Now set the audio effect parameters.
Here are some more insights into the controls:

- `effect_type` - Choose from one of the pre-trained models.
- `gain_dB` - Adjust the input gain. This can have a big effect since the effects are very nonlinear.
- `c0` and `c1` - These are the effect controls which will adjust perceptual aspects of the effect, depending on the effect type. Very large values will often result in more extreme effects.
- `mix` - Control the wet/dry mix of the effect.
- `width` - Increase stereo width of the effect.
- `max_length` - If you uploaded a very long file this will truncate it.
- `stereo` - Convert mono input to stereo output.
- `tail` - If checked, we will also compute the effect tail (nice for reverbs).


In [None]:
effect_type = "Reverb" #@param ["Compressor", "Reverb", "Amp", "Analog Delay", "Synth2Synth"]
gain_dB = -18 #@param {type:"slider", min:-24, max:24, step:0.1}
c0 = 0 #@param {type:"slider", min:-10, max:10, step:0.1}
c1 = 0 #@param {type:"slider", min:-10, max:10, step:0.1}
mix = 100 #@param {type:"slider", min:0, max:100, step:1}
width = 100 #@param {type:"slider", min:0, max:100, step:1}
max_length = 120 #@param {type:"slider", min:5, max:120, step:1}
stereo = False #@param {type:"boolean"}
tail = False #@param {type:"boolean"}

# select model type
if effect_type == "Reverb":
  pt_model = model_verb

# measure the receptive field
pt_model_rf = pt_model.compute_receptive_field()
pt_model_params = sum(p.numel() for p in pt_model.parameters() if p.requires_grad)

print(f"Parameters: {pt_model_params*1e-3:0.3f} k")
print(f"Receptive field: {pt_model_rf} samples or {(pt_model_rf/sample_rate)*1e3:0.1f} ms")

# crop input signal if needed
max_samples = int(sample_rate * max_length)
x_p_crop = x_p[:,:max_samples]
chs = x_p_crop.shape[0]

# if mono and stereo requested
if chs == 1 and stereo:
  x_p_crop = x_p_crop.repeat(2,1)
  chs = 2

# pad the input signal
front_pad = pt_model_rf-1
back_pad = 0 if not tail else front_pad
x_p_pad = torch.nn.functional.pad(x_p_crop, (front_pad, back_pad))

# design highpass filter
sos = scipy.signal.butter(
    8,
    20.0,
    fs=sample_rate,
    output="sos",
    btype="highpass"
)

# compute linear gain
gain_ln = 10 ** (gain_dB / 20.0)

# process audio with pre-trained model
with torch.no_grad():
  y_hat = torch.zeros(x_p_crop.shape[0], x_p_crop.shape[1] + back_pad)
  for n in range(chs):
    if n == 0:
      factor = (width*5e-3)
    elif n == 1:
      factor = -(width*5e-3)
    c = torch.tensor([float(c0+factor), float(c1+factor)]).view(1,1,-1)
    y_hat_ch = pt_model(gain_ln * x_p_pad[n,:].view(1,1,-1), c)
    y_hat_ch = scipy.signal.sosfilt(sos, y_hat_ch.view(-1).numpy())
    y_hat_ch = torch.tensor(y_hat_ch)
    y_hat[n,:] = y_hat_ch

# pad the dry signal
x_dry = torch.nn.functional.pad(x_p_crop, (0,back_pad))

# normalize each first
y_hat /= y_hat.abs().max()
x_dry /= x_dry.abs().max()

# mix
mix = mix/100.0
y_hat = (mix * y_hat) + ((1-mix) * x_dry)

# remove transient
y_hat = y_hat[...,8192:]
y_hat /= y_hat.abs().max()

torchaudio.save("output.mp3", y_hat.view(chs,-1), sample_rate, compression=320.0)
print("Done.")
print("Sending audio to browser...")

# show the audio
IPython.display.display(IPython.display.Audio("output.mp3"))

In [None]:
print(pt_model)

In [None]:
target_file = os.path.join(CONVERTED_DIR, "stack_Y_train.wav")
target, sample_rate = torchaudio.load(target_file)
print(target_file, target[0].shape)

In [None]:
def peak_normalize(tensor):
    # max_values = torch.max(torch.abs(tensor), dim=1, keepdim=True).values
    # normalized_tensor = tensor / max_values
    # return normalized_tensor
    torch.nn.functional.normalize(tensor, p=2, dim=1)
    return tensor

batch_size = 1

trainset = SpringDataset(root_dir=DATA_DIR, split='train', transform=peak_normalize)
train_size = int(0.8 * len(trainset))
val_size = len(trainset) - train_size
train, valid = torch.utils.data.random_split(trainset, [train_size, val_size])

train_loader = torch.utils.data.DataLoader(train, batch_size, num_workers=0, shuffle=True, drop_last=True)
valid_loader = torch.utils.data.DataLoader(valid, batch_size, num_workers=0, shuffle=False, drop_last=True)

testset = SpringDataset(root_dir=DATA_DIR, split="test", transform=peak_normalize)
test_loader = torch.utils.data.DataLoader(testset, batch_size, num_workers=0, drop_last=True)

In [None]:
pt_model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = pt_model.to(device)

mae = torch.nn.L1Loss().to(device)
dc = auraloss.time.DCLoss().to(device)
esr = auraloss.time.ESRLoss().to(device)
mrstft =  auraloss.freq.MultiResolutionSTFTLoss(
    fft_sizes=[32, 128, 512, 2048],
    win_lengths=[32, 128, 512, 2048],
    hop_sizes=[16, 64, 256, 1024],
    sample_rate=sample_rate,
    perceptual_weighting=False,
    ).to(device)
    
criterions = {"mae": mae, "esr": esr, "dc": dc, "mrstft": mrstft}
test_results = {"mae": [], "esr": [], "dc": [], "mrstft": []}
metrics = [mae, esr, dc, mrstft]

# Lists to store the metrics

with torch.no_grad():
    for n_iters, (input, target) in enumerate(test_loader):
        input = input.to(device)
        target = target.to(device)
        rf = pt_model.compute_receptive_field()

        input = torch.nn.functional.pad(input, (rf-1, rf-1))
        input = input[:, 0, :]  # Select the first channel
        c_rand = torch.tensor([-0.1, 0.0], device=device).view(1, 1, -1)

        output = pt_model(0.2 * input, c_rand)
        input = causal_crop(input, output.shape[-1])
        output /= output.abs().max()

        # # apply high pass filter to remove DC
        # sos = scipy.signal.butter(8, 20.0, fs=sample_rate, output="sos", btype="highpass")
        # output = scipy.signal.sosfilt(sos, output.cpu().view(-1).numpy())

        # remove start transient
        output = output[4410:]
        input = input.view(-1)[4410:].cpu().numpy()

        output = torch.from_numpy(output * 0.8).to(device)
        input = torch.from_numpy(input / np.max(np.abs(input))).to(device)
        output = output.view(1, 1, -1)  # Reshape the output tensor
    
        min_length = min(output.shape[2], target.shape[2])
        output = output[:, :, :min_length]
        target = target[:, :, :min_length]
        # print(f"output:{output.shape},target:{target.shape}")
        
        # Compute metrics means for current batch
        for name, metric in criterions.items():
            # print(name, metric(output, target).item())
            batch_score = metric(output, target).item()
            test_results[name].append(batch_score)
            
# # Compute the mean values/
for metric_name, metric_values in zip(metrics, test_results.values()):
    mean_value = sum(metric_values) / len(metric_values)
    print(f'{metric_name} Mean value: {mean_value:.6f}')

## 2. Training on the Spring Dataset

In [None]:
def peak_normalize(tensor):
    return tensor / tensor.abs().max()

def stack_samples(batch):
    dry_samples = [x for x, y in batch]
    wet_samples = [y for x, y in batch]

    stacked_dry_samples = torch.cat(dry_samples, dim=1)
    stacked_wet_samples = torch.cat(wet_samples, dim=1)

    return stacked_dry_samples, stacked_wet_samples

trainset = SpringDataset(root_dir=DATA_DIR, split='train', transform=peak_normalize)

In [None]:
# Set the desired length of the generated array
desired_length = 100

# Randomly select a certain number of items from the dataset
selected_indices = np.random.choice(len(trainset), size=desired_length, replace=False)

# Create a list to store the selected items
selected_items = []

# Iterate over the selected indices and retrieve the corresponding items
for index in selected_indices:
    item = trainset[index]
    selected_items.append(item)

# Use the stack_samples() function to stack the selected samples
stacked_dry_samples, stacked_wet_samples = stack_samples(selected_items)

# Print the shapes of the stacked samples
print(f"Stacked dry samples shape: {stacked_dry_samples.shape}")
print(f"Stacked wet samples shape: {stacked_wet_samples.shape}")

<img src="plots/pretrained_reverb_params.png" alt="Screenshot" width="700">


In [None]:
# set TCN model parameters:
cond_dim = 2 #@param {type:"slider", min:1, max:10, step:1}
kernel_size = 9 #@param {type:"slider", min:3, max:32, step:1}
n_blocks = 5 #@param {type:"slider", min:2, max:30, step:1}
dilation_growth = 10 #@param {type:"slider", min:1, max:10, step:1}
n_channels = 32 #@param {type:"slider", min:1, max:128, step:1}
n_iters = 2500 #@param {type:"slider", min:0, max:10000, step:1}
length = 262144 #@param {type:"slider", min:0, max:524288, step:1}
lr = 0.001 #@param {type:"number"}

x = stacked_dry_samples
y = stacked_wet_samples

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

# reshape the audio
x_batch = x.view(1,x.shape[0],-1)
y_batch = y.view(1,y.shape[0],-1)
c = torch.tensor([0.0, 0.0], device=device).view(1,1,-1)

# crop length
x_batch = x_batch[:,0:1,:]
y_batch = y_batch[:,0:1,:]

_, x_ch, x_samp = x_batch.size()
_, y_ch, y_samp = y_batch.size()

# build the model
model = TCN(
    n_inputs=x_ch,
    n_outputs=y_ch,
    cond_dim=cond_dim,
    kernel_size=kernel_size,
    n_blocks=n_blocks,
    dilation_growth=dilation_growth,
    n_channels=n_channels)
rf = model.compute_receptive_field()
params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Parameters: {params*1e-3:0.3f} k")
print(f"Receptive field: {rf} samples or {(rf/sample_rate)*1e3:0.1f} ms")

# setup loss function, optimizer, and scheduler
loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
    fft_sizes=[32, 128, 512, 2048],
    win_lengths=[32, 128, 512, 2048],
    hop_sizes=[16, 64, 256, 1024])
loss_fn_l1 = torch.nn.L1Loss()

optimizer = torch.optim.Adam(model.parameters(), lr)
ms1 = int(n_iters * 0.8)
ms2 = int(n_iters * 0.95)
milestones = [ms1, ms2]
print(
    "Learning rate schedule:",
    f"1:{lr:0.2e} ->",
    f"{ms1}:{lr*0.1:0.2e} ->",
    f"{ms2}:{lr*0.01:0.2e}",
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones,
    gamma=0.1,
    verbose=False,
)

# move tensors to GPU
if torch.cuda.is_available():
    model.to(device)
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)
    c = c.to(device)

# pad input so that output is same size as input
#x_pad = torch.nn.functional.pad(x_batch, (rf-1, 0))

# iteratively update the weights
pbar = tqdm(range(n_iters))
for n in pbar:
    optimizer.zero_grad()

    start_idx = rf #np.random.randint(rf, x_batch.shape[-1]-length-1)
    stop_idx = start_idx + length
    x_crop = x_batch[...,start_idx-rf+1:stop_idx]
    y_crop = y_batch[...,start_idx:stop_idx]

    y_hat = model(x_crop, c)
    loss = loss_fn(y_hat, y_crop) #+ loss_fn_l1(y_hat, y_crop)

    loss.backward()
    optimizer.step()

    scheduler.step()
    if (n+1) % 1 == 0:
        pbar.set_description(f" Loss: {loss.item():0.3e} | ")

y_hat /= y_hat.abs().max()

model.eval()
x_pad = torch.nn.functional.pad(x_batch, (rf-1, 0))
with torch.no_grad():
    y_hat = model(x_pad, c)

input = causal_crop(x_batch.view(-1).detach().cpu().numpy(), y_hat.shape[-1])
output = y_hat.view(-1).detach().cpu().numpy()
target = causal_crop(y_batch.view(-1).detach().cpu().numpy(), y_hat.shape[-1])

# apply highpass to output
sos = scipy.signal.butter(8, 20.0, fs=sample_rate, output="sos", btype="highpass")
output = scipy.signal.sosfilt(sos, output)

input /= np.max(np.abs(input))
output /= np.max(np.abs(output))
target /= np.max(np.abs(target))

In [None]:
# Save and load your spring model to inference with CPU
model_fn = os.path.join(MODELS_DIR, "baseline_spring_reverb.pt")
torch.save(model, model_fn)
my_spring_reverb_model = torch.load(model_fn)

## Evaluation

In [None]:
# Load the test set
testset = SpringDataset(root_dir=DATA_DIR, split="test")
test_loader = torch.utils.data.DataLoader(testset, batch_size=1, drop_last=True)

Found 4 files in ../datasets/plate-spring/spring
Using dry_val_test.h5 and wet_val_test.h5 for test split.


In [None]:
# setup the pre-trained models
model = torch.load("models/baseline_spring_reverb.pt", map_location="cpu").eval()

In [None]:
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
rf = model.compute_receptive_field()
mae = torch.nn.L1Loss().to(device)
dc = auraloss.time.DCLoss().to(device)
esr = auraloss.time.ESRLoss().to(device)
mrstft =  auraloss.freq.MultiResolutionSTFTLoss(
    fft_sizes=[32, 128, 512, 2048],
    win_lengths=[32, 128, 512, 2048],
    hop_sizes=[16, 64, 256, 1024],
    sample_rate=sample_rate,
    perceptual_weighting=False,
    ).to(device)
    
criterions = {"mae": mae, "esr": esr, "dc": dc, "mrstft": mrstft}
test_results = {"mae": [], "esr": [], "dc": [], "mrstft": []}
metrics = [mae, esr, dc, mrstft]

# Lists to store the metrics

with torch.no_grad():
    for n_iters, (input, target) in enumerate(test_loader):

        input = torch.nn.functional.pad(input, (rf-1, rf-1))
        input = input[0,:]
        input = input.view(1,1,-1).to(device)
        c_rand = torch.tensor([-0.1,0.0], device=device).view(1,1,-1)

        output = model(input)
        input = causal_crop(input, output.shape[-1])

        output /= output.abs().max()
        output = output.view(-1).cpu().numpy()
        # # apply high pass filter to remove DC
        # sos = scipy.signal.butter(8, 20.0, fs=sample_rate, output="sos", btype="highpass")
        # output = scipy.signal.sosfilt(sos, output.cpu().view(-1).numpy())

        # remove start transient
        output = output[4410:]
        input = input.view(-1)[4410:].cpu().numpy()

        output = torch.from_numpy(output * 0.8).to(device)
        input = torch.from_numpy(input / np.max(np.abs(input))).to(device)
        output = output.view(1, 1, -1)  # Reshape the output tensor
    
        min_length = min(output.shape[2], target.shape[2])
        output = output[:, :, :min_length]
        target = target[:, :, :min_length]
        # print(f"output:{output.shape},target:{target.shape}")

        # Compute metrics means for current batch
        for name, metric in criterions.items():
            batch_score = metric(output, target).item()
            test_results[name].append(batch_score)

# # Compute the mean values/
for metric_name, metric_values in zip(metrics, test_results.values()):
    mean_value = sum(metric_values) / len(metric_values)
    print(f'{metric_name} Mean value: {mean_value:.6f}')

L1Loss() Mean value: 0.205089
ESRLoss() Mean value: 5.463211
DCLoss() Mean value: 0.160171
MultiResolutionSTFTLoss(
  (stft_losses): ModuleList(
    (0-3): 4 x STFTLoss(
      (spectralconv): SpectralConvergenceLoss()
      (logstft): STFTMagnitudeLoss(
        (distance): L1Loss()
      )
      (linstft): STFTMagnitudeLoss(
        (distance): L1Loss()
      )
    )
  )
) Mean value: 2.697644


In [None]:
print(model)