In [2]:
#| default_exp trigonometry

In [3]:
#| export

import jax
import jax.numpy as jnp

import functools

import numpy as np

In [4]:
#| export

from jaxtyping import Float

In [5]:

#| hide

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_debug_nans", False)

In [6]:
#| hide

import jaxtyping


In [7]:

#| hide

%load_ext jaxtyping
%jaxtyping.typechecker beartype.beartype

## Basic trigonometry

For example, the circumcenter of a triangle, which is the position of the dual Voronoi vertex.

### Coding style notes

Throughout, we will (attempt to) provide a type signature for all functions. To do so for array-based functions, we use  [jaxtyping](https://docs.kidger.site/jaxtyping).

### JAX

The aim is to create a triangulation datastructure compatible with the JAX library for automatic differentiation and numerical computing. In practice, this means that we use `jnp` (=`jax.numpy`) instead of `numpy`, and make sure our code follows JAX's functional programming paradigm (see [JAX- the sharp bits](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html)).

In [8]:
#| export

## trig functions - we can use vmap to vectorize them

@functools.partial(jax.jit, static_argnames=['zero_clip'])
def get_circumcenter(a: Float[jax.Array, " dim"],
                     b: Float[jax.Array, " dim"],
                     c: Float[jax.Array, " dim"], zero_clip: float = 1e-10) -> Float[jax.Array, " dim"]:
    """Return circumcenter coordinates of triangle with vertices a, b, c"""    
    # compute using barycentric coordinates. Start by computing the edge lengths:
    la, lb, lc = (jnp.linalg.norm(b-c), jnp.linalg.norm(c-a), jnp.linalg.norm(a-b,))
    ba, bb, bc = (la**2*(lb**2+lc**2-la**2), lb**2*(lc**2+la**2-lb**2), lc**2*(la**2+lb**2-lc**2))
    u = (ba*a+bb*b+bc*c)/jnp.clip(ba+bb+bc, zero_clip)  # avoid div by zero for degenerate triangles
    return u

def get_oriented_triangle_area(a: Float[jax.Array, " dim"],
                               b: Float[jax.Array, " dim"],
                               c: Float[jax.Array, " dim"]) -> Float[jax.Array, "*"]:
    """Signed area of triangle with vertices a, b, c. If d=2, returns a scalar, if d=3, a vector."""
    return 0.5 * jnp.cross(b - a, c - a)

def get_triangle_area(a: Float[jax.Array, " dim"],
                      b: Float[jax.Array, " dim"],
                      c: Float[jax.Array, " dim"]) -> Float[jax.Array, ""]:
    """Unoriented area of triangle with vertices a, b, c. Works in dim 2 or 3."""
    return 0.5 * jnp.linalg.norm(jnp.cross(b - a, c - a))

def get_polygon_area(pts: Float[jax.Array, "n_vertices 2"]) -> Float[jax.Array, ""]:
    """Area of 2D polygon assuming no self-intersection."""
    return jnp.sum(pts[:,0]*jnp.roll(pts[:,1], 1) - jnp.roll(pts[:,0], 1)*pts[:,1])/2

In [17]:
#| export

def get_signed_angle_between_vectors(a: Float[jax.Array, "2"],
                                     b: Float[jax.Array, "2"]) -> Float[jax.Array, ""]:
    """Signed angle between two 2d vectors"""
    return jnp.atan2(jnp.cross(a, b), jnp.dot(a, b))


def get_angle_between_vectors(a: Float[jax.Array, " dim"],
                              b: Float[jax.Array, " dim"]) -> Float[jax.Array, ""]:
    """Angle between two vectors"""
    inner = a.dot(b) / (jnp.linalg.norm(a) * jnp.linalg.norm(b))
    return jnp.arccos(jnp.clip(inner, -1, 1))

def get_cot_between_vectors(a: Float[jax.Array, " dim"],
                              b: Float[jax.Array, " dim"]) -> Float[jax.Array, ""]:
    """Cotangent of angle between two vectors"""
    return jnp.dot(a, b) / jnp.linalg.norm(jnp.cross(a, b))

In [47]:
#| export

@functools.partial(jax.jit, static_argnames=['zero_clip'])
def get_voronoi_corner_area(a: Float[jax.Array, " dim"],
                            b: Float[jax.Array, " dim"],
                            c: Float[jax.Array, " dim"], zero_clip: float=1e-10) -> Float[jax.Array, "*"]:
    """
    Compute oriented Voronoi area at corner a of triangle abc. Returns vector in 3d.
    Returns zero for a degenerate triangle.
    """
    u = get_circumcenter(a, b, c)
    # Voronoi edges are midpoints of triangle edges. the corner area splits into two triangles:
    #a_corner = get_polygon_area(jnp.stack([a, (a-b)/2, u, (a-c)/2], axis=0))
    a_corner = get_oriented_triangle_area(a, (a-c)/2, u)+get_oriented_triangle_area(u, (a-b)/2, a)
    a_triangle = get_triangle_area(a, b, c)
    return jnp.where(a_triangle > zero_clip, a_corner, 0.0)

In [48]:
get_voronoi_corner_area(jnp.array([0.,0.]), jnp.array([0.,1.]),  jnp.array([1.,0.]))

Array(-0.25, dtype=float64)

In [26]:
get_circumcenter(jnp.array([0.,0.]), jnp.array([0.,1.]),  jnp.array([1.,0.]))

Array([0.5, 0.5], dtype=float64)

In [18]:
get_signed_angle_between_vectors(jnp.array([1.,0.]), jnp.array([0.5,0.5])) * (180/jnp.pi)

Array(45., dtype=float64)

In [19]:
get_polygon_area(jnp.array([[0.,0.], [0.,1.], [1.,0.]]) ), get_triangle_area(*jnp.array([[0.,0.], [0.,1.], [1.,0.]]) )

(Array(0.5, dtype=float64), Array(0.5, dtype=float64))

In [20]:
get_circumcenter(jnp.array([0.,0.]), jnp.array([0.,1.]),  jnp.array([1.,0.]))

Array([0.5, 0.5], dtype=float64)

In [21]:
get_circumcenter(jnp.array([1.,0.]), jnp.array([1.,0.]),  jnp.array([0.,1.]))

Array([0., 0.], dtype=float64)

In [22]:
get_circumcenter(jnp.array([1.,0.]), jnp.array([2.,0.]),  jnp.array([1.,0.]))

Array([0., 0.], dtype=float64)

In [27]:
get_voronoi_corner_area(jnp.array([0.,0.]), jnp.array([0.,1.]),  jnp.array([1.,0.]))

Array(-0.25, dtype=float64)

### Normals and rotation matrices in 3D

In [None]:
#| export

def get_rot_mat(theta: float) -> Float[jax.Array, "2 2"]:
    """Get 2D rotation matrix from angle in radians."""
    return jnp.array([[jnp.cos(theta), jnp.sin(theta)],[-jnp.sin(theta), jnp.cos(theta)]])

def get_perp_2d(x: Float[jax.Array, "... 2"]) -> Float[jax.Array, "... 2"]:
    """Get perpendicular vector."""
    return jnp.stack([x[..., 1], -x[..., 0]], axis=-1)

def get_triangle_normal(a: Float[jax.Array, "3"],
                        b: Float[jax.Array, "3"],
                        c: Float[jax.Array, "3"]) -> Float[jax.Array, "3"]:
    """Compute unit normal vector of triangle abc."""
    n = jnp.cross(b - a, c - a)
    return n / jnp.linalg.norm(n)

def quaternion_to_rot_max(q: Float[jax.Array, "4"]) -> Float[jax.Array, "3 3"]:
    """
    Convert unit quaternion into a 3d rotation matrix.
    
    See https://fr.wikipedia.org/wiki/Quaternions_et_rotation_dans_l%27espace
    """
    a, b, c, d = q / jnp.linalg.norm(q)
    return jnp.array([[a**2+b**2-c**2-d**2, 2*b*c-2*a*d, 2*a*c+2*b*d],
                      [2*a*d+2*b*c, a**2-b**2+c**2-d**2, 2*c*d-2*a*b],
                      [2*b*d-2*a*c, 2*a*b+2*c*d, a**2-b**2-c**2+d**2]])

### Barycentric coordinates

In [80]:
#| export

@functools.partial(jax.jit, static_argnames=['zero_clip', 'normalize'])
def get_barycentric_coordinates(point: Float[jax.Array, " dim"],
                                a: Float[jax.Array, " dim"],
                                b: Float[jax.Array, " dim"],
                                c: Float[jax.Array, " dim"],
                                zero_clip: float = 1e-10, normalize: bool = True) -> Float[jax.Array, "3"]:
    """Compute barycentric coordinates of point with respect to triangle abc."""
    bary, _, _, _ = jnp.linalg.lstsq(jnp.stack([a,b,c], axis=1), point)
    if normalize:
        bary = bary / jnp.clip(bary.sum(), zero_clip)
    return bary

In [81]:
vertices = jnp.array([[0., 0 ], [0., 1.], [1., 0.]])
point1 = jnp.array([0.5, 0.2])

get_barycentric_coordinates(point1, *vertices, normalize=False)

Array([0. , 0.2, 0.5], dtype=float64)

In [83]:
vertices2 = jnp.array([[0., 0, 0], [0., 1., 0], [1., 0., 0]])
point2 = jnp.array([0.5, 0.2, 0.])

get_barycentric_coordinates(point2, *vertices2, normalize=False)

Array([0. , 0.2, 0.5], dtype=float64)

In [84]:
point3 = jnp.array([0.5, 0.2, 1.])

get_barycentric_coordinates(point3, *vertices2, normalize=False)

Array([0. , 0.2, 0.5], dtype=float64)