## Computing Virtual Cones for the given implicit surface

### Objective
Given an implicit surface and set of camera rays, here we compute the virtual cones for each camera ray in parallel. 

Starting with including required packaged and adding some helper functions

In [1]:
%matplotlib widget
import torch
from torch import autograd
from functools import partial
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

dot = lambda x,y: (x*y).sum(-1, keepdim=True)
norm = lambda x: torch.sqrt((x**2).sum(-1, keepdim=True))
transp = lambda x: x.transpose(-2,-1)
normalize = lambda x: x/(1e-7+norm(x))

### Given quantitites
<img src="media/viz_3D_1.jpg" width="200">



**1. Implicit Surface**

Here we consider the implicit surface to be a quadric. But all the computation in this script holds true for other forms of implicit surfaces such as signed distance functions.  
Examples of quadric surfaces include spheres, ellipsoids, cylinders, hyperboloids and paraboloids. Quadric can be written in matrix form as :
$$ \mathbf{p}^TQ\mathbf{p} = 1$$
, where $p \in R^3$ is a point on the surface, $Q$ is a 3x3 matrix representing the quadric. 

In [2]:
def implicit_quadric(Q_mat, p_vec):
    p_vec = p_vec.requires_grad_(True)
    return transp(p_vec[...,None])@Q_mat@p_vec[...,None] - 1

Here are Q matrices for different common quadrics. 

In [3]:
## Sphere with radius 5
# Q_mat = torch.Tensor([[1,0,0],
#                       [0,1,0],
#                       [0,0,1.]])/25
# # Cylinder with radius 5 and axis along y
# Q_mat = torch.Tensor([[1,0,0],
#                       [0,0,0],
#                       [0,0,1.]])/25
# Ellipse with radius 5 along x and z, 2.5 along y
Q_mat = torch.Tensor([[1,0,0],
                      [0,4,0],
                      [0,0,1]])/25
# Partial creates a function with only x as the argument
f_func = partial(implicit_quadric, Q_mat)

**2. Camera ray parameters**

Suppose we are querying $N$ pixels. Each pixel corresponds to a primary ray shot from pinhole to center of the pinhole. The primary ray can be represented by ray source $\mathbf{s}$ and direction $\mathbf{d}$. In this example consider $N=2$ pixels one along x axis and another along y axis. 


In [4]:
s = torch.Tensor([[10, 0, 0], 
                  [0,  0, 10]]) # 2 x 3
d = normalize(torch.Tensor([[-1., 0, 0],
                            [ 0,  0, -1]]))

**3. Ray surface intersection**

$\mathbf{p}$ is intersection of the primary ray with the implicit surface. For ORCA, VolSDF part gives us multiple values of $\mathbf{p}$ For analytical SDF it can be found by sphere tracing. For quadric, it can be found by solving ray quadric intersection. Here I am just hard coding the intersection values which would be the same for the above 3 $Q$ matrices. If you change $Q$ to any other value, intersection points would be different


In [5]:
p = torch.Tensor([[5, 0, 0], 
                  [0, 0, 5]])

**3. Primary cone rays**

As the pixel has finite size, each primary ray corresponds to a cone. The corners of the pixel can be connected to the pinhole, to obtain four corners of the cone. Appending these corner rays with the central ray, we have 5 rays that define the cone for each pixel. I am hard coding the corner rays But MipNeRF has code to compute these corner rays. 

In [6]:
d_primary = d 
sp =  0.1 # Size of pixel
d = normalize(torch.Tensor([[[-1., 0, 0], [-1., sp, 0.], [-1. , 0., sp], [-1, -sp, 0.], [-1., 0, -sp]],
                            [[0., 0, -1.], [0., sp, -1.], [sp , 0., -1], [0., -sp, -1.], [-sp, 0, -1.]]]))
    # 2 x 5 x 3
# Add singelton dimensions to s and p
s = s[...,None,:]
p = p[...,None,:]

### Estimate Normal and Curvature of implicit surface
<img src="media/viz_3D_2.jpg" width="300">

**1. Surface Normals**

Similar to VolSDF, we can compute normals using the gradient of the implicit surface $f(\mathbf{x})$ (here $=\mathbf{x}^TQ\mathbf{x}$)
$$ N = \dfrac{\nabla f}{\| \nabla f\|}$$

In [7]:
def get_grad(f_val, x, has_grad=True):
    return autograd.grad(
        f_val, 
        x, 
        torch.ones_like(f_val, device=x.device),
        create_graph = has_grad, 
        retain_graph= has_grad, 
        only_inputs=True)[0]
        

def get_surface_normals(f_func, x, has_grad=True):
    f_val = f_func(x)
    grad_f = get_grad(f_val,x,has_grad)
    return grad_f/norm(grad_f)

N = get_surface_normals(f_func, p)

**2. Curvature**

For curvature, we first have to define shape operator $dN$. Shape operator is differential of the normal map. $dN$ is defined on the tangent plane at a point $\mathbf{p}$. For more details and intuition on shape operator refer to [this lecture](https://www.youtube.com/watch?v=UewzuzaPlxA) by Prof. Justin Solomon 

For implicit function $f$, [this paper](https://arxiv.org/pdf/2201.09263.pdf) defines $dN$ as the following linear function:

$$dN = (I-N.N^T)\frac{\mathbf{H}f}{\|\nabla f\|} $$
where $\mathbf{H}f$ is the Hessian of f and I is 3x3 identity matrix. 



In [8]:
def get_H(f_func, x, has_grad=True):
    # TODO: Vectorize Hessian computation
    x_flat = x.reshape(-1,3)
    b_len = x_flat.shape[0]
    H = torch.zeros(b_len, 3, 3).to(x.device).float()
    for idx in range(b_len):
        H[idx] = autograd.functional.hessian(f_func, x_flat[idx],
                                             create_graph=True, 
                                             vectorize=True)
    return H.reshape(x.shape[:-1]+(3,3))

def get_shape_operator(f_func, x, has_grad=True):
    f_val = f_func(x)
    Hf = get_H(f_func, x, has_grad) # ..., 3x3
    grad_f = get_grad(f_val, x, has_grad)# ..., 3,
    N = get_surface_normals(f_func, x, has_grad)[...,None] # ..., 3,1
    I = torch.eye(3)
    if x.ndim > 1:
        # Add singleton based on batch dimensions of x
        I = I.reshape((1,)*(x.ndim-1)+I.shape)
    return (I-N@transp(N))@Hf/norm(grad_f)[...,None]

dN = get_shape_operator(f_func, p)

From the shape operator, we can find curvature along any vector $\mathbf{v}$ in the tangent plane as:
$$ \kappa_{\mathbf{v}} = \left<-dN.\mathbf{v},\mathbf{v}\right>$$
where $\left<,\right>$ is the inner product. 


In [9]:
def get_normal_curvature(dN, v) :
    return dot((-dN@(v[...,None]))[...,0], v)   

### Primary cone rays and oscullating surface intersection
Which vectors along the tangent plane should we query the curvature? We can intersect the primary cone rays with the surface to find the tangent vectors.

<img src="media/viz_3D_3.jpg" width="300">

In [10]:
def intersect_ray_plane(s, d, N, p):
    # Surface given by <x-p,N>=0
    # Ray given by x = s+td , t> 0
    t_int = dot(p-s,N)
    return s+t_int*d
# Intersect cone ray and tangent plane 
u = intersect_ray_plane(s, d, N, p)

We can query the radius of curvature at the intersection point, to get the oscullating circle. 

In [11]:
# Repeat dN along cone rays dimension
dN_exp = dN.expand(-1,5,-1,-1)
# Get radius along the direction of intersection
Ku = get_normal_curvature(dN_exp, u-p)
# Radius of curvature (Negative if convex)
Ru = 1/Ku
# Center of radius
cu = p + Ru*N

Intersect the ray with oscullating circle. 

In [12]:
def intersect_ray_circle(s, d, c, R):
    # Circle given by <x-c,x-c>=R**2
    # Ray given by x = s+td
    # Point of intersection t follows:
    # a2t^2 + a1t + a0 with following values
    a2 = 1.
    a1 = 2*dot(d, s-c)
    a0 = norm(s-c)**2 - R**2
    # Using quadratic formula finding the smaller of the two roots
    t_int = (-a1 - torch.sqrt(a1**2 - 4*a2*a0))/(2*a2)

    return s+t_int*d

# Intersect ray and oscullating circle
c_int = intersect_ray_circle(s, d, cu, Ru)


### Virtual cone parameters
<img src="media/viz_3D_4.jpg" width="300">


**1. Virtual Cone directions**

From the osculating circles we know the normal vector. Normals $N_u$ at a point on the osculating circle is ray going from center $c_u$ to point of intersection $c_{int}$.

In [13]:
Nu = normalize(c_int-cu)

Virtual cone direction can then be found using the reflection formula

In [14]:
d_r = d - 2*dot(d,Nu)*Nu

**2. Virtual Cone Apex**

We approximate the cone as the point with the least net distance from all the cone rays. This can be posed as a pseudo-inverse problem as described [here](https://math.stackexchange.com/a/55286) .

In [16]:
def closest_pt_to_lines(c_int, d_r):
    # Written in terminology described in the link
    w_i = c_int[...,None] # N x 5 x 3 x 1
    u_i = d_r[...,None] # N x 5 x 3 x 1
    I = torch.eye(3)
    # Add singleton based on batch dimensions of w_i
    I = I.reshape((1,)*(w_i.ndim-2)+I.shape) # N x 5 x 3 x 3
    A_i = I - u_i@transp(u_i) # N x 5 x 3 x 3
    p_i = A_i @ w_i # N x 5 x 3 x 1

    # Sum along cone rays
    A = A_i.sum(-3, keepdim=True) # N x 1 x 3 x 3
    p = p_i.sum(-3, keepdim=True) # N x 1 x 3 x 1

    A_pinv = torch.inverse(A)
    o_prime = A_pinv@p # N x 1 x 3 x 1

    return o_prime[...,0]

o_prime = closest_pt_to_lines(c_int, d_r)

**3. Virtual Cone Radius**

From the virtual camera position $o'$ and the virtual cone directions $d_r$, we can compute cone radius $r'$ in a similar manner as MipNeRF