# Evaluation

In [None]:
import h5py
import ipywidgets as widgets
import torch

from tqdne.conf import Config
from tqdne.metric import PowerSpectralDensityFID

In [None]:
max_num_samples = 1000
config = Config()
test_dataset_path = config.datasetdir / config.data_upsample_test
test_file = h5py.File(test_dataset_path, mode="r")
test_waveforms = test_file["waveform"][:max_num_samples]
test_waveforms = torch.tensor(test_waveforms, dtype=torch.float32)
test_file.close()

pred_dataset_path = config.datasetdir / config.data_upsample_test # TODO: placeholder for generated dataset
pred_file = h5py.File(pred_dataset_path, mode="r")
pred_waveforms = pred_file["waveform"][1000:max_num_samples+1000]
pred_waveforms = torch.tensor(pred_waveforms, dtype=torch.float32)
pred_file.close()

## Plotting metrics

In [None]:
metrics = [PowerSpectralDensityFID(fs=config.fs)]

# Create a dropdown for selecting the metric
metric_dropdown = widgets.Dropdown(
    options=[(metric.__class__.__name__, metric) for metric in metrics],
    value=metrics[0],
    description="Metric:",
)


# Function to update the plot
def update_plot(metric):
    # Compute the metric
    metric.reset()
    result = metric(pred={"high_res": pred_waveforms}, target={"high_res": test_waveforms})
    print(result)

    # Create the plot
    metric.plot().show()

# Create interactive plot
widgets.interact(update_plot, metric=metric_dropdown)