In [1]:
from jax import numpy as jnp, random, lax, jit, grad, scipy as jsp, value_and_grad
from jax import Array
import pandas as pd

In [2]:
def ret_x(x):
    return x

In [None]:
# Numerical gradient of x is always 1.0
# (ie the gradient of the function is x)

grad(ret_x)(6.0)

In [None]:
# The gradient is the same in this function

def add_one(x):
    return x + 1.0

grad(add_one)(5.0)

In [5]:
# Multiply 2 values

def just_multiply(x, y):
    return x*y

In [None]:
# Gradient of both variables (args 0,1)
# The result is the input arguments swapped;
# Since the function is purely multiplicative, the rate at which
# change in x produces change in the output is the value of y, and vice versa

grad(just_multiply, [0,1])(1.5,4.0)

In [26]:
def random_walk(mean, sd, k, N) -> Array:
    """Random gaussian walk

    Args:
        mean: Gaussian mean
        sd: Gaussian sd
        k: Jax PRNGKey
        N: Number of timesteps

    Returns:
        Array of random walk values
    """
    return jnp.cumsum(random.normal(k,(N,)) * sd + mean)

In [None]:
# Example random walk.  Try changing the seed value of the key

k = random.PRNGKey(0)
pd.Series(random_walk(0.0,1.0, k,128)).plot()

In [None]:
# Value and gradient of our random walk function
# We evaluate this on the last value of the random walk (gradients can only be computed over a scalar output)
# Note that gradient of the mean (first argument) is always the length of the sequence
# Since the random walk is just a cumulative sum of N steps, this is exactly as expected
# When the mean is 0.0, the gradient of the standard deviation is equal to the value of
# the output; ie changes in the output are produced entirely by this parameter
# These values of course change with different random keys (different random walks)
# When the mean is not 0.0, the output is a combination of the 2 inputs (with gradients reflecting this)

# Thus we can think of the gradient of a stochastic function as being deterministic for a given value of k

k = random.PRNGKey(0)
N = 128

mean = 0.0
sd = 1.0

value_and_grad(lambda mean, sd, k, N: random_walk(mean,sd,k,N)[-1],[0,1])(mean, sd, k, N)