In [None]:
import torch

# data loading
from plyfile import PlyData

# plotting
from meshplot import plot

# implementation
from implementation import (
    point_rasterization,
    grid_interpolation,
    unitsphere_scaling,
    minmax_scaling,
)

### Load data and compute rasterization

In [None]:
SCAN = "./deep_geometric_prior_data/scans/anchor.ply"
GRID = tuple([256] * 3)

In [None]:
def read_ply_file(file):
    data = PlyData.read(file)

    xyz = [torch.tensor(data["vertex"][axis]) for axis in ["x", "y", "z"]]
    xyz = torch.stack(xyz, dim=-1)

    nxyz = [torch.tensor(data["vertex"][axis]) for axis in ["nx", "ny", "nz"]]
    nxyz = torch.stack(nxyz, dim=-1)
    return xyz.unsqueeze(0), nxyz.unsqueeze(0)


points, normals = read_ply_file(SCAN)
points = torch.sigmoid(unitsphere_scaling(points, scale=0.9))

In [None]:
# encode normals into grid
normal_vector_field = point_rasterization(points, normals, GRID)

In [None]:
# retrieve normal values
# permute: (one grid per feature) -> (grid with all features)
normal_vector_field = normal_vector_field.permute(0, 2, 3, 4, 1)
n = grid_interpolation(normal_vector_field, points)

### Example Rasterization
Following we show the difference between interpolated normals and true normals

In [None]:
# plot with interpolated normals
plot(points[0].numpy(), c=n[0].numpy())

In [None]:
# plot with actual normals
plot(points[0].numpy(), c=normals[0].numpy())