# Chain rule for matrix derivatives & jax
What's the order of derivatives?

In [1]:
import jax
import jax.numpy as jnp

Take $h(x) = f(g(x))$, where $f: \mathbb{R}^M \to \mathbb{R}$, $g: \mathbb{R}^N \to \mathbb{R}^M$ and some $x_0 \in \mathbb{R}^N$. We seek to compute $\nabla_x h(x_0)$

In [57]:
key = jax.random.PRNGKey(314)
key_0, key_g, key_f = jax.random.split(key, 3)

N, M = 5, 3

G_proj = jax.random.normal(key_g, (M, N))
F_proj = jax.random.normal(key_f, (1, M))
x0 = jax.random.normal(key_0, (N,))

In [59]:
def g(x):
    v = jnp.sin(G_proj @ x)
    return v

def f(x):
    v = jnp.cos(F_proj @ x)
    return v


def h(x):
    return f(g(x))

In [61]:
Hx = jax.jacfwd(h)
H0 = Hx(x0)
H0

Array([[-0.02170205,  0.00988508, -0.00630814,  0.01018986, -0.00282953]],      dtype=float32)

In [62]:
jax.make_jaxpr(Hx)(x0)

{ [34m[22m[1mlambda [39m[22m[22ma[35m:f32[3,5][39m b[35m:f32[1,3][39m; c[35m:f32[5][39m. [34m[22m[1mlet
    [39m[22m[22md[35m:i32[5,5][39m = iota[dimension=0 dtype=int32 shape=(5, 5)] 
    e[35m:i32[5,5][39m = add d 0
    f[35m:i32[5,5][39m = iota[dimension=1 dtype=int32 shape=(5, 5)] 
    g[35m:bool[5,5][39m = eq e f
    h[35m:f32[5,5][39m = convert_element_type[new_dtype=float32 weak_type=False] g
    i[35m:f32[5,5][39m = slice[limit_indices=(5, 5) start_indices=(0, 0) strides=None] h
    j[35m:f32[3][39m = dot_general[dimension_numbers=(([1], [0]), ([], []))] a c
    k[35m:f32[3,5][39m = dot_general[dimension_numbers=(([1], [1]), ([], []))] a i
    l[35m:f32[3][39m = sin j
    m[35m:f32[3][39m = cos j
    n[35m:f32[5,3][39m = transpose[permutation=(1, 0)] k
    o[35m:f32[1,3][39m = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 3)] m
    p[35m:f32[5,3][39m = mul n o
    q[35m:f32[1][39m = dot_general[dimension_numbers=(([1], [0]), 

In [63]:
U0 = g(x0)
F0 = jax.jacfwd(f)(U0)
G0 = jax.jacfwd(g)(x0)

In [64]:
F0 @ G0

Array([[-0.02170205,  0.00988508, -0.00630814,  0.01018986, -0.00282953]],      dtype=float32)

In [65]:
jax.make_jaxpr(jax.jacfwd(f))(U0)

{ [34m[22m[1mlambda [39m[22m[22ma[35m:f32[1,3][39m; b[35m:f32[3][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:i32[3,3][39m = iota[dimension=0 dtype=int32 shape=(3, 3)] 
    d[35m:i32[3,3][39m = add c 0
    e[35m:i32[3,3][39m = iota[dimension=1 dtype=int32 shape=(3, 3)] 
    f[35m:bool[3,3][39m = eq d e
    g[35m:f32[3,3][39m = convert_element_type[new_dtype=float32 weak_type=False] f
    h[35m:f32[3,3][39m = slice[limit_indices=(3, 3) start_indices=(0, 0) strides=None] g
    i[35m:f32[1][39m = dot_general[dimension_numbers=(([1], [0]), ([], []))] a b
    j[35m:f32[1,3][39m = dot_general[dimension_numbers=(([1], [1]), ([], []))] a h
    _[35m:f32[1][39m = cos i
    k[35m:f32[1][39m = sin i
    l[35m:f32[3,1][39m = transpose[permutation=(1, 0)] j
    m[35m:f32[1,1][39m = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] k
    n[35m:f32[3,1][39m = mul l m
    o[35m:f32[3,1][39m = neg n
    p[35m:f32[1,3][39m = transpose[permutation=(1, 0)]

In [66]:
jax.make_jaxpr(jax.jacfwd(g))(x0)

{ [34m[22m[1mlambda [39m[22m[22ma[35m:f32[3,5][39m; b[35m:f32[5][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:i32[5,5][39m = iota[dimension=0 dtype=int32 shape=(5, 5)] 
    d[35m:i32[5,5][39m = add c 0
    e[35m:i32[5,5][39m = iota[dimension=1 dtype=int32 shape=(5, 5)] 
    f[35m:bool[5,5][39m = eq d e
    g[35m:f32[5,5][39m = convert_element_type[new_dtype=float32 weak_type=False] f
    h[35m:f32[5,5][39m = slice[limit_indices=(5, 5) start_indices=(0, 0) strides=None] g
    i[35m:f32[3][39m = dot_general[dimension_numbers=(([1], [0]), ([], []))] a b
    j[35m:f32[3,5][39m = dot_general[dimension_numbers=(([1], [1]), ([], []))] a h
    _[35m:f32[3][39m = sin i
    k[35m:f32[3][39m = cos i
    l[35m:f32[5,3][39m = transpose[permutation=(1, 0)] j
    m[35m:f32[1,3][39m = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 3)] k
    n[35m:f32[5,3][39m = mul l m
    o[35m:f32[3,5][39m = transpose[permutation=(1, 0)] n
    p[35m:f32[3,5][39m = sli