# Evaluation

In [1]:
import h5py
import ipywidgets as widgets
from functools import partial

from tqdne.conf import Config
from tqdne.metric import PowerSpectralDensity, MeanSquaredError, BinMetric

from pathlib import Path

2024-02-14 11:48:28 - numexpr.utils - INFO - Note: NumExpr detected 24 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2024-02-14 11:48:28 - numexpr.utils - INFO - NumExpr defaulting to 8 threads.
  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


In [2]:
!ls ..

CHANGELOG.rst	  notebooks	      slurm-51680583.out  slurm-51732192.out
datasets	  outputs	      slurm-51712471.out  slurm-51732322.out
diffusers	  poetry.lock	      slurm-51713032.out  slurm-51742218.out
doc		  pyproject.toml      slurm-51713320.out  slurm-51743741.out
environment.yaml  README.rst	      slurm-51723303.out  tqdne
experiments	  run-scripts	      slurm-51730838.out  wandb
Makefile	  slurm-51631391.out  slurm-51731011.out
MANIFEST.in	  slurm-51667002.out  slurm-51731126.out
nohup.out	  slurm-51671025.out  slurm-51732107.out


In [3]:
max_num_samples = 50
config = Config()
test_dataset_path = Path("../datasets/small_data_upsample_test.h5")
test_file = h5py.File(test_dataset_path, mode="r")
test_waveforms = test_file["waveform"][:max_num_samples]
test_features = test_file["features"][:max_num_samples]
test_file.close()

pred_dataset_path = Path("../datasets/small_data_upsample_train.h5") # TODO: placeholder for generated dataset
pred_file = h5py.File(pred_dataset_path, mode="r")
pred_waveforms = pred_file["waveform"][:max_num_samples]
pred_file.close()

## Plotting metrics

In [4]:
metrics = {"Power Spectral Density": partial(PowerSpectralDensity, fs=config.fs), "Mean Squared Error": MeanSquaredError}

# Create a dropdown for selecting the metric
metric_dropdown = widgets.Dropdown(
    options=metrics.keys(),
    description='Metric:',
)

# Create a slider for selecting the channel
channel_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=2,
    step=1,
    description='Channel:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)


# Bin plot checkbox
bin_plot_checkbox = widgets.Checkbox(
    value=False,
    description='Plot bins',
    disabled=False,
    indent=True,
)

# Create a slider for selecting the number of bins
num_bins_slider = widgets.IntSlider(
    value=10,
    min=1,
    max=50,
    step=1,
    description='Num bins:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

# Function to update the plot
def update_plot(metric_name, bin_plot, num_bins=10, channel=0):
    metric = metrics[metric_name](channel=channel)
    if bin_plot:
        metric = BinMetric(metric, num_mag_bins=num_bins, num_dist_bins=num_bins)

    # Compute the metric
    metric.reset()
    metric.update(pred={"high_res": pred_waveforms}, target={"high_res": test_waveforms, "cond": test_features})
    metric.plot().show()

# Create interactive plot
widgets.interact(update_plot, metric_name=metric_dropdown, channel=channel_slider, bin_plot=bin_plot_checkbox, num_bins=num_bins_slider)

interactive(children=(Dropdown(description='Metric:', options=('Power Spectral Density', 'Mean Squared Error')…

<function __main__.update_plot(metric_name, bin_plot, num_bins=10, channel=0)>