In [None]:
import torch
import numpy as np
import pandas as pd

# from pytorch3d.structures import Pointclouds
from plyfile import PlyData

import meshplot

In [None]:
# GROUNDTRUTH = "./deep_geometric_prior_data/ground_truth/anchor.xyz"
SCAN = "./deep_geometric_prior_data/scans/anchor.ply"

In [None]:
def read_ply_file(file):
    data = PlyData.read(SCAN)

    xyz = [torch.tensor(data["vertex"][axis]) for axis in ["x", "y", "z"]]
    xyz = torch.stack(xyz, dim=-1)

    nxyz = [torch.tensor(data["vertex"][axis]) for axis in ["nx", "ny", "nz"]]
    nxyz = torch.stack(nxyz, dim=-1)
    return xyz.unsqueeze(0), nxyz.unsqueeze(0)


def read_xyz_file(path):
    return torch.tensor(
        pd.read_table(
            GROUNDTRUTH, delim_whitespace=True, names=["x", "y", "z"]
        ).to_numpy()
    ).unsqueeze(0)

In [None]:
# read xyz file
# gt_xyz = read_xyz_file(GROUNDTRUTH)
# gt_cloud = Pointclouds(gt_xyz)

# read ply file
points, normals = read_ply_file(SCAN)
# scan_cloud = Pointclouds(scan_xyz, normals=scan_nxyz)

In [None]:
from implementation import minmax_scaling

In [None]:
# authors use a normalization + sigmoid
# we prefer to use a simpler min-max scaling as it retains the original shape
meshplot.plot(minmax_scaling(points[0], eps=1e-5).numpy(), c=normals[0].numpy());

In [None]:
from implementation import DPSR_forward
from oracle import forward, spec_gaussian_filter

In [None]:
## configuration
GRID = tuple([64] * 3)
PREPROC = lambda x: minmax_scaling(x.type(torch.float64), eps=1e-5)
SIGMA = 5
eps = 1e-6
m = 0.5


class SelfMockup:
    pass


# test ours
mockup = SelfMockup()
mockup.grid = GRID
mockup.sigma = SIGMA
mockup.eps = eps
mockup.m = m
ourchi = DPSR_forward(mockup, PREPROC(points), normals)

# test authors
mockup = SelfMockup()
mockup.dim = len(GRID)
mockup.res = GRID
mockup.shift = True
mockup.scale = True
mockup.G = spec_gaussian_filter(GRID, SIGMA)
testchi = forward(mockup, PREPROC(points), normals)

In [None]:
# check implementation correctness with oracle
# assert torch.allclose(testchi, ourchi, atol=1e-6)
(testchi - ourchi).max()

In [None]:
atol = 0.05
status = {
    "external": (ourchi > atol).sum(),
    "internal": (ourchi < -atol).sum(),
    "border": ((ourchi >= -atol) & (ourchi <= atol)).sum(),
}
status

In [None]:
from ipywidgets import interact, FloatSlider

In [None]:
chi = ourchi.squeeze(0)
shading = {"point_color": "green", "point_size": 0.1}
fig = meshplot.plot(
    v=(chi < -1e-1).nonzero().type(torch.float).numpy(),
    return_plot=True,
    shading=shading,
)


@interact(
    plot=["indicator", "internal", "external"],
    thresh=FloatSlider(
        value=1e-1, min=1e-2, max=2e-1, step=1e-2, description="threshold"
    ),
    psize=FloatSlider(value=1e-3, min=1e-3, max=1e-1, step=1e-3, description="size"),
)
def plot_points(plot, thresh, psize):
    if plot == "external":
        filter = chi > thresh
    elif plot == "internal":
        filter = chi < -thresh
    else:
        filter = (chi >= -thresh) & (chi <= thresh)

    points = filter.nonzero().type(torch.float) / (GRID[0])
    shading.update({"point_size": psize})
    meshplot.plot(points.numpy(), shading=shading, plot=fig)