In [None]:
import torch

# data loading
import pandas as pd
from plyfile import PlyData

# plotting
from meshplot import plot
from ipywidgets import interact, FloatSlider

# implementation
from sap import ShapeAsPoints

# generate initial pointcloud
from pytorch3d.utils import ico_sphere
from pytorch3d.ops import sample_points_from_meshes

### Load data and set parameters

In [None]:
MODEL_NAME = "gargoyle"
SCAN = f"./deep_geometric_prior_data/scans/{MODEL_NAME}.ply"
GT = f"./deep_geometric_prior_data/ground_truth/{MODEL_NAME}.xyz"

In [None]:
def read_ply_file(path):
    data = PlyData.read(path)

    xyz = [torch.tensor(data["vertex"][axis]) for axis in ["x", "y", "z"]]
    xyz = torch.stack(xyz, dim=-1)

    nxyz = [torch.tensor(data["vertex"][axis]) for axis in ["nx", "ny", "nz"]]
    nxyz = torch.stack(nxyz, dim=-1)
    return xyz.unsqueeze(0), nxyz.unsqueeze(0)

def read_xyz_file(path):
    return torch.tensor(
        pd.read_table(path, delim_whitespace=True, names=["x", "y", "z"]).to_numpy()
    ).unsqueeze(0)

In [None]:
# we should use GT here but computing the loss on 
# 2M points without sampling takes too long
gt_points, normals = read_ply_file(SCAN)
# gt_points = read_xyz_file(GT).float()

In [None]:
# we start with a sphere
sphere = ico_sphere(5)
init_points, init_normals = sample_points_from_meshes(
    sphere, num_samples=50000, return_normals=True
)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
sap = ShapeAsPoints(init_points, init_normals, gt_points, resolution=32, device="cuda")

### Optimization approach example
Note that if no gpu is available, grid resolution should be vastly reduced.

In [None]:
fig = plot(
    v=sphere.verts_packed().numpy(),
    f=sphere.faces_packed().numpy(),
    c=sphere.verts_normals_packed().numpy(),
)
def update_figure(rec):
    v,f,n = rec["mesh"]
    plot(v[0].cpu().numpy(), f=f[0].cpu().numpy(), c=n[0].cpu().numpy(), plot=fig)

save = []
def callback(rec):
    save.append(rec["loss"])
    if (rec["epoch"]-1) % 100 == 0:
        update_figure(rec)

sap.train(
    [(32, 2500), (64, 500), (128, 500), (256, 500)],
    callback=callback
)