In [None]:
import torch

# data loading
import pandas as pd
from plyfile import PlyData

# plotting
from meshplot import plot
from ipywidgets import interact, FloatSlider

# implementation
from implementation import minmax_scaling, centerzoom_scaling, DPSR_forward

### Load data and set parameters

In [None]:
SCAN = "./deep_geometric_prior_data/scans/dc.ply"

## configuration
GRID = tuple([256] * 3)
SIGMA = 1
# authors use a "normalization" + sigmoid schema as preprocessing
# anything can be used as long as the final space in which points
# reside in the interval (0,1) along each dimension
PREPROC = lambda x: torch.sigmoid(centerzoom_scaling(x, scale=0.9))
# we tried min-max scaling with noticeably worse results then the above
# PREPROC = lambda x: torch.sigmoid(minmax_scaling(x, eps=1e-5))

In [None]:
def read_ply_file(path):
    data = PlyData.read(path)

    xyz = [torch.tensor(data["vertex"][axis]) for axis in ["x", "y", "z"]]
    xyz = torch.stack(xyz, dim=-1)

    nxyz = [torch.tensor(data["vertex"][axis]) for axis in ["nx", "ny", "nz"]]
    nxyz = torch.stack(nxyz, dim=-1)
    return xyz.unsqueeze(0), nxyz.unsqueeze(0)

In [None]:
points, normals = read_ply_file(SCAN)

### Pointcloud with normals

In [None]:
# plot pointcloud with color given by normals
plot(minmax_scaling(points[0], eps=0).numpy(), c=normals[0].numpy());

### Indicator function $\chi$

In [None]:
eps = 1e-6  # division by zero
m = 0.5  # indicator scaling


class SelfMockup:
    pass


# test ours
mockup = SelfMockup()
mockup.grid = GRID
mockup.sigma = SIGMA
mockup.eps = eps
mockup.m = m
chi = DPSR_forward(mockup, PREPROC(points), normals)

### Estimate of grid distribution

In [None]:
atol = 0.05  # threshold for indicator function
status = {
    "external": (chi > atol).sum(),
    "internal": (chi < -atol).sum(),
    "border": ((chi >= -atol) & (chi <= atol)).sum(),
}
status

### Plot Indicator Function

In [None]:
# use our implementation for plotting
chi = chi.squeeze(0)
shading = {"point_color": "green", "point_size": 0.1}
fig = plot(
    v=(chi < -1e-1).nonzero().type(torch.float).numpy(),
    return_plot=True,
    shading=shading,
)


@interact(
    what=["indicator", "internal", "external"],
    thresh=FloatSlider(
        value=1e-1, min=1e-2, max=2e-1, step=1e-2, description="threshold"
    ),
    pointsize=FloatSlider(
        value=1e-2, min=1e-2, max=5e-2, step=1e-3, description="pointsize"
    ),
)
def plot_points(what, thresh, pointsize):
    if what == "external":
        filter = chi > thresh
    elif what == "internal":
        filter = chi < -thresh
    else:
        filter = (chi >= -thresh) & (chi <= thresh)

    points = filter.nonzero().type(torch.float) / GRID[0]
    shading.update({"point_size": pointsize})
    plot(points.numpy(), shading=shading, plot=fig)