# Using JAX as a backend in NetKet - Feature Preview for v3.0 

In this tutorial we will show how differentiable functions (for example deep networks) written in [JAX](https://github.com/google/jax) can be used as variational quantum states in NetKet. 

This feature will be available in the upcoming major release (version 3.0). While version 3.0 is still in beta development, users can already try this feature. 


## Prerequisites 

To try out integration with JAX, you first need to fetch the development version of NetKet (v3.0) 
We recommend using a virtual environment (either a python environment or a conda environment), for example

```shell
python3 -m venv nk_env
source nk_env/bin/activate
pip install git+https://github.com/netket/netket@v3.0
```

Frameworks such as JAX and Pytorch are add-ons for NetKet 3.0, thus they require separate installation. In this case, we can easily get JAX doing 

```shell
pip install --upgrade jax jaxlib
```
More information can also be found [here](https://github.com/google/jax#installation).

## Defining the quantum system 

NetKet allows for full flexibility in defining quantum systems, for example when tackling a ground-state search problem. While there are a few pre-defined hamiltonians, it is relatively straightforward to implement new quantum operators/ Hamiltonians. 

In the following, we consider the case of a transverse-field Ising model defined on a graph with random edges. 

$$ H = -\sum_{i\in\textrm{nodes}}^{L} \sigma^x_{i} + J \sum_{(i,j)\in\textrm{edges}}\sigma_{i}^{z}\sigma_{i}^{z} $$  

In [1]:
import netket as nk

#Define a random graph
n_nodes=10
n_edges=20
from numpy.random import choice
rand_edges=[choice(n_nodes, size=2,replace=False).tolist() for i in range(n_edges)]

graph=nk.graph.Graph(edges=rand_edges)


#Define the local hilbert space
hi=nk.hilbert.Spin(graph,s=0.5)


#Define the Hamiltonian as a sum of local operators 
from netket.operator import LocalOperator as Op

# Pauli Matrices
sx = [[0, 1], [1, 0]]
sz = [[1, 0], [0, -1]]

# Defining the Hamiltonian as a LocalOperator acting on the given Hilbert space
ha = Op(hi)

#Adding a transverse field term on each node of the graph
for i in range(graph.n_nodes):
    ha += Op(hi, sx, [i])

#Adding nearest-neighbors interactions on the edges of the given graph
from numpy import kron
J=0.5
for edge in graph.edges():
    ha += J*Op(hi, kron(sz, sz), edge)


## Defining a JAX module to be used as a wave function

We now want to define a suitable JAX wave function to be used as a wave function ansatz. To simplify the discusssion, we consider here a simple single-layer fully connected network with complex weights and a $tanh$ activation function. These are easy to define in JAX, using for example a model built with [STAX](https://github.com/google/jax/tree/master/jax/experimental). The only requirement is that these networks take as  inputs JAX arrays of shape ```(batch_size,n)```, where batch_size is an arbitrary ```batch size``` and ```n``` is the number of quantum degrees of freedom (for example, the number of spins, in the previous example). Notice that regardless of the dimensionality of the problem, the last dimension is always flattened into a single index.  


In [2]:
import jax
from jax.experimental import stax

#We define a custom layer that performs the sum of its inputs 
def SumLayer():
    def init_fun(rng, input_shape):
        output_shape = (-1, 1)
        return output_shape, ()

    def apply_fun(params, inputs, **kwargs):
        return inputs.sum(axis=-1)

    return init_fun, apply_fun

#We construct a fully connected network with tanh activation 
model=stax.serial(stax.Dense(2 * graph.n_nodes), stax.Tanh,SumLayer())

#Here we use this model as a netket machine that can be used in other applications
ma=nk.machine.Jax(hi,model,dtype=complex)
ma.init_random_parameters(seed=1232)



## Train the neural network to find an approximate ground state

In order to perform Variational Monte Carlo, we further need to specify a suitable 
sampler (to compute expectation values over the variational state) as well as 
an optimizer. In the following we will adopt the Stochatic Gradient Descent coupled
with quantum natural gradients (this scheme is known in the VMC literature as Stochastic Reconfiguration)

In [3]:
# Defining a sampler that performs local moves
# NetKet automatically dispatches here to MCMC sampler written using JAX types
sa = nk.sampler.MetropolisLocal(machine=ma, n_chains=2)

# Using Sgd
# Also dispatching to JAX optimizer
op = nk.optimizer.Sgd(ma, learning_rate=0.02)

# Using Stochastic Reconfiguration a.k.a. quantum natural gradient
# Also dispatching to a pure JAX version
sr = nk.optimizer.SR(ma, diag_shift=0.1)

# Create the Variational Monte Carlo instance to learn the ground state
vmc = nk.Vmc(
    hamiltonian=ha, sampler=sa, optimizer=op, n_samples=1000, sr=sr
)

### Running the training loop 

The last version of NetKet also allows for a finer control of the vmc loop. In the simplest case, one can just iterate through the vmc object and print the current value of the energy. More sophisticated output schemes based on tensorboard have been also implemented, but are not discussed in this Tutorial. 

In [4]:
# Running the learning loop and printing the energy every 50 steps
# [notice that the very first iteration is slow because of JIT compilation]
for it in vmc.iter(500,50):
    print(it,vmc.energy)

0 9.97-0.01j ± 0.087 [var=6.4e+00, R_hat=0.9993]
50 -3.2+0.1j ± 0.12 [var=7.9e+00, R_hat=1.0005]
100 -9.68-0.05j ± 0.064 [var=2.7e+00, R_hat=1.0007]
150 -10.65-0.02j ± 0.038 [var=1.0e+00, R_hat=1.0016]
200 -10.94-0.01j ± 0.027 [var=8.8e-01, R_hat=1.0002]
250 -11.08-0.00j ± 0.023 [var=6.0e-01, R_hat=0.9991]
300 -11.20+0.01j ± 0.024 [var=4.6e-01, R_hat=0.9992]
350 -11.21-0.00j ± 0.017 [var=2.9e-01, R_hat=0.9992]
400 -11.26+0.00j ± 0.014 [var=1.8e-01, R_hat=0.9997]
450 -11.28+0.00j ± 0.012 [var=1.3e-01, R_hat=0.9994]


## Comparing to exact diagonalization

Since this is a relatively small quantum system, we can still diagonalize the Hamiltonian using exact diagonalization. For this purpose, NetKet conveniently exposes a ```.to_sparse``` method that just converts the Hamiltonian into a ```scipy``` sparse matrix.
Here we first obtain this sparse matrix, and then diagonalize it with scipy builtins. 

In [5]:
import scipy
exact_ens=scipy.sparse.linalg.eigsh(ha.to_sparse(),k=1,which='SA',return_eigenvectors=False)
print("Exact energy is : ",exact_ens[0])
print("Relative error is : ", (abs((vmc.energy.mean-exact_ens[0])/exact_ens[0])))

Exact energy is :  -11.297315259167611
Relative error is :  0.0017834662727841972
