In [1]:
from src.models.end_to_end import EndToEndSystem
from src.plot_utils import dafx_from_name
from src.dataset.paired_audio_dataset import PairedAudioDataset
from src.callbacks.metrics import *
from src.utils import *
from tqdm import tqdm

import os
import glob
import pandas as pd
import pytorch_lightning as pl

In [2]:
def get_val_checkpoint_filename(checkpoint_folder):
    list_of_files = glob.glob(checkpoint_folder + "/*.ckpt")
    val_file = [fl for fl in list_of_files if "val" in fl]
    latest_file = max(val_file, key=os.path.getctime)
    return latest_file

In [3]:
def get_checkpoint_for_effect(effect_name, checkpoints_dir):
    checkpoint_id = effect_to_end_to_end_checkpoint_id(effect_name)
    checkpoint_id_dir = os.path.join(checkpoints_dir, checkpoint_id + "/checkpoints/")
    checkpoint_file = get_val_checkpoint_filename(checkpoint_id_dir)
    return checkpoint_file

In [28]:
def get_results_filename(results_dir, dafx, dataset):
    return results_dir + f"/{dafx.split()[-1].lower()}_{dataset}.csv"

In [5]:
SAMPLE_RATE = 24_000
NUM_EXAMPLES = 1_000
DAFX = "mda Overdrive"
CHECKPOINTS_DIR = "/home/kieran/Level5ProjectAudioVAE/src/train_scripts/l5proj_end2end"
AUDIO_DIR = "/home/kieran/Level5ProjectAudioVAE/src/audio"
DATASET = "daps"
DATASET_INPUT_DIRS = [f"{DATASET}_{SAMPLE_RATE}"]
CHECKPOINT = get_checkpoint_for_effect(DAFX, CHECKPOINTS_DIR)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
RESULTS_DIR = "/home/kieran/Level5ProjectAudioVAE/src/notebooks/evaluation/data/metrics"
SEED = 1234

In [None]:
pl.seed_everything(SEED)

In [6]:
metrics = {
    "PESQ": PESQ(SAMPLE_RATE),
    "MRSTFT": auraloss.freq.MultiResolutionSTFTLoss(
        fft_sizes=[32, 128, 512, 2048, 8192, 32768],
        hop_sizes=[16, 64, 256, 1024, 4096, 16384],
        win_lengths=[32, 128, 512, 2048, 8192, 32768],
        w_sc=0.0,
        w_phs=0.0,
        w_lin_mag=1.0,
        w_log_mag=1.0,
    ),
    "MSD": MelSpectralDistance(SAMPLE_RATE),
    "SCE": SpectralCentroidError(SAMPLE_RATE),
    "CFE": CrestFactorError(),
    "LUFS": LoudnessError(SAMPLE_RATE),
    "RMS": RMSEnergyError()
}

In [7]:
# load model
model = EndToEndSystem.load_from_checkpoint(CHECKPOINT).to(DEVICE)
model.eval()

EndToEndSystem(
  (audio_encoder): SpectrogramVAE(
    (encoder_conv): Sequential(
      (0): Sequential(
        (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): ReLU()
        (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): Sequential(
        (0): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): ReLU()
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (2): Sequential(
        (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): ReLU()
        (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (3): Sequential(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): ReLU()
        (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (mu): Linear(in_fea

In [8]:
# load dataset
dafx = dafx_from_name(DAFX)

dataset = PairedAudioDataset(
    dafx=dafx,
    audio_dir=AUDIO_DIR,
    subset="val",
    train_frac=0.8,
    input_dirs=DATASET_INPUT_DIRS,
    num_examples_per_epoch=NUM_EXAMPLES,
    augmentations={},
    length=model.hparams.train_length * 2,
    effect_input=False,
    effect_output=True,
    random_effect_threshold=0.,
    dummy_setting=False
)

g = torch.Generator()
g.manual_seed(SEED)

loader = torch.utils.data.DataLoader(
    dataset,
    num_workers=1,
    batch_size=1,
    timeout=6000,
    shuffle=False,
    generator=g
)

100%|██████████████████████████████████████████| 9/9 [00:00<00:00, 14407.91it/s]


Loaded 9 files for val = 0.38 hours.





In [15]:
# get values
outputs = []

for batch in tqdm(loader):
    x, y = batch
    x, y_ref, y = get_training_reference(x, y)

    y_hat, p, z = model(x.to(DEVICE), y=y_ref.to(DEVICE))

    outputs.append({
        "y": y.detach().cpu(),
        "y_hat": y_hat.detach().cpu(),
    })

100%|██████████| 1000/1000 [00:22<00:00, 44.53it/s]


In [37]:
results = {
            "PESQ": [],
            "MRSTFT": [],
            "MSD": [],
            "SCE": [],
            "CFE": [],
            "LUFS": [],
            "RMS": [],
        }

In [38]:
for output in tqdm(outputs):
    for metric_name, metric in metrics.items():
        try:
            val = metric(output["y_hat"], output["y"])
            if type(val) == torch.Tensor:
                val = val.numpy()
            results[metric_name].append(val)
        except Exception as e:
            print("Some error occurred: ", e)
            results[metric_name].append(np.NaN)

 50%|█████     | 503/1000 [01:18<01:14,  6.66it/s]

Some error occurred:  b'No utterances detected'


100%|██████████| 1000/1000 [02:36<00:00,  6.40it/s]


In [39]:
df = pd.DataFrame(results)

In [40]:
print(df.head())

       PESQ     MRSTFT        MSD         SCE        CFE      LUFS       RMS
0  2.955839  1.2800151  7.4726267    5.682373  10.355343  3.377809  8.342247
1  4.505720  0.9200942  4.2486234  259.835449   2.124372  0.950178  2.357044
2  3.639986   1.345643   5.394623  599.538269   1.662407  1.564906  4.213348
3  4.253411  0.4026405  6.6475515  104.319763   1.042591  4.088740  9.854973
4  4.379947  1.0551684  4.6367507  513.889954   0.028324  2.132826  5.311535


In [41]:
results_filename = get_results_filename(RESULTS_DIR, DAFX, DATASET)
df.to_csv(results_filename)