In [None]:
import time
from pathlib import Path
from pprint import pprint
from statistics import mean, stdev

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import torch
from scipy.io import wavfile
from scipy.signal import freqz
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm, trange

from src.dataset import SignalTrainDataset, download_signal_train_dataset_to
from src.loss import forge_validation_criterions_by
from src.model import S4ConditionalModel
from src.model.layer import DSSM, convert_to_decibel
from src.parameter import ConditionalTaskParameter
from src.utils import get_tensor_device, set_random_seed_to

# S4 Hyper-conditioning Dynamic Range Compressor Model Evaluation

This Jupyter Notebook contains routine to evaluate a trained single-chain S4 DRC model with hyper-conditioning.

Edit and execute the code in the [Preparatory Work](#preparatory-work) section first to load the model,
and then execude the code in the rest of the section to evaluate each individual metrics.

Each individual evaluation task is wrapped in a function to prevent variables going global,
All functions will save results to the local file system.

## <a id="preparatory-work">Preparatory Work</a>

Edit the following global constant to locate the model to be evaluated.

If the model is trained properly using my given script, you don't need to edit any other cell.

In [None]:
DATASET_DIR = Path('./data/SignalTrain')
CHECKPOINT_DIR = Path('./experiment-result')

# Single chain without tanh
# JOB_NAME = '2023-5-9-22-34-47'
# JOB_NAME = 'single-chain-no-tanh'
# EPOCH = 54

# Side chain without tanh
# JOB_NAME = '2023-5-10-7-1-39'
# JOB_NAME = 'side-chain-no-tanh'
# EPOCH = 68

# Single chain with tanh
# JOB_NAME = '2023-5-10-15-27-38'
# JOB_NAME = 'single-chain-tanh'
# EPOCH = 66

# Side chain with tanh
# JOB_NAME = '2023-5-10-23-54-7'
# JOB_NAME = 'side-chain-tanh'
# EPOCH = 66

# Single chain with parametered tanh
# JOB_NAME = '2023-6-28-3-32-47'
# JOB_NAME = 'single-chain-ptanh'
# EPOCH = 52

# Side chain with parametered tanh
# JOB_NAME = '2023-6-28-16-11-36'
JOB_NAME = 'side-chain-ptanh'
EPOCH = 40

Execute the following code to load the model and configuration.

In [None]:
job_dir = CHECKPOINT_DIR / JOB_NAME
job_eval_dir = job_dir / 'evaluations'
job_eval_dir.mkdir(parents=True, exist_ok=True)

device = get_tensor_device(apple_silicon=False)  # Some operations are not supported on Apple Silicon
param = ConditionalTaskParameter.from_json(CHECKPOINT_DIR / JOB_NAME / 'config.json')
pprint(param.to_dict())

set_random_seed_to(param.random_seed)

download_signal_train_dataset_to(DATASET_DIR)
testing_dataset_short = SignalTrainDataset(DATASET_DIR, 'test', 1.5)
testing_dataset_mid = SignalTrainDataset(DATASET_DIR, 'test', 4)
testing_dataset_long = SignalTrainDataset(DATASET_DIR, 'test', 10)
testing_datasets = {
    'short': testing_dataset_short,
    'mid': testing_dataset_mid,
    'long': testing_dataset_long,
}
testing_dataset_lengths = {
    'short': 1.5,
    'mid': 4,
    'long': 10,
}

model = S4ConditionalModel(
    param.model_take_side_chain,
    param.model_inner_audio_channel,
    param.model_s4_hidden_size,
    param.s4_learning_rate,
    param.model_depth,
    param.model_film_take_batchnorm,
    param.model_take_residual_connection,
    param.model_convert_to_decibels,
    param.model_take_tanh,
    param.model_activation,
    param.model_take_parametered_tanh,
).eval().to(device)
model.load_state_dict(torch.load(CHECKPOINT_DIR / JOB_NAME / f'model-epoch-{EPOCH}.pth', map_location=device))

## General Testing

In [None]:
@torch.no_grad()
def test():
    for dataset_name, dataset in testing_datasets.items():
        dataloader = DataLoader(dataset, 20, num_workers=8, pin_memory=True)

        validation_criterions = forge_validation_criterions_by(param.loss_filter_coef, device)
        validation_losses = {
            validation_loss: 0.0
            for validation_loss in validation_criterions.keys()
        }

        for x, y, parameters in tqdm(
            dataloader, desc=f'Testing {dataset_name} datset.', total=len(dataloader)
        ):
            x: Tensor = x.to(device)
            y: Tensor = y.to(device)
            parameters: Tensor = parameters.to(device)

            y_hat: Tensor = model(x, parameters)

            for validation_loss, validation_criterion in validation_criterions.items():
                loss: Tensor = validation_criterion(y_hat.unsqueeze(1), y.unsqueeze(1))
                validation_losses[validation_loss] += loss.item()
        
        for k, v in list(validation_losses.items()):
            validation_losses[k] = v / len(dataloader)
        
        with open(job_eval_dir / f'loss-{dataset_name}.txt', 'w') as f:
            pprint(validation_losses, stream=f)

test()

## S4 Frequency Response Analysis

In [None]:
@torch.no_grad()
def s4_frequency_response_analysis():
    out_dir = job_eval_dir / 's4-impulse-response'
    out_dir.mkdir(exist_ok=True)
    for c, block in enumerate(model.blocks):
        s4 = block.s4
        assert isinstance(s4, DSSM)
        kernel = s4.get_kernel(int(SignalTrainDataset.sample_rate * 1))
        for r in trange(param.model_inner_audio_channel, desc=f'Block {c}.'):
            impulse_response = kernel[r, :].detach().cpu().numpy()
            w, h = freqz(impulse_response)
            title = f'layer-{c + 1}-channel-{r + 1}'

            fig, ax = plt.subplots()
            ax.set_title(title)
            ax.plot(w, 20 * np.log10(abs(h)), 'b')
            ax.set_xlabel('Frequency [rad/sample]')
            ax.set_ylabel('Amplitude [dB]', color='b')
            ax2 = ax.twinx()
            ax2.plot(w, np.unwrap(np.angle(h)), 'g')
            ax2.set_ylabel('Angle (radians)', color='g')
            ax2.grid(True)
        
            fig.savefig(str(out_dir / f'{title}.png'))
            plt.close(fig)

s4_frequency_response_analysis()

## Evaluate Inference Efficiency

In [None]:
@torch.no_grad()
def evaluate_inference_efficiency():
    if device.type == 'cpu':
        print(f'Doing inference speed test on CPU...')
        device_name = 'cpu'
    elif device.type == 'cuda':
        print(f'Doing inference speed test on {(device_name := torch.cuda.get_device_name())}.')
    else:
        raise NotImplementedError(f'Inference efficiency test can only run on CPU/CUDA')

    for dataset_name, testing_dataset in testing_datasets.items():
        dataset_sample_length = testing_dataset_lengths[dataset_name]

        inference_time: list[int] = []
        for i in tqdm(range(10)):
            x, _, cond = testing_dataset[i]
            x = x.to(device).unsqueeze(0)
            cond = cond.to(device).unsqueeze(0)

            tic = time.perf_counter_ns()
            model(x, cond)
            toc = time.perf_counter_ns()
            inference_time.append(toc - tic)
        
        inference_time_mean = mean(inference_time) / 1e6
        inference_time_stdev = stdev(inference_time) / 1e6
        speed_ratio = inference_time_mean / (dataset_sample_length * 1e3)

        with open(job_eval_dir / f'inference-efficiency-{dataset_name}-{device_name}.txt', 'w') as f:
            print(f'Average inference time on {dataset_name} dataset: {inference_time_mean} ms. ', file=f)
            print(f'Inference time standard deviation on {dataset_name} dataset: {inference_time_stdev} ms. ', file=f)
            print(f'Real-time speed ratio on {dataset_name} dataset: {speed_ratio}. ', file=f)

evaluate_inference_efficiency()

## Evaluate Output Audio, Waveform Difference, RMS Difference, and STFT Difference

In [None]:
def acquire_rms(x: npt.NDArray[np.float32], window: int = 100):
    assert x.ndim == 1
    size = x.size
    ret = np.fromiter(
        (np.sqrt(np.mean(np.square(x[i:i + window]))) for i in range(0, size, window)),
        dtype=np.float32
    )
    return ret


@torch.no_grad()
def evaluate_output_audio():
    # Audio output, waveform difference, RMS difference and STFT difference
    for dataset_name, dataset in testing_datasets.items():
        if dataset_name != 'long':
            continue

        dataloader = DataLoader(dataset, 20, num_workers=8, pin_memory=True)

        output_audio_dir = job_eval_dir / f'output-audio-{dataset_name}'
        output_audio_dir.mkdir(exist_ok=True)
        output_rms_dir = job_eval_dir / f'output-rms-{dataset_name}'
        output_rms_dir.mkdir(exist_ok=True)
        output_waveform_dir = job_eval_dir / f'output-waveform-{dataset_name}'
        output_waveform_dir.mkdir(exist_ok=True)
        output_stft_dir = job_eval_dir / f'output-stft-{dataset_name}'
        output_stft_dir.mkdir(exist_ok=True)

        ii = 0
        for x, y, cond in tqdm(dataloader, desc=f'Evaluate {dataset_name} dataset.', total=len(dataloader)):
            x: Tensor = x.to(device)
            y: Tensor = y.to(device)
            cond: Tensor = cond.to(device)
            
            y_hat: Tensor = model(y, cond)
            
            for i in range(y_hat.size(0)):
                switch, peak_reduction = cond[i, :].flatten().cpu().tolist()
                prefix = f'{str(ii).zfill(3)}-switch={switch}-peak-reduction={peak_reduction}'
    
                x_audio = x[i, :].flatten()
                y_audio = y[i, :].flatten()
                y_hat_audio = y_hat[i, :].flatten()
                y_diff_audio = y_audio - y_hat_audio

                y_stft = torch.stft(y_audio, n_fft=1024, hop_length=256, win_length=1024, return_complex=True)
                y_hat_stft = torch.stft(y_hat_audio, n_fft=1024, hop_length=256, win_length=1024, return_complex=True)
                y_diff_stft = (y_stft.abs() - y_hat_stft.abs()).log10().mul(10).cpu().numpy()

                x_audio = x_audio.cpu().numpy()
                y_audio = y_audio.cpu().numpy()
                y_hat_audio = y_hat_audio.cpu().numpy()
                y_diff_audio = y_diff_audio.cpu().numpy()

                y_rms = acquire_rms(y_audio)
                y_hat_rms = acquire_rms(y_hat_audio)
                y_diff_rms = y_rms - y_hat_rms

                wavfile.write(output_audio_dir / f'{prefix}-x.wav', SignalTrainDataset.sample_rate, x_audio)
                wavfile.write(output_audio_dir / f'{prefix}-y.wav', SignalTrainDataset.sample_rate, y_audio)
                wavfile.write(output_audio_dir / f'{prefix}-y-hat.wav', SignalTrainDataset.sample_rate, y_hat_audio)
                wavfile.write(output_audio_dir / f'{prefix}-y-diff.wav', SignalTrainDataset.sample_rate, y_diff_audio)

                fig, ax = plt.subplots(figsize=(25, 5))
                ax.plot(y_diff_audio)
                ax.set_title(f'{prefix}')
                ax.set_xlabel('Time (s)')
                ax.set_ylabel('Amplitude')
                fig.savefig(str(output_waveform_dir / f'{prefix}-y-diff.png'))
                plt.close(fig)

                fig, ax = plt.subplots(figsize=(25, 5))
                ax.plot(y_diff_rms)
                ax.set_title(f'{prefix}')
                ax.set_xlabel('Time (s)')
                ax.set_ylabel('Amplitude')
                fig.savefig(str(output_rms_dir / f'{prefix}-y-rms-diff.png'))
                plt.close(fig)


                fig, ax = plt.subplots(figsize=(25, 5))
                ax.pcolormesh(y_diff_stft, cmap='jet')
                ax.set_title(f'{prefix}')
                ax.set_xlabel('Time (s)')
                ax.set_ylabel('Magnitude')
                fig.savefig(str(output_stft_dir / f'{prefix}-y-stft-diff.png'))
                plt.close(fig)

                ii += 1

evaluate_output_audio()

## Evaluate Model Step Response

In [None]:
@torch.no_grad()
def evaluate_model_step_response():
    parameter_pair = torch.tensor([
        [0, 0],
        [0, 5],
        [0, 20],
        [0, 35],
        [0, 50],
        [0, 65],
        [0, 80],
        [0, 95],
        [0, 100],
        [1, 0],
        [1, 5],
        [1, 20],
        [1, 35],
        [1, 50],
        [1, 65],
        [1, 80],
        [1, 95],
        [1, 100],
        [0, 2],
        [0, 18],
        [0, 34],
        [0, 53],
        [0, 78],
        [0, 97],
        [1, 2],
        [1, 18],
        [1, 34],
        [1, 53],
        [1, 78],
        [1, 97],
    ]).to(device, torch.float32)
    
    sr = SignalTrainDataset.sample_rate
    step_signal = torch.cat([
        torch.zeros(int(sr * 0.2)),
        torch.ones(int(sr * 0.8)),
        torch.zeros(int(sr * 0.8)) + 0.2,
        torch.ones(int(sr * 0.8)),
        torch.zeros(int(sr * 0.2)),
    ]).to(device, torch.float32)
    output_signals: Tensor = model(
        step_signal.repeat(parameter_pair.size(0), 1),
        parameter_pair,
    )
    step_signal_decibel = convert_to_decibel(step_signal)
    
    fig, axs = plt.subplots(10, 3, figsize=(15, 50))
    for i, output_signal in enumerate(output_signals.split(1)):
        switch, peak_reduction = parameter_pair[i].tolist()
        row, col = divmod(i, 3)
        ax = axs[row][col]
        ax.plot(step_signal_decibel.cpu().numpy(), color='blue', linestyle='dashed', alpha=1.0)
        ax.plot(convert_to_decibel(output_signal).flatten().cpu().numpy(), color='red', linestyle='solid', alpha=0.5)
        ax.set_title(f'{switch = }, {peak_reduction = }')
    fig.savefig(str(job_eval_dir / f'model-step-response.png'))
    plt.close(fig)
    
evaluate_model_step_response()