# Evaluation of the dataset with the baseline model

Sources from https://github.com/csteinmetz1/steerable-nafx/blob/main/steerable-nafx.ipynb

For this experiment we will use the concatenated subset to train the model and get some results to compare with a pretrained model.
We will use a TCN model with FiLM layer. 

## Setup

In [3]:
import os
import torch
import matplotlib.pyplot as plt
import IPython.display as ipd
import librosa.display
import auraloss
import torchaudio
import numpy as np
import scipy.signal
from pathlib import Path

from tqdm.notebook import tqdm
%matplotlib inline

In [5]:
# Download dataset
if os.path.exists('../datasets/plate-spring') == False:
    !wget -P ../data/raw/ https://zenodo.org/record/3746119/files/plate-spring.zip
    !unzip ../datasets/plate-spring.zip -d ../datasets/
else:
    print('Dataset already downloaded.')

--2023-09-09 18:30:04--  https://zenodo.org/record/3746119/files/plate-spring.zip
Resolving zenodo.org (zenodo.org)... 188.185.124.72
Connecting to zenodo.org (zenodo.org)|188.185.124.72|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 504370887 (481M) [application/octet-stream]
Saving to: ‘../data/raw/plate-spring.zip’


2023-09-09 18:35:33 (1,46 MB/s) - ‘../data/raw/plate-spring.zip’ saved [504370887/504370887]

Archive:  ../datasets/plate-spring.zip
  End-of-central-directory signature not found.  Either this file is not
  a zipfile, or it constitutes one disk of a multi-part archive.  In the
  latter case the central directory and zipfile comment will be found on
  the last disk(s) of this archive.
unzip:  cannot find zipfile directory in one of ../datasets/plate-spring.zip or
        ../datasets/plate-spring.zip.zip, and cannot find ../datasets/plate-spring.zip.ZIP, period.


In [None]:
# Global variables
DATA_DIR = '../datasets/plate-spring/spring'
CONVERTED_DIR = Path("audio/springset_converted")
PROCESSED_DIR = Path("audio/processed")
MODELS_DIR = "models/"
PLOTS_DIR = "plots/"


sample_rate = 16000

## Implementation of the TCN model with FiLM layers for training and inference 

In [None]:
def causal_crop(x, length: int):
    if x.shape[-1] != length:
        stop = x.shape[-1] - 1
        start = stop - length
        x = x[..., start:stop]
    return x

class FiLM(torch.nn.Module):
    def __init__(
        self,
        cond_dim,  # dim of conditioning input
        num_features,  # dim of the conv channel
        batch_norm=True,
    ):
        super().__init__()
        self.num_features = num_features
        self.batch_norm = batch_norm
        if batch_norm:
            self.bn = torch.nn.BatchNorm1d(num_features, affine=False)
        self.adaptor = torch.nn.Linear(cond_dim, num_features * 2)

    def forward(self, x, cond):

        cond = self.adaptor(cond)
        g, b = torch.chunk(cond, 2, dim=-1)
        g = g.permute(0, 2, 1)
        b = b.permute(0, 2, 1)

        if self.batch_norm:
            x = self.bn(x)  # apply BatchNorm without affine
        x = (x * g) + b  # then apply conditional affine

        return x

class TCNBlock(torch.nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, dilation, cond_dim=0, activation=True):
    super().__init__()
    self.conv = torch.nn.Conv1d(
        in_channels, 
        out_channels, 
        kernel_size, 
        dilation=dilation, 
        padding=0, #((kernel_size-1)//2)*dilation,
        bias=True)
    if cond_dim > 0:
      self.film = FiLM(cond_dim, out_channels, batch_norm=False)
    if activation:
      #self.act = torch.nn.Tanh()
      self.act = torch.nn.PReLU()
    self.res = torch.nn.Conv1d(in_channels, out_channels, 1, bias=False)

  def forward(self, x, c=None):
    x_in = x
    x = self.conv(x)
    if hasattr(self, "film"):
      x = self.film(x, c)
    if hasattr(self, "act"):
      x = self.act(x)
    x_res = causal_crop(self.res(x_in), x.shape[-1])
    x = x + x_res

    return x

class TCN(torch.nn.Module):
  def __init__(self, n_inputs=1, n_outputs=1, n_blocks=10, kernel_size=13, n_channels=64, dilation_growth=4, cond_dim=0):
    super().__init__()
    self.kernel_size = kernel_size
    self.n_channels = n_channels
    self.dilation_growth = dilation_growth
    self.n_blocks = n_blocks
    self.stack_size = n_blocks

    self.blocks = torch.nn.ModuleList()
    for n in range(n_blocks):
      if n == 0:
        in_ch = n_inputs
        out_ch = n_channels
        act = True
      elif (n+1) == n_blocks:
        in_ch = n_channels
        out_ch = n_outputs
        act = True
      else:
        in_ch = n_channels
        out_ch = n_channels
        act = True
      
      dilation = dilation_growth ** n
      self.blocks.append(TCNBlock(in_ch, out_ch, kernel_size, dilation, cond_dim=cond_dim, activation=act))

  def forward(self, x, c=None):
    for block in self.blocks:
      x = block(x, c)

    return x
  
  def compute_receptive_field(self):
    """Compute the receptive field in samples."""
    rf = self.kernel_size
    for n in range(1, self.n_blocks):
        dilation = self.dilation_growth ** (n % self.stack_size)
        rf = rf + ((self.kernel_size - 1) * dilation)
    return rf

## Pretrained Model

In [None]:
#@title Load the pre-trained model
model_path = os.path.join(MODELS_DIR, "reverb_full.pt")
model_verb = torch.load(model_path, map_location="cpu").eval()
print(model_verb)

In [None]:
#@title Load validation data for testing
process_file = os.path.join(CONVERTED_DIR, 'stack_X_test.wav')
x_p, sample_rate = torchaudio.load(process_file)
print(process_file, x_p.shape)
ipd.Audio(data=x_p, rate=sample_rate)

Now set the audio effect parameters. Here are some more insights into the controls:

*   `effect_type` - Choose from one of the pre-trained models.
*   `gain_dB` - Adjust the input gain. This can have a big effect since the effects are very nonlinear.
*   `c0` and `c1` - These are the effect controls which will adjust perceptual aspects of the effect, depending on the effect type. Very large values will often result in more extreme effects.
*   `mix` - Control the wet/dry mix of the effect.
*   `width` - Increase stereo width of the effect.
*   `max_length` - If you uploaded a very long file this will truncate it.
*   `stereo` - Convert mono input to stereo output.
*   `tail` - If checked, we will also compute the effect tail (nice for reverbs).
*   `output_file` - Output file path. It avoids overwriting the processed files. 








In [None]:
#@title Setting parameters and processing
gain_dB = 0 #@param {type:"slider", min:-24, max:24, step:0.1}
c0 = -10 #@param {type:"slider", min:-10, max:10, step:0.1}
c1 = -10 #@param {type:"slider", min:-10, max:10, step:0.1}
mix = 100 #@param {type:"slider", min:0, max:100, step:1}
width = 50 #@param {type:"slider", min:0, max:100, step:1}
max_length = 129 #@param {type:"slider", min:5, max:130, step:1}
stereo = False #@param {type:"boolean"}
tail = True #@param {type:"boolean"}
save_pred_file = os.path.join(PROCESSED_DIR, "pretrained_y_hat.wav") #@param {type: "string"}

# measure the receptive field
pt_model_rf = model_verb.compute_receptive_field()

# crop input signal if needed
max_samples = int(sample_rate * max_length)
x_p_crop = x_p[:,:max_samples]
chs = x_p_crop.shape[0]

# if mono and stereo requested
if chs == 1 and stereo:
  x_p_crop = x_p_crop.repeat(2,1)
  chs = 2

# pad the input signal
front_pad = pt_model_rf-1
back_pad = 0 if not tail else front_pad
x_p_pad = torch.nn.functional.pad(x_p_crop, (front_pad, back_pad))

# design highpass filter
sos = scipy.signal.butter(
    8, 
    20.0, 
    fs=sample_rate, 
    output="sos", 
    btype="highpass"
)

# compute linear gain 
gain_ln = 10 ** (gain_dB / 20.0)

# process audio with pre-trained model
with torch.no_grad():
  y_hat = torch.zeros(x_p_crop.shape[0], x_p_crop.shape[1] + back_pad)
  for n in range(chs):
    if n == 0:
      factor = (width*5e-3)
    elif n == 1:
      factor = -(width*5e-3)
    c = torch.tensor([float(c0+factor), float(c1+factor)]).view(1,1,-1)
    y_hat_ch = model_verb(gain_ln * x_p_pad[n,:].view(1,1,-1), c)
    y_hat_ch = scipy.signal.sosfilt(sos, y_hat_ch.view(-1).numpy())
    y_hat_ch = torch.tensor(y_hat_ch)
    y_hat[n,:] = y_hat_ch

# pad the dry signal 
x_dry = torch.nn.functional.pad(x_p_crop, (0,back_pad))

# normalize each first
y_hat /= y_hat.abs().max()
x_dry /= x_dry.abs().max()

# mix
mix = mix/100.0
y_hat = (mix * y_hat) + ((1-mix) * x_dry)

# remove transient
y_hat = y_hat[...,8192:]
y_hat /= y_hat.abs().max()

# save and preview the audio
torchaudio.save(save_pred_file, y_hat.view(chs,-1), sample_rate, encoding="PCM_S", bits_per_sample=16)
ipd.Audio(save_pred_file)

In [None]:
def overlap_waveforms(output, target, sample_rate, start, end, title):    
    o_zoom = output[start:end]
    t_zoom = target[start:end]

    # create time vector
    time = range(start, end)

    plt.figure(figsize=(12,4))
    plt.plot(time, o_zoom, alpha=0.5, label="Output")
    plt.plot(time, t_zoom, alpha=0.5, label="Target")
    plt.xlabel("Time (samples)")
    plt.ylabel("Amplitude")
    plt.title(title)
    plt.legend()
    plt.grid()
    plt.show()

In [None]:
start = 11000
end = 12000

N = 512
hop_size = 256

o_p, sample_rate = torchaudio.load(PROCESSED_DIR / "pretrained_y_hat.wav")
t_p, t_sr = torchaudio.load(CONVERTED_DIR / "stack_Y_test.wav")

output = o_p.numpy().squeeze()
target = t_p.numpy().squeeze()

overlap_waveforms(output, target, sample_rate, start, end, "Waveform comparison")

We can notice a time alignment issue between the target values and the model's outputs. This could potentially be attributed to various factors, one of them being that the data used for evaluation might not be the same as the one used during the training process of the model

## Train a new model with the chosen dataset

Using the same model architecture, we will train a new model with the chosen dataset. 
Same hyperparameters will be used for the training.

In [None]:
# Upload clean sound (Xtrain.wav)
input_file = os.path.join(CONVERTED_DIR, 'stack_X_train.wav')
x, sample_rate = torchaudio.load(input_file)
print(input_file, x.shape)

In [None]:
#@title Upload processed sound (Ytrain_0.wav) that will be the target
save_pred_file = os.path.join(CONVERTED_DIR, 'stack_Y_train.wav')
y, sample_rate = torchaudio.load(save_pred_file)

if not y.shape[-1] == x.shape[-1]:
  print(f"Input and output files are different lengths! Found clean: {x.shape[-1]} processed: {y.shape[-1]}.")
  if y.shape[-1] > x.shape[-1]:
    print(f"Cropping target...")
    y = y[:,:x.shape[-1]]
  else:
    print(f"Cropping input...")
    x = x[:,:y.shape[-1]]

print(save_pred_file, y.shape)

In [None]:
torch.cuda.empty_cache()

In [None]:
#TCN model hyperparameters and launch training
cond_dim = 2 #@param {type:"slider", min:1, max:10, step:1}
kernel_size = 9 #@param {type:"slider", min:3, max:32, step:1}
n_blocks = 5 #@param {type:"slider", min:2, max:30, step:1}
dilation_growth = 10 #@param {type:"slider", min:1, max:10, step:1}
n_channels = 32 #@param {type:"slider", min:1, max:128, step:1}
n_iters = 2500 #@param {type:"slider", min:0, max:10000, step:1}
length = 262144 #@param {type:"slider", min:0, max:524288, step:1}
lr = 0.001 #@param {type:"number"}

if torch.cuda.is_available():
  device = "cuda"
else:
  device = "cpu"

PYTORCH_NO_CUDA_MEMORY_CACHING=1

# reshape the audio
x_batch = x.view(1,x.shape[0],-1)
y_batch = y.view(1,y.shape[0],-1)
c = torch.tensor([0.0, 0.0], device=device).view(1,1,-1)

# crop length
x_batch = x_batch[:,0:1,:]
y_batch = y_batch[:,0:1,:]

_, x_ch, x_samp = x_batch.size()
_, y_ch, y_samp = y_batch.size()

# build the model
model = TCN(
    n_inputs=x_ch,
    n_outputs=y_ch,
    cond_dim=cond_dim, 
    kernel_size=kernel_size, 
    n_blocks=n_blocks, 
    dilation_growth=dilation_growth, 
    n_channels=n_channels)
rf = model.compute_receptive_field()
params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Parameters: {params*1e-3:0.3f} k")
print(f"Receptive field: {rf} samples or {(rf/sample_rate)*1e3:0.1f} ms")
print(model)

# setup loss function, optimizer, and scheduler
loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
    fft_sizes=[32, 128, 512, 2048],
    win_lengths=[32, 128, 512, 2048],
    hop_sizes=[16, 64, 256, 1024])
loss_fn_l1 = torch.nn.L1Loss()

optimizer = torch.optim.Adam(model.parameters(), lr)
ms1 = int(n_iters * 0.8)
ms2 = int(n_iters * 0.95)
milestones = [ms1, ms2]
print(
    "Learning rate schedule:",
    f"1:{lr:0.2e} ->",
    f"{ms1}:{lr*0.1:0.2e} ->",
    f"{ms2}:{lr*0.01:0.2e}",
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones,
    gamma=0.1,
    verbose=False,
)

# move tensors to GPU
if torch.cuda.is_available():
  model.to(device)
  x_batch = x_batch.to(device)
  y_batch = y_batch.to(device)
  c = c.to(device)

# pad input so that output is same size as input
x_pad = torch.nn.functional.pad(x_batch, (rf-1, 0))

# iteratively update the weights
pbar = tqdm(range(n_iters))
for n in pbar:
  optimizer.zero_grad()

  start_idx = rf #np.random.randint(rf, x_batch.shape[-1]-length-1)
  stop_idx = start_idx + length
  x_crop = x_batch[...,start_idx-rf+1:stop_idx]
  y_crop = y_batch[...,start_idx:stop_idx]

  y_hat = model(x_crop, c)
  loss = loss_fn(y_hat, y_crop) #+ loss_fn_l1(y_hat, y_crop)

  loss.backward()
  optimizer.step()
  
  scheduler.step()
  if (n+1) % 1 == 0:
    pbar.set_description(f" Loss: {loss.item():0.3e} | ")

y_hat /= y_hat.abs().max()

model.eval()
x_pad = torch.nn.functional.pad(x_batch, (rf-1, 0))
with torch.no_grad():
  y_hat = model(x_pad, c)

input = causal_crop(x_batch.view(-1).detach().cpu().numpy(), y_hat.shape[-1])
output = y_hat.view(-1).detach().cpu().numpy()
target = causal_crop(y_batch.view(-1).detach().cpu().numpy(), y_hat.shape[-1])

# apply highpass to output
sos = scipy.signal.butter(8, 20.0, fs=sample_rate, output="sos", btype="highpass")
output = scipy.signal.sosfilt(sos, output)

input /= np.max(np.abs(input))
output /= np.max(np.abs(output))
target /= np.max(np.abs(target))

In [None]:
time = 3.9 #@param {type:"slider", min:0, max:30, step:0.1}
center = 54 #@param {type:"slider", min:0, max:120, step:1}
time_window = [center - time, center + time]

fig, ax = plt.subplots(nrows=1, ncols=1, sharex=True, figsize=(10, 4))

librosa.display.waveshow(target,
                         sr=sample_rate, 
                         alpha=0.7, 
                         ax=ax, 
                         label='Target')
librosa.display.waveshow(output, 
                         sr=sample_rate,  
                         alpha=0.7, 
                         ax=ax, 
                         label='Output')
ax.set_title("New Model")

ax.set_xlabel('Time (min)')
ax.set_ylabel('Amplitude')
ax.set_xlim(time_window)
ax.grid(True)

ax.legend(loc='upper right')
plt.tight_layout()    
plt.savefig(os.path.join(PLOTS_DIR,'new_model_out_wave.png'))
print("Done.")

In [None]:
#Save and load your spring model to inference with CPU
model_fn = os.path.join(MODELS_DIR, "spring_reverb_new_nb_02.pt")
torch.save(model, model_fn)
my_spring_reverb_model = torch.load(model_fn)

In [None]:
save_pred_file = os.path.join(CONVERTED_DIR, 'stacked_Xvalidation.wav')
x_p, sample_rate = torchaudio.load(process_file)
print(process_file, x_p.shape)

## Evaluation of the new model

In [None]:
# Make an Inference

gain_dB = 0 #@param {type:"slider", min:-24, max:24, step:0.1}
c0 = -10 #@param {type:"slider", min:-10, max:10, step:0.1}
c1 = -10 #@param {type:"slider", min:-10, max:10, step:0.1}
mix = 100 #@param {type:"slider", min:0, max:100, step:1}
width = 50 #@param {type:"slider", min:0, max:100, step:1}
max_length = 128 #@param {type:"slider", min:5, max:128, step:1}
stereo = False #@param {type:"boolean"}
tail = True #@param {type:"boolean"}
output_file = os.path.join(PROCESSED_DIR, "new_model_out_inference.wav") #@param {type: "string"}

model.eval()
# measure the receptive field
model_rf = model.compute_receptive_field()

# crop input signal if needed
max_samples = int(sample_rate * max_length)
x_p_crop = x_p[:,:max_samples]
chs = x_p_crop.shape[0]

# if mono and stereo requested
if chs == 1 and stereo:
  x_p_crop = x_p_crop.repeat(2, 1)
  chs = 2

# pad the input signal
front_pad = model_rf - 1
back_pad = 0 if not tail else front_pad
x_p_pad = torch.nn.functional.pad(x_p_crop, (front_pad, back_pad))

# design highpass filter
sos = scipy.signal.butter(
    8, 
    20.0, 
    fs=sample_rate, 
    output="sos", 
    btype="highpass"
)

# compute linear gain 
gain_ln = 10 ** (gain_dB / 20.0)

# move data to cuda
x_p_crop = x_p_crop.to(device)
x_p_pad = x_p_pad.to(device)

# process audio with pre-trained model
with torch.no_grad():
  y_hat = torch.zeros(x_p_crop.shape[0], x_p_crop.shape[1] + back_pad, device=device)
  for n in range(chs):
    if n == 0:
      factor = (width * 5e-3)
    elif n == 1:
      factor = -(width * 5e-3)
    c = torch.tensor([float(c0+factor), float(c1+factor)], device=device).view(1, 1, -1).float()
    y_hat_ch = model(gain_ln * x_p_pad[n, :].view(1, 1, -1), c).cpu()
    y_hat_ch = scipy.signal.sosfilt(sos, y_hat_ch.view(-1).numpy())
    y_hat_ch = torch.tensor(y_hat_ch)
    y_hat[n,:] = y_hat_ch

# pad the dry signal 
x_dry = torch.nn.functional.pad(x_p_crop, (0, back_pad))

# normalize each first
y_hat /= y_hat.abs().max()
x_dry /= x_dry.abs().max()

# mix
mix = mix / 100.0
y_hat = (mix * y_hat) + ((1 - mix) * x_dry)

# remove transient
y_hat = y_hat[..., 8192:]
y_hat /= y_hat.abs().max()

torchaudio.save(output_file, y_hat.view(chs, -1).to("cpu"), sample_rate, encoding="PCM_S", bits_per_sample=16)

## Define some metrics

In [None]:
#@title Define ESR and MSE metrics
def error_to_signal(y, y_pred):
    """
    Error to signal ratio with pre-emphasis filter:
    https://www.mdpi.com/2076-3417/10/3/766/htm
    """
    # y, y_pred = pre_emphasis_filter(y), pre_emphasis_filter(y_pred)
    
    # Pad inputs to same size
    max_len = max(y.shape[1], y_pred.shape[1])
    y = torch.nn.functional.pad(y, (0, max_len - y.shape[1]))
    y_pred = torch.nn.functional.pad(y_pred, (0, max_len - y_pred.shape[1]))
    
    # Compute error to signal ratio
    error = torch.sum(torch.pow(y - y_pred, 2))
    signal = torch.sum(torch.pow(y, 2))
    esr = error / (signal + 1e-10)
    return esr


def pre_emphasis_filter(x, coeff=0.95):
    return torch.cat((x[:, 0:1], x[:, 1:] - coeff * x[:, :-1]), dim=1)


def mean_square_error(y, y_pred):
    """
    Compute mean square error between y and y_pred
    """
    # Pad inputs to same size
    max_len = max(y.shape[1], y_pred.shape[1])
    y = torch.nn.functional.pad(y, (0, max_len - y.shape[1]))
    y_pred = torch.nn.functional.pad(y_pred, (0, max_len - y_pred.shape[1]))
    
    # Compute mean square error
    mse = torch.nn.functional.mse_loss(y_pred, y)
    
    return mse.item()

In [None]:
#@title Compute ESR and MSE 
ground_truth = os.path.join(CONVERTED_DIR, 'stacked_Yvalidation_0.wav')
pred_1 = os.path.join(PROCESSED_DIR, 'pretrained-out.wav')
pred_2 = os.path.join(PROCESSED_DIR, 'new_model_out.wav')

print(torchaudio.info(ground_truth))
print(torchaudio.info(pred_1))
print(torchaudio.info(pred_2))

gt, sr = torchaudio.load(ground_truth)
p1, sr = torchaudio.load(pred_1)
p2, sr = torchaudio.load(pred_2)

max_len = max(gt.shape[-1], p1.shape[-1], p2.shape[-1])

# Pad the waveforms with zeros to the length of the longest waveform
gt = torch.nn.functional.pad(gt, (0, max_len-gt.shape[-1]), 'constant', 0)
p1 = torch.nn.functional.pad(p1, (0, max_len-p1.shape[-1]), 'constant', 0)
p2 = torch.nn.functional.pad(p2, (0, max_len-p2.shape[-1]), 'constant', 0)

p1_esr = error_to_signal(gt, p1)
p2_esr = error_to_signal(gt, p2)

print(f"First model ESR: {p1_esr}")
print(f"Second model ESR: {p2_esr}")

mse_1 = mean_square_error(gt, p1)
mse_2 = mean_square_error(gt, p2)

print(f"First model MSE: {mse_1}")
print(f"Second model MSE: {mse_2}")

## Plot the results

In [None]:
#@title Plot ESR and MSE
labels = ['Pretrained', 'New Model']
esr = [p1_esr, p2_esr]
mse = [mse_1, mse_2]
colors = ['red', 'green',]

# Create subplots
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(10, 3))

# Plot bar charts
ax1.bar(labels, esr, color=colors, alpha=0.6, width=0.4)
ax2.bar(labels, mse, color=colors, alpha=0.6, width=0.4)

# Set titles and labels
ax1.set_title('Error-to-Signal Ratio')
ax1.set_ylabel('ESR')
ax1.grid(True)

ax2.set_title('Mean Square Error')
ax2.set_ylabel('MSE')
ax2.grid(True)

# Add annotations
for i, value in enumerate(esr):
    ax1.annotate('{:.4f}'.format(value), xy=(i, value), ha='center', va='top')

for i, value in enumerate(mse):
    ax2.annotate('{:.4f}'.format(value), xy=(i, value), ha='center', va='top')

# Display the plot
fig.subplots_adjust(wspace=0.4)
plt.savefig('plots/bars.png')
plt.show()

In [None]:
#@title Plot the waveforms

fig, axs = plt.subplots(nrows=3, ncols=1, sharex=True, figsize=(10,10))

axs[0].set_title('Comparison of the Y validation with the first prediction')
librosa.display.waveshow(gt.numpy(), sr=sample_rate, ax=axs[0], color='b', alpha=0.8, label='Y validation')
librosa.display.waveshow(p1.numpy(), sr=sample_rate, ax=axs[0], color='r', alpha=0.5, label='Pretrained model')

axs[1].set_title('Comparison of the Y validation with the second prediction')
librosa.display.waveshow(gt.numpy(), sr=sample_rate, ax=axs[1], color='b', alpha=0.8, label='Y validation')
librosa.display.waveshow(p2.numpy(), sr=sample_rate, ax=axs[1], color='g', alpha=0.5, label='New model')

axs[2].set_title('Comparison of the two predictions')
librosa.display.waveshow(p1.numpy(), sr=sample_rate, ax=axs[2], color='r', alpha=0.5, label='Pretrained model')
librosa.display.waveshow(p2.numpy(), sr=sample_rate, ax=axs[2], color='g', alpha=0.5, label='New model')

# set the x and y labels, grid and legend for all subplots
for i, ax in enumerate(axs):
    ax.set_xlabel('Time (min)')
    ax.set_ylabel('Amplitude')
    ax.grid(True)
    ax.legend(loc='upper right')

    # Save just the portion _inside_ the second axis's boundaries
    # extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
    # fig.savefig(f'./imgs/ax{i}-wave.png', bbox_inches=extent.expanded(1.3, 1.3))

# show the figure
plt.tight_layout()
plt.show()

fig.savefig('plots/all-waveforms.png')

In [None]:
#@title Zoom waveform's time frames
time = 3.9 #@param {type:"slider", min:0, max:30, step:0.1}
center = 54 #@param {type:"slider", min:0, max:120, step:1}

time_window = [center - time, center + time]

fig, axs = plt.subplots(nrows=1, ncols=2, sharey=True, 
                        figsize=(10, 4))

# Plot first subplot
librosa.display.waveshow(gt.numpy(), sr=sample_rate, 
                         alpha=0.8, ax=axs[0], color='b', 
                         label='Y validation')
librosa.display.waveshow(p1.numpy(), sr=sample_rate,  
                         alpha=0.5, ax=axs[0], color='r', 
                         label='Pretrained Model')
axs[0].set_title("Comparison of the Y validation with the first prediction")

# Plot second subplot
librosa.display.waveshow(gt.numpy(), sr=sample_rate, 
                         alpha=0.8, ax=axs[1], color='b', 
                         label='Y validation')
librosa.display.waveshow(p2.numpy(), sr=sample_rate,  
                         alpha=0.5, ax=axs[1], color='g', 
                         label='New Model')
axs[1].set_title("Comparison of the Y validation with the second prediction")

# set the x and y labels, grid and legend for all subplots
for i, ax in enumerate(axs):
    ax.set_xlabel('Time (min)')
    ax.set_ylabel('Amplitude')
    ax.grid(True)
    ax.legend(loc='upper right')
    ax.set_xlim(time_window)
plt.tight_layout()
plt.savefig('plots/waveforms-time-frame.png')
plt.show()

In [None]:
#@title Zoom spectrograms' time frames
time = 1.9 #@param {type:"slider", min:0, max:30, step:0.1}
center = 54 #@param {type:"slider", min:0, max:120, step:1}

time_window = [center - time, center + time]

fig, axs = plt.subplots(
    nrows=1, ncols=3, sharey=True,
    figsize=(10, 4))

dgt = librosa.amplitude_to_db(
    np.abs(librosa.stft(gt.numpy().squeeze())), ref=np.max)
dp1 = librosa.amplitude_to_db(
    np.abs(librosa.stft(p1.numpy().squeeze())), ref=np.max)
dp2 = librosa.amplitude_to_db(
    np.abs(librosa.stft(p2.numpy().squeeze())), ref=np.max)

librosa.display.specshow(
    dgt, sr=sample_rate, x_axis='time', ax=axs[0], 
    cmap='hot', label='Y validation')
axs[0].set_title('Y validation')

librosa.display.specshow(
    dp1, sr=sample_rate, x_axis='time', ax=axs[1], 
    cmap='hot', label='Pretrained Model')
axs[1].set_title('Pretrained Model')

librosa.display.specshow(
    dp2, sr=sample_rate, x_axis='time', ax=axs[2], 
    cmap='hot', label='New Model')
axs[2].set_title('New Model')

# set the x and y labels, grid and legend for all subplots
for i, ax in enumerate(axs):
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Frequency (Hz)')
    ax.grid(True)
    ax.set_xlim(time_window)

plt.savefig('plots/spec-time-frames.png')
plt.show()