In [None]:
import logging
import math
from os import path

import imageio
import torch
from pytorch3d.renderer.points.pulsar import Renderer

x, y = 650, 300

In [None]:
torch.manual_seed(1)
n_points = 10
width = 1_000
height = 1_000
device = torch.device("cuda")
# The PyTorch3D system is right handed; in pulsar you can choose the handedness.
# For easy reproducibility we use a right handed coordinate system here.
renderer = Renderer(width, height, n_points, right_handed_system=True).to(device)
# Generate sample data.
vert_pos = torch.rand(n_points, 3, dtype=torch.float32, device=device) * 10.0
vert_pos[:, 2] += 25.0
vert_pos[:, :2] -= 5.0
vert_col = torch.rand(n_points, 3, dtype=torch.float32, device=device, requires_grad=True)
vert_rad = torch.rand(n_points, dtype=torch.float32, device=device)
opacity = torch.ones_like(vert_rad) * 0.1
opacity.requires_grad = True
cam_params = torch.tensor(
    [
        0.0,
        0.0,
        0.0,  # Position 0, 0, 0 (x, y, z).
        0.0,
        math.pi,  # Because of the right handed system, the camera must look 'back'.
        0.0,  # Rotation 0, 0, 0 (in axis-angle format).
        5.0,  # Focal length in world size.
        2.0,  # Sensor size in world size (width).
    ],
    dtype=torch.float32,
    device=device,
)

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams["figure.figsize"] = (20,20)

vert_col.grad = None
opacity.grad = None

# Render.
image = renderer(
    vert_pos,
    vert_col,
    vert_rad,
    cam_params,
    1.0e-1,  # Renderer blending parameter gamma, in [1., 1e-5].
    45.0,  # Maximum depth.
    opacity=opacity,
    mode=2,
)
print (y, x, image[y, x])
image[y, x].sum().backward()
print (vert_col.grad)
print (opacity.grad)

image = (image.cpu().detach() * 255.0).to(torch.uint8).numpy()

plt.imshow(image)
plt.show()

In [None]:
from pytorch3d.transforms import axis_angle_to_matrix
cam_pos = cam_params[0:3]
cam_R = axis_angle_to_matrix(cam_params[3:6])
cam_focal = cam_params[6] / cam_params[7] * width
print ("cam_pos:", cam_pos)
print ("cam_R:", cam_R)
print ("cam_focal:", cam_focal)

cam_T = - torch.matmul(cam_R, cam_pos)
camtoworld = torch.cat([cam_R.t(), -cam_T[:, None]], dim=1)
camtoworld = torch.cat([
    camtoworld, torch.tensor([[0., 0., 0., 1.]], device=camtoworld.device)
], dim=0)
print ("camtoworld", camtoworld.shape)

In [None]:
import numpy as np
import collections


Rays = collections.namedtuple("Rays", ("origins", "directions", "viewdirs"))


def generate_rays(w, h, focal, camtoworlds):
    """
    Generate perspective camera rays. Principal point is at center.
    Args:
        w: int image width
        h: int image heigth
        focal: float real focal length
        camtoworlds: jnp.ndarray [B, 4, 4] c2w homogeneous poses
        equirect: if true, generates spherical rays instead of pinhole
    Returns:
        rays: Rays a namedtuple(origins [B, 3], directions [B, 3], viewdirs [B, 3])
    """
    x, y = np.meshgrid(  # pylint: disable=unbalanced-tuple-unpacking
        np.arange(w, dtype=np.float32),  # X-Axis (columns)
        np.arange(h, dtype=np.float32),  # Y-Axis (rows)
        indexing="xy",
    )

    camera_dirs = np.stack(
        [
            (x + 0.5 - w * 0.5) / focal,
            -(y + 0.5 - h * 0.5) / focal,
            -np.ones_like(x),
        ],
        axis=-1,
    )

    c2w = camtoworlds[:, None, None, :3, :3]
    camera_dirs = camera_dirs[None, Ellipsis, None]
    directions = np.matmul(c2w, camera_dirs)[Ellipsis, 0]
    origins = np.broadcast_to(
        camtoworlds[:, None, None, :3, -1], directions.shape
    )
    norms = np.linalg.norm(directions, axis=-1, keepdims=True)
    viewdirs = directions / norms
    rays = Rays(
        origins=origins, directions=directions, viewdirs=viewdirs
    )
    return rays


rays = generate_rays(width, height, cam_focal.cpu(), camtoworld[None, :].cpu().numpy()) 

In [None]:
def ray_sphere_intersections(rays, vert_pos, vert_rad):
    # rays.origins, rays.viewdirs: [1, 500, 500, 3]
    # vert_pos: [N, 3]
    # vert_rad: [N, ]
    vert_pos = vert_pos.cpu().numpy()
    vert_rad = vert_rad.cpu().numpy()

    # [N, 500, 500, 3]
    o__sphere_ = vert_pos[:, None, None, :] - rays.origins
    # [N, 500, 500]
    o__p1_dist = np.sum(o__sphere_ * rays.viewdirs, axis=-1)
    # [N, 500, 500, 3]
    o__p1_ = rays.viewdirs * o__p1_dist[..., None]
    # [N, 500, 500, 3]
    p1__sphere_= o__sphere_ - o__p1_
    # [N, 500, 500]
    p1__sphere_dist = np.linalg.norm(p1__sphere_, axis=-1)
    
    # whether intersection happens
    hits = np.logical_and(
        o__p1_dist > 0, p1__sphere_dist <= vert_rad[:, None, None]
    )
    print ("center:", o__sphere_[:, y, x])
    print ("o__p1:", o__p1_dist[:, y, x])
    hits_depth = o__p1_dist
    return hits, hits_depth 

hits, hits_depth = ray_sphere_intersections(rays, vert_pos, vert_rad)

In [None]:
hits_depth[~hits] = float("inf")
mask = np.float32(hits_depth).min(axis=0)
plt.imshow(mask)
plt.show()

In [None]:
torch.autograd.set_detect_anomaly(True)
vert_col.grad = None
opacity.grad = None

sorted_sphere_idxs = np.argsort(hits_depth[:, y, x], axis=0)

bg_col = torch.ones((3,), device=device)
light_intensity = 1.0
color = torch.zeros((3,), device=device)
for i, sphere_id in enumerate(sorted_sphere_idxs):
    hit = hits[sphere_id, y, x]
    t = hits_depth[sphere_id, y, x]
    if t == float("inf") or hit == False:
        continue
    if i < len(sorted_sphere_idxs) - 1:
        sphere_id_next = sorted_sphere_idxs[i + 1]
        t_next = hits_depth[sphere_id_next, y, x]
    else:
        t_next = float("inf")
    delta_t = min(abs(t_next - t), 1e10)
    sigma = opacity[sphere_id]
    att = torch.exp(- delta_t * sigma)
    weight = light_intensity * (1. - att)
    color = color + weight * vert_col[sphere_id]

    print (
        "render|nerf accum. i(%d), sphere_id(%d) t(%.5f), delta_t(%.5f) "
        "sigma(%.5f), att(%.5f), alpha(%.5f), "
        "T(%.5f), weight(%.5f), "
        "result(%.5f, %.5f, %.5f), "
        "col_ptr(%.5f, %.5f, %.5f) \n" % (
            i, sphere_id, t, delta_t,
            sigma, att, 1. - att,
            light_intensity, weight,
            color[0], color[1], color[2],
            vert_col[sphere_id][0], vert_col[sphere_id][1], vert_col[sphere_id][2]
        )
    )
    light_intensity = light_intensity * att

color = color + light_intensity * bg_col

color.sum().backward()
print ("backward|nerf color grad.", vert_col.grad)
print ("backward|nerf opacity grad.", opacity.grad)