In [None]:
import imageio
import nvdiffrast.torch as dr
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
from umbra import MeshViewer

from diffshadow.common import to_display_image, vflip
from diffshadow.transformation import create_lookat_matrix, create_orthographic_projection_matrix
from diffshadow.simple_renderer import Camera, Mesh

from drpg import *
from drpg.rendering import render_silhouette

def circle(t, radius: float = 1.0):
    x = radius * torch.cos(t)
    y = radius * torch.sin(t)
    z = torch.zeros_like(t)
    return torch.stack([x, y, z], dim=-1)

device = torch.device('cuda:0')

viewer = MeshViewer()

In [None]:
num_curves = 16
n          = 4 # degree + 1
scale      = 0.4

t = torch.linspace(0, 2*math.pi, (num_curves-1) * (n-1) + n)
x = circle(t, scale)

# a = [0, 1, 2]
a = torch.arange(0, num_curves)
# a = [0, 2, 4] (n=2)
a = (n-1)*a
# a = [[0, 1, 2], [2, 3, 4], ...]
a = a[:, None] + torch.arange(n)[None, :]

curves = ParametricCurves(n, d=3, device=device)
curves.add(x[a])
curves.merge_duplicate_vertices()

profile_radius = 0.03
profile        = circle_profile(profile_radius, n=32, device=device)
tessellator = CurveTessellator(u_resolution=16)
v, f = tessellator.tessellate(curves, profile)[:2]

viewer.set_mesh(v.detach().cpu().numpy(), f.cpu().numpy(), object_name='surface_1')

In [None]:
# Load the reference image
reference_path = "data/outline_star.png"
# reference_path = "data/outline_bear.png"
silhouette_ref = torch.from_numpy(imageio.imread(reference_path))[:, :, :3].to(device, dtype=torch.float32) / 255.
silhouette_ref = vflip(silhouette_ref.mean(dim=-1, keepdim=True))

plt.imshow(to_display_image(silhouette_ref.detach().cpu(), grayscale_to_rgb=True))

In [None]:
render_context = dr.RasterizeGLContext(device=device)

camera = Camera(
    view_matrix=create_lookat_matrix([0, 0, 2], [0, 0, 0], [0, 1, 0], device=device),
    projection_matrix=create_orthographic_projection_matrix(1, 0.01, 10, device=device)
)

silhouette_init = render_silhouette(render_context, Mesh(v, f.to(dtype=torch.int32)), camera, (256, 256), spp=2)

plt.imshow(to_display_image(silhouette_init.detach().cpu(), grayscale_to_rgb=True))

In [None]:
curves_opt = curves.with_control_points(curves.V.clone())

use_large_steps    = True
large_steps_lambda = 0.5
c1_loss_weight     = 20
lr                 = 0.1 if use_large_steps else 0.005
optimize_radius    = True

if use_large_steps:
    from largesteps.parameterize import to_differential, from_differential
    from largesteps.optimize import AdamUniform

    # Compute the system matrix
    V, F = convert_curves_to_graph(curves_opt)
    M    = compute_lsig_matrix(V, F, lambda_=large_steps_lambda)

    # Parameterize
    u_opt = to_differential(M, curves_opt.V)
    u_opt.requires_grad = True
    
    optimizer = AdamUniform([u_opt], lr=lr)
else:
    curves_opt.V.requires_grad_(True)
    optimizer = torch.optim.Adam([curves_opt.V], lr=lr)

if optimize_radius:
    radius_opt = torch.tensor(profile_radius, requires_grad=True, device=device)
    optimizer_radius = torch.optim.Adam([radius_opt], lr=0.0001)

def compute_mesh(t_prev, n_prev, b_prev):
    if use_large_steps:
        curves_opt.V = from_differential(M, u_opt, 'Cholesky')

    if optimize_radius:
        profile_opt = circle_profile(radius_opt, n=32, device=device)
    else:
        profile_opt = profile

    v_opt, f_opt, t_prev, n_prev, b_prev = tessellator.tessellate(curves_opt, profile_opt, t_prev=t_prev, n_prev=n_prev, b_prev=b_prev)
    mesh_opt = Mesh(v_opt, f_opt.to(dtype=torch.int32))
    return mesh_opt, t_prev, n_prev, b_prev

tangent_map = build_curve_tangent_map(curves)

t_prev, n_prev, b_prev = (None, None, None)

progress_iterator = tqdm(range(1200))
for it in progress_iterator:
    if it % 100 == 0:
        # Update rotation minimizing frames in this iteration
        t_prev = None
        n_prev = None
        b_prev = None

    mesh_opt, t_prev, n_prev, b_prev = compute_mesh(t_prev, n_prev, b_prev)
    
    silhouette_opt = render_silhouette(render_context, mesh_opt, camera, (256, 256), spp=2)

    loss = ((silhouette_opt - silhouette_ref)**2).mean()

    if c1_loss_weight > 0:
        loss += c1_loss_weight*c1_curve_loss(curves_opt, tangent_map)

    if optimize_radius:
        optimizer_radius.zero_grad()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if optimize_radius:
        optimizer_radius.step()
        with torch.no_grad():
            radius_opt.data.clamp_min_(0.01)

    viewer.set_points(curves_opt.V[curves_opt.F].reshape(-1, 3).detach().cpu().numpy(), object_name='control_points')
    viewer.set_mesh(mesh_opt.vertices.detach().cpu().numpy(), mesh_opt.faces.cpu().numpy(), object_name='surface_1')

plt.imshow(to_display_image(silhouette_opt, grayscale_to_rgb=True))