In [1]:
from pathlib import Path

import torch
from torch.utils.data import DataLoader

from src.model import S4ConditionalSideChainModel
from src.parameter import ConditionalTaskParameter
from src.utils import get_tensor_device, set_random_seed_to
from src.dataset import SignalTrainDataset, download_signal_train_dataset_to, SwitchValue, PeakReductionValue

# S4 Hyper-conditioning Dynamic Range Compressor Model Evaluation

This Jupyter Notebook contains routine to evaluate a trained S4 DRC model with hyper-conditioning.

Edit and execute the code in the [Preparatory Work](#preparatory-work) section first to load the model,
and then execude the code in the rest of the section to evaluate each individual metrics.
Each individual evaluation task is wrapped in a function to keep variables not global.

## <a id="preparatory-work">Preparatory Work</a>

Edit the following global constant to locate the model to be evaluated.

If the model is trained properly using my given script, you don't need to edit any other cell.

In [None]:
DATASET_DIR = Path('./data/SignalTrain')
CHECKPOINT_DIR = Path('./experiment-result')
JOB_NAME = '2023-3-30-4-20-19'
EPOCH = '30'

TESTING_DATASET_SEGMENT_LENGTH = 3.0
TESTING_DATASET_BATCH_SIZE = 1

Execute the following code to load the model and configuration.

In [None]:
device = get_tensor_device(apple_silicon=False)
param = ConditionalTaskParameter.from_json(CHECKPOINT_DIR / JOB_NAME / 'config.json')

set_random_seed_to(param.random_seed)

download_signal_train_dataset_to(DATASET_DIR)
testing_dataset = SignalTrainDataset(DATASET_DIR, 'test', TESTING_DATASET_SEGMENT_LENGTH)
testing_dataloader = DataLoader(testing_dataset, TESTING_DATASET_BATCH_SIZE, shuffle=False)

model = S4ConditionalSideChainModel(
    param.model_version,
    param.model_control_parameter_mlp_depth,
    param.model_control_parameter_mlp_hidden_size,
    param.model_film_take_batch_normalization,
    param.model_inner_audio_channel,
    param.model_s4_hidden_size,
    param.s4_learning_rate,
    param.model_depth,
    param.model_activation,
    param.model_convert_to_decibels,
).eval().to(device)
model.load_state_dict(torch.load(CHECKPOINT_DIR / JOB_NAME / f'model-epoch-{EPOCH}.pth', map_location=device))

## Evaluate Model Step Response

In [None]:
@torch.no_grad()
def evaluate_model_step_response():
    in_dataset_compressing_parameter_pair = torch.tensor([
        [0, 0],
        [0, 5],
        [0, 20],
        [0, 35],
        [0, 50],
        [0, 65],
        [0, 80],
        [0, 95],
        [0, 100],
    ]).to(device, torch.float32)
    in_dataset_limitating_parameter_pair = torch.tensor([
        [1, 0],
        [1, 5],
        [1, 20],
        [1, 35],
        [1, 50],
        [1, 65],
        [1, 80],
        [1, 95],
        [1, 100],
    ]).to(device, torch.float32)
    out_dataset_compressing_parameter_pair = torch.tensor([
        [0, 2],
        [0, 18],
        [0, 34],
        [0, 53],
        [0, 78],
        [0, 97],
    ]).to(device, torch.float32)
    out_dataset_limitating_parameter_pair = torch.tensor([
        [1, 2],
        [1, 18],
        [1, 34],
        [1, 53],
        [1, 78],
        [1, 97],
    ]).to(device, torch.float32)
    
    sr = testing_dataset.sample_rate
    step_signal = torch.cat([
        torch.zeros(int(sr * 0.1)),
        torch.zeros(int(sr * 0.45)),
        torch.zeros(int(sr * 0.45)) + 0.3
    ]).to(device, torch.float32)
    
    in_dataset_compressing_output_signal = model(
        step_signal.repeat(in_dataset_compressing_parameter_pair.size(0), 1),
        in_dataset_compressing_parameter_pair,
    )
    in_dataset_limitating_output_signal = model(
        step_signal.repeat(in_dataset_limitating_parameter_pair.size(0), 1),
        in_dataset_limitating_parameter_pair,
    )
    out_dataset_compressing_output_signal = model(
        step_signal.repeat(out_dataset_compressing_parameter_pair.size(0), 1),
        out_dataset_compressing_parameter_pair,
    )
    out_datasetlimitatingg_output_signal = model(
        step_signal.repeat(out_dataset_limitating_parameter_pair.size(0), 1),
        out_dataset_limitating_parameter_pair,
    )

## Evaluate Waveform Difference