In [16]:
import jax.numpy as jnp
from jax import jacfwd, jacrev, random
import time

# 1.2 Part (i)

In [None]:
# start with linear parts
A1 = jnp.sin(jnp.arange(1000 * 5000).reshape(1000, 5000) * 0.001)
A2 = jnp.sin(jnp.arange(30 * 1000).reshape(30, 1000) * 0.001)
A3 = jnp.sin(jnp.arange(1 * 30).reshape(1, 30) * 0.001)

# non linear affines
b1 = jnp.tanh(jnp.sum(A1, axis=1))
b2 = jnp.tanh(jnp.sum(A2, axis=1))
b3 = jnp.tanh(jnp.sum(A3, axis=1))

def F(x):
    y1 = A1 @ x + b1
    y1 = jnp.tanh(y1) + 0.1 * jnp.sin(y1)

    y2 = A2 @ y1 + b2
    y2 = jnp.tanh(y2) + 0.1 * y2**2

    y3 = A3 @ y2 + b3
    return y3.squeeze()



In [None]:
# generate random 5000 dimensional vector x
key = random.PRNGKey(0)
random_x = random.normal(key, shape=(5000,))
# compute DF(x) using both jacfwd and jacrev
start = time.perf_counter()
DF_forward = jacfwd(F)(random_x)
end = time.perf_counter()
print(f"Forward time: {end - start}")

start = time.perf_counter()
DF_forward = jacrev(F)(random_x)
end = time.perf_counter()
print(f"Reverse time: {end - start}")

Forward time: 1.3830949170514941
Reverse time: 0.41518862498924136


# 1.2 part (ii)

In [20]:
def G(z):
    z = jnp.atleast_1d(z)        

    y2 = A3.T @ (z - b3)         
    y2 = jnp.tanh(y2) + 0.1 * jnp.sin(y2)

    y1 = A2.T @ (y2 - b2)   
    y1 = jnp.tanh(y1) + 0.1 * y1**2

    x  = A1.T @ (y1 - b1)  
    return x


In [22]:
# generate random 5000 dimensional vector x
key = random.PRNGKey(0)
random_x = random.normal(key, shape=(1,))
# compute DF(x) using both jacfwd and jacrev
start = time.perf_counter()
DF_forward = jacfwd(G)(random_x)
end = time.perf_counter()
print(f"Forward time: {end - start}")

start = time.perf_counter()
DF_forward = jacrev(G)(random_x)
end = time.perf_counter()
print(f"Reverse time: {end - start}")

Forward time: 0.6012854169821367
Reverse time: 1.1691453750245273


# 1.2 part (iii)

Explain the differences in computation time: 

The fastest way to compute a series of matrix multiplications (ABCDE...Z) is to begin on the side with the smaller dimension on the outside (row_count for A, col_count for Z).  This is because the subsequent matrix multiplications will mainting the complexity of that outer dimension.  In the case that A is (1, m), the entire matrix multiplication results to a bunch of multiplications like (1,m) x (n, l), which are cheaper than the other side, which might be something like a repeated matrix multiplication looking like (m, n) x (l, 5000).

In this specific context, The forward does slowly on F because F has a large first dimension on A and a small dimension on Z.  This is similar but reversed for G

