# Implicit Function Learning

In [6]:
import copy
import os

import numpy as np
import open3d as o3d
import plotly.express as px
import plotly.figure_factory as ff
import plotly.graph_objects as go
import plotly.io as pio
from PIL import Image
from easy_o3d import utils
from scipy.spatial.transform import Rotation
from skimage.measure import marching_cubes

slides = False
show = False
out_dir = "../figures/implicit"

## 1. Simple Implicit Function

In [12]:
def get_sphere_mesh(xyz, radius=1):
    mesh = o3d.geometry.TriangleMesh.create_sphere(radius=radius)
    mesh.translate([xyz[0], xyz[1], xyz[2]])

    vertices = np.asarray(mesh.vertices)
    triangles = np.asarray(mesh.triangles)
    x = vertices[:, 0]
    y = vertices[:, 1]
    z = vertices[:, 2]
    i = triangles[:, 0]
    j = triangles[:, 1]
    k = triangles[:, 2]

    return x, y, z, i, j, k

In [None]:
# Complete sphere
x, y, z, i, j, k = get_sphere_mesh(np.zeros(3), radius=1)

sphere_mesh_plot = go.Mesh3d(x=x, y=y, z=z,
                             i=i, j=j, k=k,
                             hovertemplate="<b>x</b> %{x:.2f}<br><b>y</b> %{y:.2f}<br><b>z</b> %{z:.2f}<extra><b>Point</b></extra>",
                             color="#00CC96",
                             showscale=False)

# Point on sphere
x, y, z, i, j, k = get_sphere_mesh([0.71, 0.41, 0.57], radius=0.05)

point_plot = go.Mesh3d(x=x, y=y, z=z,
                       i=i, j=j, k=k,
                       hovertemplate="<b>azimuth: 30°<br>polar: 55°</b><extra></extra>",
                       color="white",
                       showscale=False)

# Half sphere
sphere_mesh = o3d.geometry.TriangleMesh().create_sphere(resolution=200)
bbox = sphere_mesh.get_axis_aligned_bounding_box().translate([-1, 0, 0])
cut_sphere_mesh = sphere_mesh.crop(bbox)
vertices = np.asarray(cut_sphere_mesh.vertices)
triangles = np.asarray(cut_sphere_mesh.triangles)
x = vertices[:, 0]
y = vertices[:, 1]
z = vertices[:, 2]
i = triangles[:, 0]
j = triangles[:, 1]
k = triangles[:, 2]

cut_sphere_mesh_plot = go.Mesh3d(x=x, y=y, z=z,
                                 i=i, j=j, k=k,
                                 hovertemplate="f(x,y,z)=0<extra><b>Surface</b></extra>",
                                 color="#00CC96",
                                 showscale=False,
                                 visible=False)

# Point inside sphere
x, y, z, i, j, k = get_sphere_mesh(xyz=np.zeros(3), radius=0.05)

inside_point_plot = go.Mesh3d(x=x, y=y, z=z,
                              i=i, j=j, k=k,
                              hovertemplate="f(x,y,z)<0<extra><b>Inside</b></extra>",
                              color="#636EFA",
                              showscale=False,
                              visible=False)

# Point outside sphere
x, y, z, i, j, k = get_sphere_mesh(xyz=[0, 1, 1], radius=0.05)

outside_point_plot = go.Mesh3d(x=x, y=y, z=z,
                               i=i, j=j, k=k,
                               hovertemplate="f(x,y,z)>0<extra><b>Outside</b></extra>",
                               color="#EF553B",
                               showscale=False,
                               visible=False)

# Figure
fig = go.Figure([sphere_mesh_plot, point_plot, cut_sphere_mesh_plot, inside_point_plot, outside_point_plot])

buttons = [dict(label="Explicit/Parametric", method="update", args=[dict(visible=[True, True, False, False, False])]),
           dict(label="Implicit", method="update", args=[dict(visible=[False, False, True, True, True])])]

fig.update_layout(scene=dict(xaxis=dict(visible=False),
                             yaxis=dict(visible=False),
                             zaxis=dict(visible=False),
                             aspectmode="data"),
                  hoverlabel=dict(font_size=font_size),
                  showlegend=False,
                  height=700 if slides else 500,
                  margin=dict(r=0, l=0, b=0, t=0, pad=0),
                  scene_camera=dict(eye=dict(x=1, y=0, z=0),
                                    up=dict(x=0, y=0, z=1),
                                    center=dict(x=0, y=0, z=0),
                                    projection=dict(type="orthographic")),
                  scene_dragmode="orbit",
                  updatemenus=[dict(type="buttons",
                                    xanchor="left",
                                    x=0.01,
                                    y=0.99,
                                    font_size=font_size,
                                    active=0,
                                    buttons=buttons)])

if show:
    fig.show(config=dict(displayModeBar=False))
else:
    pio.write_html(fig,
                   file=os.path.join(out_dir, "sphere.html"),
                   config=dict(displayModeBar=False),
                   full_html=False,
                   include_plotlyjs=False)

## 2. Complex Implicit Function

In [None]:
def linear_interpolation(min_val: float, max_val: float, data: np.array) -> np.array:
    return (max_val - min_val) * (data - np.min(data)) / (np.max(data) - np.min(data)) + min_val


def load_sdf(sdf_path: str, resolution: int = 256):
    intsize = 4
    floatsize = 8
    sdf = {
        "bounds": [],
        "values": []
    }
    with open(sdf_path, "rb") as f:
        try:
            bytes = f.read()
            ress = np.frombuffer(bytes[:intsize * 3], dtype=np.int32)
            if -1 * ress[0] != resolution or ress[1] != resolution or ress[2] != resolution:
                raise Exception(sdf_path, "res not consistent with ", str(resolution))
            positions = np.frombuffer(bytes[intsize * 3:intsize * 3 + floatsize * 6], dtype=np.float64)
            sdf["bounds"] = [positions[0], positions[1], positions[2], positions[3], positions[4], positions[5]]
            sdf["bounds"] = np.float32(sdf["bounds"])
            sdf["values"] = np.frombuffer(bytes[intsize * 3 + floatsize * 6:], dtype=np.float32)
            sdf["values"] = np.reshape(sdf["values"], (resolution + 1, resolution + 1, resolution + 1))
        finally:
            f.close()
    return sdf


def mesh_from_sdf(sdf_path: str, resolution: int = 256, padding: float = 0.1, level: float = 0.0):
    sdf_dict = load_sdf(sdf_path, resolution)
    volume = sdf_dict["values"].copy().transpose((2, 1, 0))

    box_size = 1 + padding
    voxel_size = box_size / (np.array(volume.shape) - 1)

    vertices, triangles, normals, _ = marching_cubes(volume=volume,
                                                     level=level,
                                                     spacing=voxel_size,
                                                     step_size=1,
                                                     allow_degenerate=False,
                                                     method="lewiner")

    mesh = o3d.geometry.TriangleMesh()
    mesh.vertices = o3d.utility.Vector3dVector(vertices)
    mesh.triangles = o3d.utility.Vector3iVector(triangles)
    mesh.vertex_normals = o3d.utility.Vector3dVector(normals)
    mesh = mesh.simplify_vertex_clustering(0.01)
    mesh.compute_triangle_normals()
    offsets = np.repeat(0.5 * box_size, 3)
    mesh.translate(-offsets)
    return mesh


def plot_from_mesh(mesh: o3d.geometry.TriangleMesh, color: str, name: str = ""):
    vertices = np.asarray(mesh.vertices)
    triangles = np.asarray(mesh.triangles)
    mesh_plot = ff.create_trisurf(x=vertices[:, 0],
                                  y=vertices[:, 1],
                                  z=vertices[:, 2],
                                  simplices=triangles,
                                  plot_edges=False,
                                  colormap=color,
                                  show_colorbar=False)
    mesh_plot.data[0].showlegend = True
    mesh_plot.data[0].name = name
    mesh_plot.data[0].hoverinfo = "name"
    return mesh_plot.data[0]

### 2.1 Binary Classification (Occupancy)

In [None]:
# Inside/outside
cut_samples = uniform_random_samples[uniform_random_samples[:, 0] <= 0]
probabilities = cut_samples[:, 3]
inside = linear_interpolation(0, 0.5, probabilities[probabilities <= 0])
outside = linear_interpolation(0.5, 1, probabilities[probabilities > 0])
probabilities = np.concatenate([inside, outside])

inside_outside_plot = go.Scatter3d(x=cut_samples[:, 0],
                                   y=cut_samples[:, 1],
                                   z=cut_samples[:, 2],
                                   mode="markers",
                                   marker=dict(size=4,
                                               color=["#EF553B" if s >= 0.5 else "#636EFA" for s in probabilities],
                                               line=dict(width=0.5,
                                                         color="DarkSlateGrey")),
                                   showlegend=False,
                                   text=cut_samples[:, 3],
                                   hovertemplate="%{text:.2f}<extra><b>Distance</b></extra>")

# Surface
cut_surface_samples = surface_samples[surface_samples[:, 0] <= 0]

surface_plot = go.Scatter3d(x=cut_surface_samples[:, 0],
                            y=cut_surface_samples[:, 1],
                            z=cut_surface_samples[:, 2],
                            name="Surface",
                            hoverinfo="name",
                            mode="markers",
                            marker=dict(size=3,
                                        color="white",
                                        line=dict(width=0.5,
                                                  color="DarkSlateGrey")),
                            showlegend=True,
                            visible="legendonly")

# Mesh
vertices = np.asarray(mesh.vertices)
triangles = np.asarray(mesh.triangles)

mesh_plot = ff.create_trisurf(x=vertices[:, 0],
                              y=vertices[:, 1],
                              z=vertices[:, 2],
                              simplices=triangles,
                              plot_edges=True,
                              colormap="#00CC96",
                              show_colorbar=False).data
mesh_plot[0].showlegend = True
mesh_plot[1].showlegend = False
mesh_plot[0].visible = "legendonly"
mesh_plot[1].visible = "legendonly"
mesh_plot[0].legendgroup = "Mesh"
mesh_plot[1].legendgroup = "Mesh"
mesh_plot[0].name = "Mesh"
mesh_plot[1].name = "Mesh"
mesh_plot[0].hoverinfo = "name"
mesh_plot[1].hoverinfo = "none"

# Figure
fig = go.Figure([inside_outside_plot, surface_plot, mesh_plot[0], mesh_plot[1]])

buttons = list([
    dict(label="Inside/Outside",
         method="restyle",
         args=[{"marker": [dict(size=3,
                                color=["#EF553B" if s >= 0.5 else "#636EFA" for s in probabilities],
                                colorscale=None,
                                line=dict(width=0.5,
                                          color="DarkSlateGrey"))],
                "text": [cut_samples[:, 3]],
                "hovertemplate": "%{text:.2f}<extra><b>Distance</b></extra>"}, [0]]),
    dict(label="Probabilities",
         method="restyle",
         args=[{"marker": [dict(size=3,
                                color=probabilities,
                                colorscale="RdBu",
                                line=dict(width=0.5,
                                          color="DarkSlateGrey"))],
                "text": [probabilities],
                "hovertemplate": "%{text:.2f}<extra><b>Probability</b></extra>"}, [0]])
])

fig.update_layout(scene=dict(xaxis=dict(visible=False),
                             yaxis=dict(visible=False),
                             zaxis=dict(visible=False),
                             aspectmode="data"),
                  hoverlabel=dict(font_size=font_size),
                  legend=dict(xanchor="right",
                              font_size=font_size,
                              itemsizing="constant",
                              bgcolor="rgba(0, 0, 0, 0)"),
                  height=1200 if slides else 700,
                  margin=dict(r=0, l=0, b=0, t=0, pad=0),
                  scene_camera=dict(eye=dict(x=1, y=0, z=0),
                                    up=dict(x=0, y=0, z=1),
                                    center=dict(x=0, y=0, z=0),
                                    projection=dict(type="orthographic")),
                  scene_dragmode="orbit",
                  updatemenus=[dict(type="buttons",
                                    xanchor="left",
                                    x=0.01,
                                    y=0.99,
                                    font_size=font_size,
                                    active=0,
                                    buttons=buttons)])

if show:
    fig.show(config=dict(displaylogo=False))
else:
    pio.write_html(fig,
                   file=os.path.join(out_dir, "classification.html"),
                   config=dict(displaylogo=False) if slides else dict(displayModeBar=False),
                   full_html=False,
                   include_plotlyjs=False)

### 2.2 Regression (SDF)

In [None]:
# SDF ISO surfaces
bbox = o3d.geometry.AxisAlignedBoundingBox(min_bound=[-1, -1, -1], max_bound=[0, 1, 1])
back_crop_mesh_iso_0 = mesh_iso_0.crop(bbox)
back_crop_mesh_iso_0025 = mesh_iso_0025.crop(bbox)
back_crop_mesh_iso_005 = mesh_iso_005.crop(bbox)
back_crop_mesh_iso_neg_0025 = mesh_iso_neg_0025.crop(bbox)
back_crop_mesh_iso_neg_005 = mesh_iso_neg_005.crop(bbox)
bbox = o3d.geometry.AxisAlignedBoundingBox(min_bound=[0, -1, -1], max_bound=[1, 1, 1])
front_crop_mesh_iso_0 = mesh_iso_0.crop(bbox)

meshes = list()
meshes.append(plot_from_mesh(back_crop_mesh_iso_0, "rgb(255, 255, 255)", "zero-level set"))
meshes.append(plot_from_mesh(back_crop_mesh_iso_0025, px.colors.diverging.RdBu[4], "0.025"))
meshes.append(plot_from_mesh(back_crop_mesh_iso_005, px.colors.diverging.RdBu[2], "0.05"))
meshes.append(plot_from_mesh(back_crop_mesh_iso_neg_0025, px.colors.diverging.RdBu[-4], "-0.025"))
meshes.append(plot_from_mesh(back_crop_mesh_iso_neg_005, px.colors.diverging.RdBu[-2], "-0.05"))
meshes.append(plot_from_mesh(front_crop_mesh_iso_0, "rgb(255, 255, 255)", "zero-level set (left side)"))
meshes[-1].visible = "legendonly"

# Figure
fig = go.Figure(meshes)

fig.update_layout(scene=dict(xaxis=dict(visible=False),
                             yaxis=dict(visible=False),
                             zaxis=dict(visible=False),
                             aspectmode="data"),
                  hoverlabel=dict(font_size=font_size),
                  legend=dict(xanchor="right",
                              title="ISO Surfaces" if slides else None,
                              title_font_size=font_size + 5,
                              font_size=font_size,
                              itemsizing="constant",
                              bgcolor="rgba(0, 0, 0, 0)"),
                  height=1200 if slides else 700,
                  margin=dict(r=0, l=0, b=0, t=0, pad=0),
                  scene_camera=dict(eye=dict(x=1, y=0, z=0),
                                    up=dict(x=0, y=0, z=1),
                                    center=dict(x=0, y=0, z=0),
                                    projection=dict(type="orthographic")),
                  scene_dragmode="orbit")

if show:
    fig.show(config=dict(displaylogo=False))
else:
    pio.write_html(fig,
                   file=os.path.join(out_dir, "regression.html"),
                   config=dict(displaylogo=False) if slides else dict(displayModeBar=False),
                   full_html=False,
                   include_plotlyjs=False)

## 3. Data Representations

In [None]:
def make_voxel(origin, index=0, voxel_size=1):
    ox, oy, oz = origin
    vs = voxel_size

    x = [ox, ox, ox + vs, ox + vs, ox, ox, ox + vs, ox + vs]
    y = [oy, oy + vs, oy + vs, oy, oy, oy + vs, oy + vs, oy]
    z = [oz, oz, oz, oz, oz + vs, oz + vs, oz + vs, oz + vs]

    i = (np.array([7, 0, 0, 0, 4, 4, 6, 6, 4, 0, 3, 2]) + index * 8).tolist()
    j = (np.array([3, 4, 1, 2, 5, 6, 5, 2, 0, 1, 6, 3]) + index * 8).tolist()
    k = (np.array([0, 7, 2, 3, 6, 7, 1, 1, 5, 5, 7, 6]) + index * 8).tolist()

    return x, y, z, i, j, k


def get_voxels(voxel_grid):
    x, y, z, i, j, k = [], [], [], [], [], []

    for index, v in enumerate(voxel_grid.get_voxels()):
        voxel = make_voxel(origin=v.grid_index, index=index)
        x.extend(voxel[0])
        y.extend(voxel[1])
        z.extend(voxel[2])
        i.extend(voxel[3])
        j.extend(voxel[4])
        k.extend(voxel[5])

    return x, y, z, i, j, k


def make_grid(res: int,
              pos: tuple = (0, 0, 0),
              color: str = "DarkSlateGray",
              line_width: float = 0.5,
              visible: bool = True):
    lines = []
    for i in range(res + 1):
        for j in range(res + 1):
            lines.append(go.Scatter3d(x=np.array([0, res]) + pos[0],
                                      y=np.array([i, i]) + pos[1],
                                      z=np.array([j, j]) + pos[2],
                                      mode="lines",
                                      marker=dict(color=color),
                                      line=dict(width=line_width),
                                      name="Grid",
                                      legendgroup="Grid",
                                      showlegend=False,
                                      hoverinfo="name",
                                      visible=visible))

            lines.append(go.Scatter3d(x=np.array([i, i]) + pos[0],
                                      y=np.array([0, res]) + pos[1],
                                      z=np.array([j, j]) + pos[2],
                                      mode="lines",
                                      marker=dict(color=color),
                                      line=dict(width=line_width),
                                      name="Grid",
                                      legendgroup="Grid",
                                      showlegend=False,
                                      hoverinfo="name",
                                      visible=visible))

            lines.append(go.Scatter3d(x=np.array([i, i]) + pos[0],
                                      y=np.array([j, j]) + pos[1],
                                      z=np.array([0, res]) + pos[2],
                                      mode="lines",
                                      marker=dict(color=color),
                                      line=dict(width=line_width),
                                      name="Grid",
                                      legendgroup="Grid",
                                      showlegend=False,
                                      hoverinfo="name",
                                      visible=visible))
    return lines

### 3.1 Image

In [None]:
# Downscaled image
image = Image.open(color_path)
image.thumbnail([160, 120])
image = np.asarray(image)
height, width = image.shape[:2]
color_image = np.array([f"rgb{rgb[0], rgb[1], rgb[2]}" for rgb in image.reshape(-1, 3)])
y, x = np.mgrid[0:height, 0:width]

image_plot = go.Scattergl(x=x.ravel(),
                          y=y.ravel(),
                          mode="markers",
                          marker=dict(size=4 if slides else 2,
                                      color=color_image,
                                      symbol="square"),
                          hovertemplate="%{x}, %{y}<extra><b>Pixel</b></extra>",
                          hoverlabel=dict(bgcolor=color_image),
                          showlegend=False)

# Figure
fig = go.Figure(image_plot)

fig.update_layout(template="plotly_white",
                  xaxis=dict(constrain="domain",
                             visible=False),
                  yaxis=dict(scaleanchor='x',
                             autorange="reversed",
                             visible=False),
                  hoverlabel=dict(font_size=font_size),
                  scene_camera=dict(eye=dict(x=0, y=0, z=1),
                                    up=dict(x=0, y=-1, z=0),
                                    projection=dict(type="orthographic")),
                  width=None if slides else 160,
                  height=4 * height if slides else 200,
                  margin=dict(r=0, l=0, b=0, t=0, pad=0))

if show:
    fig.show(config=dict(displayModeBar=False))
else:
    pio.write_html(fig,
                   file=os.path.join(out_dir, "image.html"),
                   config=dict(displayModeBar=False),
                   full_html=False,
                   include_plotlyjs=False)

### 3.2 Pointcloud

In [None]:
# Pointcloud
pointcloud = np.asarray(chair_mesh.sample_points_uniformly(10000).points)

pointcloud_plot = go.Scatter3d(x=pointcloud[:, 0],
                               y=pointcloud[:, 1],
                               z=pointcloud[:, 2],
                               mode="markers",
                               marker=dict(size=3 if slides else 1,
                                           color=pointcloud[:, 1],
                                           colorscale="Teal",
                                           line=dict(width=0.5 if slides else 0.1,
                                                     color="DarkSlateGrey")),
                               hovertemplate="<b>x</b> %{x:.2f}<br><b>y</b> %{y:.2f}<br><b>z</b> %{z:.2f}<extra><b>Point</b></extra>")

# Figure
fig = go.Figure(pointcloud_plot)

fig.update_layout(scene=dict(xaxis=dict(visible=False),
                             yaxis=dict(visible=False),
                             zaxis=dict(visible=False),
                             aspectmode="data"),
                  hoverlabel=dict(font_size=font_size),
                  showlegend=False,
                  width=None if slides else 160,
                  height=700 if slides else 200,
                  margin=dict(r=0, l=0, b=0, t=0, pad=0),
                  scene_camera=camera,
                  scene_dragmode="orbit")

if show:
    fig.show(config=dict(displayModeBar=False))
else:
    pio.write_html(fig,
                   file=os.path.join(out_dir, "pcd.html"),
                   config=dict(displayModeBar=False),
                   full_html=False,
                   include_plotlyjs=False)

### 3.3 Mesh

In [None]:
# Mesh
vertices = np.asarray(chair_mesh.vertices)
triangles = np.asarray(chair_mesh.triangles)

chair_mesh_plot = ff.create_trisurf(x=vertices[:, 0],
                                    y=vertices[:, 1],
                                    z=vertices[:, 2],
                                    simplices=triangles,
                                    plot_edges=True,
                                    colormap=['rgb(209, 238, 234)',
                                              'rgb(168, 219, 217)',
                                              'rgb(133, 196, 201)',
                                              'rgb(104, 171, 184)',
                                              'rgb(79, 144, 166)',
                                              'rgb(59, 115, 143)',
                                              'rgb(42, 86, 116)'],
                                    show_colorbar=False).data
chair_mesh_plot[0].hoverinfo = "name"
chair_mesh_plot[0].name = "Mesh"
chair_mesh_plot[1].hoverinfo = "name"
chair_mesh_plot[1].name = "Mesh"

# Figure
fig = go.Figure(chair_mesh_plot)

fig.update_layout(scene=dict(xaxis=dict(visible=False),
                             yaxis=dict(visible=False),
                             zaxis=dict(visible=False),
                             aspectmode="data"),
                  hoverlabel=dict(font_size=font_size),
                  showlegend=False,
                  width=None if slides else 160,
                  height=700 if slides else 200,
                  margin=dict(r=0, l=0, b=0, t=0, pad=0),
                  scene_camera=camera,
                  scene_dragmode="orbit")

if show:
    fig.show(config=dict(displayModeBar=False))
else:
    pio.write_html(fig,
                   file=os.path.join(out_dir, "mesh.html"),
                   config=dict(displayModeBar=False),
                   full_html=False,
                   include_plotlyjs=False)

### 3.4 Voxel Grid

In [None]:
# Voxels
voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh(chair_mesh, voxel_size=0.04)
x, y, z, i, j, k = get_voxels(voxel_grid)

vertices = np.vstack([x, y, z]).T
triangles = np.vstack([i, j, k]).T

mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(vertices)
mesh.triangles = o3d.utility.Vector3iVector(triangles)

mesh.remove_duplicated_vertices()
mesh.remove_unreferenced_vertices()

mesh.remove_duplicated_triangles()
mesh.remove_degenerate_triangles()

vertices = np.asarray(mesh.vertices)
triangles = np.asarray(mesh.triangles)
x = vertices[:, 0]
y = vertices[:, 1]
z = vertices[:, 2]
i = triangles[:, 0]
j = triangles[:, 1]
k = triangles[:, 2]

voxel_grid_plot = go.Mesh3d(x=x,
                            y=y,
                            z=z,
                            i=i,
                            j=j,
                            k=k,
                            hovertemplate="x: %{x}<br>y: %{y}<br>z: %{z}<extra><b>Voxel</b></extra>",
                            colorscale="Teal",
                            intensity=y,
                            flatshading=False,
                            showscale=False)

# Grid lines
grid_line_plots = make_grid(res=28, pos=(-2, -3, 0), visible="legendonly")
grid_line_plots[-1].showlegend = True

# Figure
plots = grid_line_plots + [voxel_grid_plot] if slides else voxel_grid_plot
fig = go.Figure(plots)

fig.update_layout(scene=dict(xaxis=dict(visible=False),
                             yaxis=dict(visible=False),
                             zaxis=dict(visible=False),
                             aspectmode="data"),
                  hoverlabel=dict(font_size=font_size),
                  showlegend=slides,
                  legend=dict(yanchor="top",
                              y=0.99,
                              xanchor="left",
                              x=0.01,
                              font_size=font_size,
                              itemsizing="constant"),
                  width=None if slides else 160,
                  height=700 if slides else 200,
                  margin=dict(r=0, l=0, b=0, t=0, pad=0),
                  scene_camera=dict(eye=dict(x=1, y=0, z=0),
                                    up=dict(x=0, y=0, z=1),
                                    center=dict(x=0, y=0, z=0),
                                    projection=dict(type="orthographic")),
                  scene_dragmode="orbit")

if show:
    fig.show(config=dict(displayModeBar=False))
else:
    pio.write_html(fig,
                   file=os.path.join(out_dir, "voxel.html"),
                   config=dict(displayModeBar=False),
                   full_html=False,
                   include_plotlyjs=False)

## 4. Problems

In [None]:
# Frame
frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.05)
vertices = np.asarray(frame.vertices)
triangles = np.asarray(frame.triangles)

frame_plot = go.Mesh3d(x=vertices[:, 0],
                       y=-vertices[:, 1],
                       z=vertices[:, 2],
                       i=triangles[:, 0],
                       j=triangles[:, 1],
                       k=triangles[:, 2],
                       vertexcolor=np.asarray(frame.vertex_colors),
                       showscale=False,
                       showlegend=True,
                       name="Frame",
                       hoverinfo="name")

# Box
bbox_plot = make_grid(1, pos=(-0.5, -0.5, -0.5), line_width=1)

# Rendered depth
points = np.asarray(depth_points.points)

rendered_depth_plot = go.Scatter3d(x=points[:, 0],
                                   y=points[:, 1],
                                   z=points[:, 2],
                                   mode="markers",
                                   marker=dict(size=2 if slides else 1,
                                               color=points[:, 2],
                                               colorscale="magma",
                                               line=dict(width=0.25 if slides else 0.1,
                                                         color="DarkSlateGrey")),
                                   hovertemplate="<b>x</b> %{x:.2f}<br><b>y</b> %{y:.2f}<br><b>z</b> %{z:.2f}<extra><b>Point</b></extra>",
                                   showlegend=False)

# Noise
noisy_points = copy.deepcopy(points)
centroids = np.random.choice(len(noisy_points), 10)
points_list = list()
for c in centroids:
    np.random.shuffle(noisy_points)
    points_list.append(noisy_points[c - 1000:c + 1000, :])
noisy_points = np.concatenate(points_list)
noisy_points = noisy_points[noisy_points[:, 2] > 0]
noisy_points += 0.005 * np.random.rand(*noisy_points.shape)

noisy_depth_plot = go.Scatter3d(x=noisy_points[:, 0],
                                y=noisy_points[:, 1],
                                z=noisy_points[:, 2],
                                mode="markers",
                                marker=dict(size=2 if slides else 1,
                                            color=noisy_points[:, 2],
                                            colorscale="magma",
                                            line=dict(width=0.25 if slides else 0.1,
                                                      color="DarkSlateGrey")),
                                hovertemplate="<b>x</b> %{x:.2f}<br><b>y</b> %{y:.2f}<br><b>z</b> %{z:.2f}<extra><b>Point</b></extra>",
                                showlegend=False,
                                visible=False)

# Centered
loc = (noisy_points.min(axis=0) + noisy_points.max(axis=0)) / 2
noisy_points -= loc

centered_plot = go.Scatter3d(x=noisy_points[:, 0],
                             y=noisy_points[:, 1],
                             z=noisy_points[:, 2],
                             mode="markers",
                             marker=dict(size=2 if slides else 1,
                                         color=noisy_points[:, 2],
                                         colorscale="magma",
                                         line=dict(width=0.25 if slides else 0.1,
                                                   color="DarkSlateGrey")),
                             hovertemplate="<b>x</b> %{x:.2f}<br><b>y</b> %{y:.2f}<br><b>z</b> %{z:.2f}<extra><b>Point</b></extra>",
                             showlegend=False,
                             visible=False)

# Scaled
scale = (noisy_points.max(axis=0) - noisy_points.min(axis=0)).max()
noisy_points /= scale

scaled_plot = go.Scatter3d(x=noisy_points[:, 0],
                           y=noisy_points[:, 1],
                           z=noisy_points[:, 2],
                           mode="markers",
                           marker=dict(size=2 if slides else 1,
                                       color=noisy_points[:, 2],
                                       colorscale="magma",
                                       line=dict(width=0.25 if slides else 0.1,
                                                 color="DarkSlateGrey")),
                           hovertemplate="<b>x</b> %{x:.2f}<br><b>y</b> %{y:.2f}<br><b>z</b> %{z:.2f}<extra><b>Point</b></extra>",
                           showlegend=False,
                           visible=False)

# Rotated
noisy_points = (Rotation.random().as_matrix() @ noisy_points.T).T

rotated_points = go.Scatter3d(x=noisy_points[:, 0],
                              y=noisy_points[:, 1],
                              z=noisy_points[:, 2],
                              mode="markers",
                              marker=dict(size=2 if slides else 1,
                                          color=noisy_points[:, 2],
                                          colorscale="magma",
                                          line=dict(width=0.25 if slides else 0.1,
                                                    color="DarkSlateGrey")),
                              hovertemplate="<b>x</b> %{x:.2f}<br><b>y</b> %{y:.2f}<br><b>z</b> %{z:.2f}<extra><b>Point</b></extra>",
                              showlegend=False,
                              visible=False)

# Figure
fig = go.Figure(
    bbox_plot + [frame_plot] + [rendered_depth_plot, noisy_depth_plot, centered_plot, scaled_plot, rotated_points])

buttons = [dict(label="Rendered Depth", method="update",
                args=[dict(visible=[True] * (len(bbox_plot) + 1) + [True, False, False, False, False])]),
           dict(label="Noise", method="update",
                args=[dict(visible=[True] * (len(bbox_plot) + 1) + [False, True, False, False, False])]),
           dict(label="Centered", method="update",
                args=[dict(visible=[True] * (len(bbox_plot) + 1) + [False, False, True, False, False])]),
           dict(label="Scaled", method="update",
                args=[dict(visible=[True] * (len(bbox_plot) + 1) + [False, False, False, True, False])]),
           dict(label="Rotated", method="update",
                args=[dict(visible=[True] * (len(bbox_plot) + 1) + [False, False, False, False, True])])]

fig.update_layout(scene=dict(xaxis=dict(visible=False),
                             yaxis=dict(visible=False),
                             zaxis=dict(visible=False),
                             aspectmode="data"),
                  hoverlabel=dict(font_size=font_size),
                  showlegend=False,
                  height=1200 if slides else 500,
                  margin=dict(r=0, l=0, b=0, t=0, pad=0),
                  scene_camera=dict(eye=dict(x=0, y=0, z=2),
                                    up=dict(x=0, y=0, z=1),
                                    center=dict(x=0, y=0, z=0)),
                  scene_dragmode="orbit",
                  updatemenus=[dict(type="buttons",
                                    xanchor="left",
                                    x=0,
                                    active=0,
                                    font_size=font_size,
                                    buttons=buttons)])

if show:
    fig.show(config=dict(displayModeBar=False))
else:
    pio.write_html(fig,
                   file=os.path.join(out_dir, "problems.html"),
                   config=dict(displayModeBar=False),
                   full_html=False,
                   include_plotlyjs=False)

## 5. Results

In [None]:
def plot_result(bottle_vis_id: int, scene_id: int = False, input_points=None, generated_mesh=None, pose=None, gt_mesh=None, points_iou = None, probabilities=None):
    # Input points
    points = np.asarray(input_points.points)

    points_plot = go.Scatter3d(x=points[:, 0],
                               y=points[:, 1],
                               z=points[:, 2],
                               mode="markers",
                               marker=dict(size=3,
                                           color=points[:, 1],
                                           colorscale="magma",
                                           line=dict(width=0.5,
                                                     color="DarkSlateGrey")),
                               hovertemplate="<b>x</b> %{x:.2f}<br><b>y</b> %{y:.2f}<br><b>z</b> %{z:.2f}<extra><b>Point</b></extra>",
                               showlegend=True,
                               name="Input Points")

    # Generated mesh
    vertices = np.asarray(generated_mesh.vertices)
    triangles = np.asarray(generated_mesh.triangles)

    mesh_plot = ff.create_trisurf(x=vertices[:, 0],
                                  y=vertices[:, 1],
                                  z=vertices[:, 2],
                                  simplices=triangles,
                                  plot_edges=True,
                                  colormap="#00CC96",
                                  edges_color=colorscale("#00CC96", 0.8),
                                  show_colorbar=False).data
    mesh_plot[0].showlegend = True
    mesh_plot[1].showlegend = False
    mesh_plot[0].visible = "legendonly"
    mesh_plot[1].visible = "legendonly"
    mesh_plot[0].legendgroup = "Mesh"
    mesh_plot[1].legendgroup = "Mesh"
    mesh_plot[0].name = "Mesh"
    mesh_plot[1].name = "Mesh"
    mesh_plot[0].hoverinfo = "name"
    mesh_plot[1].hoverinfo = "none"

    if scene_id:
        # Ground truth mesh
        vertices = np.asarray(gt_mesh.vertices)
        triangles = np.asarray(gt_mesh.triangles)

        gt_mesh_plot = ff.create_trisurf(x=vertices[:, 0],
                                         y=vertices[:, 1],
                                         z=vertices[:, 2],
                                         simplices=triangles,
                                         plot_edges=True,
                                         colormap="#FFA15A",
                                         edges_color=colorscale("#FFA15A", 0.8),
                                         show_colorbar=False).data
        gt_mesh_plot[0].showlegend = True
        gt_mesh_plot[1].showlegend = False
        gt_mesh_plot[0].visible = "legendonly"
        gt_mesh_plot[1].visible = "legendonly"
        gt_mesh_plot[0].legendgroup = "Mesh2"
        gt_mesh_plot[1].legendgroup = "Mesh2"
        gt_mesh_plot[0].name = "Ground Truth Mesh"
        gt_mesh_plot[1].name = "Ground Truth Mesh"
        gt_mesh_plot[0].hoverinfo = "name"
        gt_mesh_plot[1].hoverinfo = "none"

        # Frame
        frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.05).transform(pose)
        vertices = np.asarray(frame.vertices)
        triangles = np.asarray(frame.triangles)

        frame_plot = go.Mesh3d(x=vertices[:, 0],
                               y=vertices[:, 1],
                               z=vertices[:, 2],
                               i=triangles[:, 0],
                               j=triangles[:, 1],
                               k=triangles[:, 2],
                               vertexcolor=np.asarray(frame.vertex_colors),
                               showscale=False,
                               showlegend=True,
                               name="Frame",
                               hoverinfo="name",
                               visible="legendonly")

        # Probabilities
        indices = probabilities >= 0.2

        probabilities_plot = go.Scatter3d(x=points_iou[indices, 0],
                                          y=points_iou[indices, 1],
                                          z=points_iou[indices, 2],
                                          mode="markers",
                                          marker=dict(size=3,
                                                      color=probabilities[indices],
                                                      colorscale="Plotly3",
                                                      line=dict(width=0.5,
                                                                color="DarkSlateGrey")),
                                          text=probabilities[indices],
                                          hovertemplate="%{text:.2f}<extra><b>Probability</b></extra>",
                                          showlegend=True,
                                          name="Probabilities",
                                          visible="legendonly")

    # Figure
    fig = go.Figure([points_plot, mesh_plot[0], mesh_plot[1]])

    if scene_id:
        camera["eye"] = dict(x=2, y=2, z=2)
        camera["up"] = dict(x=0, y=0, z=1)
        fig.add_traces([gt_mesh_plot[0], gt_mesh_plot[1], frame_plot, probabilities_plot])

    if bottle_vis_id:
        camera["eye"] = dict(x=0.5, y=1.8, z=1.8)
        camera["up"] = dict(x=0, y=1, z=0)

    fig.update_layout(scene=dict(xaxis=dict(visible=False),
                                 yaxis=dict(visible=False),
                                 zaxis=dict(visible=False),
                                 aspectmode="data"),
                      hoverlabel=dict(font_size=22),
                      legend=dict(yanchor="top",
                                  y=0.99,
                                  xanchor="left",
                                  x=0.01,
                                  font_size=22,
                                  itemsizing="constant",
                                  bgcolor="rgba(0, 0, 0, 0)"),
                      height=600,
                      margin=dict(r=0, l=0, b=0, t=0, pad=0),
                      scene_camera=camera,
                      scene_dragmode="orbit")

    if show:
        fig.show(config=dict(displayModeBar=False))
    else:
        pio.write_html(fig,
                       file=os.path.join(f"{'real' if scene_id else 'val'}_result{scene_id if scene_id else bottle_vis_id}.html"),
                       config=dict(displayModeBar=False),
                       full_html=False,
                       include_plotlyjs=False)