In [None]:
import h5py
import numpy as np

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

In [None]:
hf = h5py.File('/scratch/ns4486/numerical-relativity-interpolation/Proca_fiducial_scaled_cropped.hdf5', 'r')

x = hf.get('Train').get('input')
y = hf.get('Train').get('target')

inputs = np.array(x)
outputs = np.array(y)

print("Shapes of inputs and outputs:")
print((inputs.shape, outputs.shape))
print("Min, Mean, Max in inputs")
print((inputs.min(), inputs.mean(), inputs.max()))

In [None]:
cats = pd.qcut(inputs[:, 0].flatten(), q=10)

In [None]:
group_interval_values = []
group = 1
for interval in list(cats.categories):
    group_interval_values.append([interval.left, interval.right, group])
    group += 1

In [None]:
group_intervals_df = pd.DataFrame(group_interval_values, columns=['left', 'right', 'group'])
group_intervals_df['color'] = px.colors.sequential.Viridis_r
group_intervals_df

In [None]:
def estimate_group(x):
    for index, row in group_intervals_df.iterrows():
        if x > row.left and x <= row.right:
            return row.group
    if x <= group_intervals_df.iloc[0]['left']:
        return 1
    else:
        return 10

In [None]:
def estimate_color(x):
    for index, row in group_intervals_df.iterrows():
        if x > row.left and x <= row.right:
            hex_color = row.color.lstrip('#')
            rgb = list(tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4)))
            opacity = ((row.group - 1) * 0.1) + (0.1 * (x - row.left)/(row.right - row.left))
            rgb.append(opacity)
#             print(opacity, rgb)
            rgb = 'rgba'+str(tuple(rgb))
#             print(rgb)
            return rgb
    if x <= group_intervals_df.iloc[0]['left']:
        hex_color = group_intervals_df.iloc[0].color.lstrip('#')
        rgb = list(tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4)))
        opacity = 0
        rgb.append(opacity)
        rgb = 'rgba'+str(tuple(rgb))
        return rgb
    else:
        hex_color = group_intervals_df.iloc[-1].color.lstrip('#')
        rgb = list(tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4)))
        opacity = 1
        rgb.append(opacity)
        rgb = 'rgba'+str(tuple(rgb))
        return rgb

In [None]:
location_value = []
frame = 0
index = 0

for i in range(inputs[index, frame].shape[0]):
    for j in range(inputs[index, frame].shape[1]):
        for k in range(inputs[index, frame].shape[2]):
            location_value.append([i, j, k, inputs[index, frame, i, j, k]])

df = pd.DataFrame(data = location_value, columns=['x', 'y', 'z', 'value'])
df

In [None]:
df['value'].apply(estimate_group)

In [None]:
df['color'] = df['value'].apply(estimate_color)
df

In [None]:
fig = go.Figure()

fig.add_trace(
    go.Scatter3d(
        mode='markers',
        x=df['x'],
        y=df['y'],
        z=df['z'],
        hovertext=df['value'],
        marker=dict(
            color=df['color']
        )
    )
)
camera = dict(
    eye=dict(x=1.25, y=1.25, z=1.25)
)

fig.update_layout(scene_camera=camera)

fig.show()

In [1]:
fig = go.Figure(data=go.Isosurface(
    x=df['x'].values,
    y=df['y'].values,
    z=df['x'].values,
    value=df['value'].values,
    isomin=0,
    isomax=13,
#     caps=dict(x_show=False, y_show=False)
))

fig.show()

NameError: name 'go' is not defined

In [None]:
fig = go.Figure(data=go.Volume(
    x=df['x'].values,
    y=df['y'].values,
    z=df['z'].values,
    value=df['value'].values,
    isomin=0.01,
    isomax=0.2,
    opacity=0.1, # needs to be small to see through all surfaces
    surface_count=17, # needs to be a large number for good volume rendering
    ))
fig.show()