In [None]:
%matplotlib widget

In [None]:
from glob import glob
from natsort import natsorted
from pathlib import Path

from ray_tools.simulation.torch_datasets import RayDataset

In [None]:
import ipywidgets as widgets
from IPython.display import display, Markdown

In [None]:
h5_dir = '/scratch/metrix-hackathon/datasets/metrix_simulation/ray_enhance_final'
h5_files = natsorted(glob(f"{h5_dir}/*.h5"))

key_base = "1e6/ray_output/ImagePlane/ml"
key_variants = [-25, -20, -15, -10, -5, 0, 5, 10, 15, 20, 25, 30]
keys = [f"{key_base}/{variant}" for variant in key_variants]

planes_shape = (len(key_variants), 256, 256)
initial_sample = 53

plt_title_var = "Image planes variance"
plt_title_3d = "Image planes in 3D"
cmap = 'turbo'

In [None]:
dataset = RayDataset(h5_files=h5_files,
                     nested_groups=False,
                     sub_groups=keys,
                     transform=None)
display(Markdown(f"Successfully loaded {len(dataset)} samples"))
n_samples = len(dataset)

In [None]:
def get_image_planes_3d(dataset, idx, plane_base, plane_variants):
    sample = dataset[idx]
    histograms = [sample[f'{plane_base}/{plane_variant}']['histogram'] for plane_variant in plane_variants]
    return np.stack(histograms)

def get_planes_variance_2d(dataset, idx, plane_base, plane_variants):
    hists = get_image_planes_3d(dataset, idx, key_base, key_variants)
    hists_var = np.var(hists, axis=0)
    return hists_var

def get_xyzv_lists(dataset, idx, plane_base, plane_variants):
    hists = get_image_planes_3d(dataset, idx, key_base, key_variants)
    _x, _y, _z = np.where(hists != 0)
    xyzv = []
    for x, y, z in zip(_x, _y, _z):
        v = hists[x, y, z]
        xyzv.append((x, y, z, v))
    xyzv = np.array(xyzv)
    xs, ys, zs, vs = np.hsplit(xyzv, xyzv.shape[-1])
    return xs, ys, zs, vs

## Variance over image planes

In [None]:
def on_change_var(change):
    if change['type'] == 'change' and change['name'] == 'value':
        sample_id = change['new']
        hists_var = get_planes_variance_2d(dataset, sample_id, key_base, key_variants)
        im_var.set_data(hists_var)
        fig_var.canvas.draw_idle()

w = widgets.IntSlider(
    value=initial_sample,
    min=0,
    max=len(dataset),
    step=1,
    description='Sample ID:',
    orientation='horizontal',
)
hists_var = get_planes_variance_2d(dataset, initial_sample, key_base, key_variants)

w.observe(on_change_var)
display(w)

plt.close(plt_title_var)
fig_var = plt.figure(plt_title_var, figsize=(7, 5))
ax_var = fig_var.add_subplot(1, 1, 1)
im_var = ax_var.imshow(hists_var, cmap=cmap, aspect='auto')
plt.show()

## 3D visualization of image planes

In [None]:
def on_change_3d(change):
    if change['type'] == 'change' and change['name'] == 'value':
        sample_id = change['new']
        xs, ys, zs, vs = get_xyzv_lists(dataset, sample_id, key_base, key_variants)
        scatter_3d.set_data(xs, ys, xz, c=vs)
        fig_3d.canvas.draw_idle()

w = widgets.IntSlider(
    value=initial_sample,
    min=0,
    max=len(dataset),
    step=1,
    description='Sample ID:',
    orientation='horizontal',
)
xs, ys, zs, vs = get_xyzv_lists(dataset, initial_sample, key_base, key_variants)

w.observe(on_change_3d)
display(w)

plt.close(plt_title_var)
fig_3d = plt.figure(plt_title_3d, figsize=(14, 10))
ax_3d = fig_3d.add_subplot(1, 1, 1, projection='3d')
scatter_3d = ax_3d.scatter(xs, ys, zs, c='w', alpha=vs / vs.max() * .3, s=1)
ax_3d.set_axis_off()
plt.show()