In [1]:
# Francisco Dominguez Mateos
# 06/08/2020
# Naive divergence
# from: https://github.com/google/jax/issues/173

# Snippets for divergence naive version 

Hi , first not very sure that the issues section is the best place for this so I apologise before hand. Theres not much of a forum community online for JAX (no mailing list ?) yet, and this can almost be seen as a feature request :

I am trying to implment the div operator (trace of the jacobian). For a given parametric map (vector field) of the form:

$$ f_{\theta} : \mathbb{R}^n \rightarrow \mathbb{R}^{n} $$

I want to apply the div operator to it which is :

$$ \nabla \cdot f_{\theta} = \sum_{i=1}^n \partial_{x_i} f_\theta $$

Note the derivatives are with respect to the inputs of the function $\partial_{x_i} f_{\theta}(x_1, ... x_n)$ rather than theta (theta is constant).

I spent some time reading the documentation and thinking about the problem . Using 'grad' and 'vmap' does not seem feasible since by definition this function requires an input of size $n$ and produces an output of the same size, theres no way of making it a scalar nicely, you would have to create n functions whic each return a index $i$ of $f$, which is quite tedious.

As a jacobian vector product I cant see how theres a vector v that would produce the trace. of the jacobian (a single pullback wont give you the trace).

In [12]:
import jax
import jax.numpy as np 

X = np.arange(0,10).reshape(1, 10).astype("float32")

theta = np.eye(10,10).astype("float32") 

def f(theta, X):
    out = X.dot(theta)
    return out

def divergence(f, theta_, X_):

    def my_div(f_):
        jac = jax.jacrev(f_, 1)
        return lambda t, x_: np.trace(jac(t, x_))

    div  = ((jax.vmap(my_div(f), in_axes=(None, 0)) (theta_, X_ )))
    return div

Y=f(theta,X)
print(Y.shape)
div = divergence(f, theta, X)
print(div.shape)  # 50, a scalar div per datapoint(row) in X

(1, 10)
(1,)
