Skip to content

Commit

Permalink
Simplify np.cross. Add a jit decorator. (#1810)
Browse files Browse the repository at this point in the history
* Simplify np.cross. Add a jit decorator.
  • Loading branch information
hawkinsp committed Dec 4, 2019
1 parent d6b18fb commit 17813ea
Showing 1 changed file with 25 additions and 40 deletions.
65 changes: 25 additions & 40 deletions jax/numpy/lax_numpy.py
Expand Up @@ -2371,48 +2371,33 @@ def outer(a, b, out=None):
raise NotImplementedError("The 'out' argument to outer is not supported.")
return ravel(a)[:, None] * ravel(b)

@partial(jit, static_argnums=(2, 3, 4))
def _cross(a, b, axisa, axisb, axisc):
a = moveaxis(a, axisa, -1)
b = moveaxis(b, axisb, -1)

if a.shape[-1] not in (2, 3) or b.shape[-1] not in (2, 3):
raise ValueError("Dimension must be either 2 or 3 for cross product")

if a.shape[-1] == 2 and b.shape[-1] == 2:
return a[..., 0] * b[..., 1] - a[..., 1] * b[..., 0]

a0 = a[..., 0]
a1 = a[..., 1]
a2 = a[..., 2] if a.shape[-1] == 3 else zeros_like(a0)
b0 = b[..., 0]
b1 = b[..., 1]
b2 = b[..., 2] if b.shape[-1] == 3 else zeros_like(b0)
c = array([a1 * b2 - a2 * b1, a2 * b0 - a0 * b2, a0 * b1 - a1 * b0])
return moveaxis(c, 0, axisc)

@_wraps(onp.cross)
def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
if axis is not None:
axisa = axis
axisb = axis
axisc = axis

a_ndims = len(shape(a))
b_ndims = len(shape(b))
axisa = _canonicalize_axis(axisa, a_ndims)
axisb = _canonicalize_axis(axisb, b_ndims)
a = moveaxis(a, axisa, -1)
b = moveaxis(b, axisb, -1)
a_shape = shape(a)
b_shape = shape(b)

if a_shape[-1] not in (2, 3) or b_shape[-1] not in (2, 3):
raise ValueError("Dimension must be either 2 or 3 for cross product")

if a_shape[-1] == 2 and b_shape[-1] == 2:
return a[..., 0] * b[..., 1] - a[..., 1] * b[..., 0]

if a_shape[-1] == 2:
a = concatenate((a, zeros(a_shape[:-1] + (1,), dtype=a.dtype)), axis=-1)
elif b_shape[-1] == 2:
b = concatenate((b, zeros(b_shape[:-1] + (1,), dtype=b.dtype)), axis=-1)

a0 = a[..., 0]
a1 = a[..., 1]
a2 = a[..., 2]
b0 = b[..., 0]
b1 = b[..., 1]
b2 = b[..., 2]

c = array([a1 * b2 - a2 * b1,
a2 * b0 - a0 * b2,
a0 * b1 - a1 * b0])

c_ndims = len(shape(c))
axisc = _canonicalize_axis(axisc, c_ndims)

return moveaxis(c, 0, axisc)
if axis is not None:
axisa = axis
axisb = axis
axisc = axis
return _cross(a, b, axisa, axisb, axisc)

@_wraps(onp.kron)
def kron(a, b):
Expand Down

0 comments on commit 17813ea

Please sign in to comment.