In [3]:
import sys
sys.path.insert(0, "../..")

import optax
import jax
import jax.numpy as jnp
from functools import reduce
import jax.numpy as np
import jax.numpy as jnp
from matplotlib import pyplot as plt
import netket
import jax.numpy as jnp
import netket.jax
from jax import jit
import math
from functools import partial
import module.misc.atomic_orbitals as atom
import module.wavefunctions

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


## Make jvp and vjp for complex valued vectorial functions

consider $f:\mathbb{R}^n \to \mathbb{C}^m$

### 1. jvp (jacobian vector product) (works already)
takes a vector in $\mathbb{R}^n$ and maps it to a complex vector in $\mathbb{C}^m$.

In [60]:
A = jnp.array([[1,1j],[-1,2j]])
print(A)
def f(x):
    return A @ x

[[ 1.+0.j  0.+1.j]
 [-1.+0.j  0.+2.j]]


In [61]:
out, D = jax.jvp(f, (jnp.array([1.,1.]),), (jnp.array([1.,1.], dtype = "float64"),))
print(out, D)

[ 1.+1.j -1.+2.j] [ 1.+1.j -1.+2.j]


### 2. vjp (covector jacobian product)

In [69]:
out, D = netket.jax.vjp(f, jnp.array([0.,0.]))
print(out)

[0.+0.j 0.+0.j]


In [71]:
D(jnp.array([1.,1.j]))

(Array([ 1.-1.j, -2.+1.j], dtype=complex128),)

## Calculate the Hessian in a nice way

In [4]:
def f(x):
    return x[0]**2 + x[1]**2 + 1.j*x[2]**2

In [33]:
x = jnp.array([5.,0.,0.])
v = jnp.array([0.,0.,1.])

In [34]:
primals, tangents = jax.jvp(f, (x,), (v,))
print(primals)
print(tangents)

(25+0j)
0j


In [35]:
g = lambda x: jax.jvp(f, (x,), (v,))[1]
primals, tangents = jax.jvp(g, (x,), (v,))

In [36]:
tangents

Array(0.+2.j, dtype=complex128)

In [None]:
# put together

In [68]:
@partial(jax.jit, static_argnames=['f'])
def _d2(f, x, v):
    g = lambda x: jax.jvp(f, (x,), (v,))[1]
    primals, tangents = jax.jvp(g, (x,), (v,))

    return tangents

_d2_ = jax.vmap(_d2, in_axes = [None, None, 0])


@partial(jax.jit, static_argnames=['f'])
def hessian(f, x):
    return _d2(f, x, jnp.array([1.,0.,0.])) + _d2(f, x, jnp.array([0.,1.,0.])) + _d2(f, x, jnp.array([0.,0.,1.]))

@partial(jax.jit, static_argnames=['f'])
def hess_diag(f, x):
    return _d2_(f, x, jax.nn.one_hot(jnp.arange(x.shape[0]), 3))

In [69]:
hess_diag(f, x)

Array([2.+0.j, 2.+0.j, 0.+2.j], dtype=complex128)

In [65]:
jax.nn.one_hot(jnp.arange(3), 3)

Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float64)

In [45]:
_d2(f, x, jnp.array([1.,0.,0.]))

Array(2.+0.j, dtype=complex128)

In [47]:
hessian(f, jnp.array([0.,0.,0.]))

Array(4.+2.j, dtype=complex128)

In [48]:
hessian_broad = jax.vmap(hessian, in_axes = [None, 0])

In [53]:
hessian_broad(f, jnp.array([[0.,0.,0.],[0.,0.,0.],[0.,0.,0.],[0.,1.,0.]]))

Array([4.+2.j, 4.+2.j, 4.+2.j, 4.+2.j], dtype=complex128)