# NetKet's infrastructure

## Jax

In [1]:
# import os
# os.environ["JAX_PLATFORM_NAME"] = "cpu"  # Before importing Jax

import jax
import jax.numpy as jnp
import numpy as np


https://jax.readthedocs.io/en/latest/

Jax is "accelerated NumPy". It combines:
* **XLA (accelerated linear algebra)**: can perform calculations on CPU, GPU, clusters, ...
* **just-in-time compilation**: needs pure functions 
* **automatic differentiation**: the backbone of machine learning
* **additional functionality**: processing PyTrees, `vmap`, ...

#### Jax Numpy

`jax.numpy` is almost completely the same a NumPy (e.g, you can use `jnp.sin` instead of `np.sin`).
Jax arrays are agnostic to the device used (CPU or GPU).
You can also mix Jax and NumPy array, altought that is not adviced because of frequent array transfers.
There is also a lower level `jax.lax` API.

The main difference is that Jax arrays are immutable.
Think about it in this way: moving data from and to GPU is expensive.
You want to put data to GPU and just do calculations.


In [2]:
# One can easily convert Numpy array to Jax array
jax_array = jnp.array(np.random.rand(6))

# Standard numpy functions
print(jnp.sin(jax_array > 0.5))

# One can mix Jax and NumPy array (but shouldn't)
print(jax_array + np.random.rand(6))

# Jax arrays are immutable
# jax_array[2] = 1  # -- does not work
print(jax_array.at[2].set(10))  # This returns a copy
print(jax_array)  # No change in original array


[0.84147096 0.         0.84147096 0.84147096 0.         0.        ]
[1.8338766  0.91905963 0.9378662  0.85064036 0.7876469  0.7818876 ]
[ 0.99444187  0.28290102 10.          0.6057231   0.33909068  0.07518139]
[0.99444187 0.28290102 0.86436254 0.6057231  0.33909068 0.07518139]


I0000 00:00:1701251664.922758       1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


#### JIT

Jax JIT works with sending a "fake" trace array with the same size and dtype as the input array through the function.
All array sizes have to be known at compile time.

We get an error like ` Abstract tracer value encountered where concrete value is expected`

Jax functions also have to be pure functions:
- same output for the same inputs
- no side effects (no change to inputs, no prints)

In [41]:
# @jax.jit
def get_large_elements(x):
  print("Getting indices")
  inds = jnp.where(x > 0.5)[0]  # -- fails here: We do not know length of inds
  print("Getting values")
  y = x[inds]
  return y

get_large_elements(np.random.rand(10))


Getting indices
Getting values


array([0.73435697])

In [49]:
@jax.jit
def get_large_elements(x):
  print("**Getting indices**")  # → jax.debug.print
  y = jnp.where(x > 0.5, x, 0)  # -- Now we know size of y at compile time

  # We cannot change x if jax array or compiled -- pure function
  # print(type(x))
  if isinstance(x, np.ndarray):
    x[0] = 0
  else:
    x = x.at[0].set(.0)  # variable x is now different place in memory!

  return y

print(get_large_elements)

print("First run")
x = np.array(np.random.rand(10))
get_large_elements(x)
print(f"First element of x: {x[0]:f}")  # -- no change to x!

# jax.make_jaxpr(get_large_elements)(x)

# print("\nSecond run with the same-size array")
# get_large_elements(np.random.rand(10))
# print("\nRun with a different size array")
# get_large_elements(np.random.rand(12))


<PjitFunction of <function get_large_elements at 0x1255afa30>>
First run
**Getting indices**
First element of x: 0.150757


In [None]:
# Similar example:  (C++ would work here!)
@jax.jit
def jitted_if(x):
  if x > 0:  # -- problem, same for loops
    return True
  else:
    return False

jitted_if(0.7)


TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function jitted_if at /var/folders/yl/hmw36wjn7n91fv3_nqz4n34w0000gn/T/ipykernel_46230/2172352681.py:2 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

In [4]:
# Second example
@jax.jit
def jitted_if(x):
  print("Compiling...")
  if len(x.shape) > 1:
    return True
  else:
    return False


In [8]:
jitted_if(jax.random.normal(jax.random.PRNGKey(0), shape=(24,1)))  # Actually, compiled every time the shape is different


Compiling...


Array(True, dtype=bool)

#### Helper functions

Jax and NetKet have several utility functions to process trees:
* `jax.tree_map`  (used in updating parameters!)
* `jax.tree_util.tree_reduce` 
* `netket.jax.tree_ravel`

In [9]:
import netket as nk

model = nk.models.MLP(hidden_dims=[10,4,5])
variables = model.init(jax.random.PRNGKey(0), jnp.ones(shape=(10,3)))
variables


{'params': {'MLP_0': {'Dense_0': {'kernel': Array([[-0.5679199 , -0.25000479, -0.33221843,  0.5674229 ,  0.48125366,
            -0.78459463,  0.73191763, -0.10213185,  0.21821541, -0.76410572],
           [-0.32451288, -0.06710989,  0.79808229,  0.49930144,  0.85064123,
            -0.18705461, -0.21961051, -0.38433781, -1.10589124,  0.47019935],
           [-0.18644778,  0.09819171,  0.28896983, -0.81544511,  0.96405829,
            -0.24335677, -1.04170972,  0.66292027,  0.13683804,  0.67856372]],      dtype=float64),
    'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float64)},
   'Dense_1': {'kernel': Array([[-4.87397305e-01, -1.82726739e-01, -2.70678735e-01,
             6.30144251e-02],
           [-4.98449327e-01,  3.24769067e-01,  6.21437443e-01,
             4.44547692e-01],
           [ 3.61663103e-01,  3.77743065e-01, -2.62377913e-01,
             3.36596660e-01],
           [-2.67408498e-01,  3.82106092e-01,  1.60904711e-01,
            -3.05882532e-01],
   

Parameters are in Flax (and thus in NetKet) stored in a Python dictionary tree:

In [24]:
variables["params"].keys()


dict_keys(['MLP_0'])

In [25]:
# Get shapes of all tree elements
jax.tree_map(lambda leaf: leaf.shape, variables)


{'params': {'MLP_0': {'Dense_0': {'bias': (10,), 'kernel': (3, 10)},
   'Dense_1': {'bias': (4,), 'kernel': (10, 4)},
   'Dense_2': {'bias': (5,), 'kernel': (4, 5)},
   'Dense_3': {'kernel': (5, 1)}}}}

In [26]:
# Number of variables in a tree
jax.tree_util.tree_reduce(lambda acc, leaf: acc + np.prod(leaf.shape), variables, 0)


114

In [27]:
numbers, unravel_fun = nk.jax.tree_ravel(variables)
print("Shape of all unraveled parameters:", numbers.shape)
unravel_fun(numbers)  # Same structures as variables


Shape of all unraveled parameters: (114,)


{'params': {'MLP_0': {'Dense_0': {'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float64),
    'kernel': Array([[-0.5679199 , -0.25000479, -0.33221843,  0.5674229 ,  0.48125366,
            -0.78459463,  0.73191763, -0.10213185,  0.21821541, -0.76410572],
           [-0.32451288, -0.06710989,  0.79808229,  0.49930144,  0.85064123,
            -0.18705461, -0.21961051, -0.38433781, -1.10589124,  0.47019935],
           [-0.18644778,  0.09819171,  0.28896983, -0.81544511,  0.96405829,
            -0.24335677, -1.04170972,  0.66292027,  0.13683804,  0.67856372]],      dtype=float64)},
   'Dense_1': {'bias': Array([0., 0., 0., 0.], dtype=float64),
    'kernel': Array([[-4.87397305e-01, -1.82726739e-01, -2.70678735e-01,
             6.30144251e-02],
           [-4.98449327e-01,  3.24769067e-01,  6.21437443e-01,
             4.44547692e-01],
           [ 3.61663103e-01,  3.77743065e-01, -2.62377913e-01,
             3.36596660e-01],
           [-2.67408498e-01,  3.82106092e-01

## Flax


In [11]:
from typing import Any

import jax
import jax.numpy as jnp
import flax
from flax import linen as nn


https://flax.readthedocs.io/en/latest/

Flax is framework based on Jax used to implement neural network models "using functional approach".

* *jax*: Always use `jax.numpy` instead of `numpy` when defining models.
* *functional*: The model does not store parameters (variables), it only provides information how to initialize parameters and for transformation
$$
f(v_{in}, x)  \rightarrow v_{out}, y
$$* A typical error is `Can't call compact methods on unbound modules`.
* First axis of $x$ can be for different samples. NetKet has more than two dimensions (MPI, chunk size).
* Variables are stored in a dictionary tree.
* If you are using complex parameters, use NetKets linen: `netket.nn`

In [12]:
layer = nn.Dense(features=4)  # One layer FFN with 4 neurons

# Initialize parameters
# layer.init(jax.random.key(0))  # -- does not work, we need to provide intput (shape)
x_dim = 3
num_samples = 5
x_in = jnp.ones((num_samples, x_dim))
variables = layer.init(jax.random.key(0), x_in)

# There are no parameters stored in the object
# layer.variables  # -- does not work
variables


{'params': {'kernel': Array([[ 0.4087802 ,  0.43891278, -0.23872387, -0.8494273 ],
         [ 0.41122693, -0.5888459 , -0.55229884,  0.49776074],
         [ 0.3480036 , -0.7046275 , -0.30813402, -1.21659   ]],      dtype=float32),
  'bias': Array([0., 0., 0., 0.], dtype=float32)}}

In [13]:
# layer(variables, x_in)  # -- does not work
layer.apply(variables, x_in)


Array([[ 1.16801071, -0.85456064, -1.09915674, -1.56825659],
       [ 1.16801071, -0.85456064, -1.09915674, -1.56825659],
       [ 1.16801071, -0.85456064, -1.09915674, -1.56825659],
       [ 1.16801071, -0.85456064, -1.09915674, -1.56825659],
       [ 1.16801071, -0.85456064, -1.09915674, -1.56825659]],      dtype=float64)

Of course, we can change the number of samples in `x_in` but not the dimension of `x_in` space.
It relies only on shape mismatch!

In [14]:
layer.apply(variables, jnp.ones(shape=(2*num_samples, x_dim)))
# layer.apply(variables, jnp.ones(shape=(num_samples, x_dim + 1)))  # -- does not work


Array([[ 1.16801071, -0.85456064, -1.09915674, -1.56825659],
       [ 1.16801071, -0.85456064, -1.09915674, -1.56825659],
       [ 1.16801071, -0.85456064, -1.09915674, -1.56825659],
       [ 1.16801071, -0.85456064, -1.09915674, -1.56825659],
       [ 1.16801071, -0.85456064, -1.09915674, -1.56825659],
       [ 1.16801071, -0.85456064, -1.09915674, -1.56825659],
       [ 1.16801071, -0.85456064, -1.09915674, -1.56825659],
       [ 1.16801071, -0.85456064, -1.09915674, -1.56825659],
       [ 1.16801071, -0.85456064, -1.09915674, -1.56825659],
       [ 1.16801071, -0.85456064, -1.09915674, -1.56825659]],      dtype=float64)

Finally, we can save a model with specific parameters.
NetKet never does that.


In [15]:
binded_model = layer.bind(variables)
binded_model(x_in)


Array([[ 1.16801071, -0.85456064, -1.09915674, -1.56825659],
       [ 1.16801071, -0.85456064, -1.09915674, -1.56825659],
       [ 1.16801071, -0.85456064, -1.09915674, -1.56825659],
       [ 1.16801071, -0.85456064, -1.09915674, -1.56825659],
       [ 1.16801071, -0.85456064, -1.09915674, -1.56825659]],      dtype=float64)

The same functional property holds for NetKet, but we have a helper function `log_value`

In [22]:
import netket as nk

g = nk.graph.Chain(length=4, pbc=True)
hi = nk.hilbert.Spin(s=1/2, N=4)
ham = nk.operator.Ising(hilbert=hi, graph=g, h=1,J=-1)
model = nk.models.RBM(alpha=1)
vqs = nk.vqs.FullSumState(hi, model)
σ = vqs._all_states


In [26]:
vqs.log_value(σ)
?? vqs.log_value


[0;31mSignature:[0m  [0mvqs[0m[0;34m.[0m[0mlog_value[0m[0;34m([0m[0mσ[0m[0;34m:[0m [0mjax[0m[0;34m.[0m[0mArray[0m[0;34m)[0m [0;34m->[0m [0mjax[0m[0;34m.[0m[0mArray[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
    [0;32mdef[0m [0mlog_value[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mσ[0m[0;34m:[0m [0mjnp[0m[0;34m.[0m[0mndarray[0m[0;34m)[0m [0;34m->[0m [0mjnp[0m[0;34m.[0m[0mndarray[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0;34mr"""[0m
[0;34m        Evaluate the variational state for a batch of states and returns[0m
[0;34m        the logarithm of the amplitude of the quantum state.[0m
[0;34m[0m
[0;34m        For pure states, this is :math:`\log(\langle\sigma|\psi\rangle)`,[0m
[0;34m        whereas for mixed states[0m
[0;34m        this is :math:`\log(\langle\sigma_r|\rho|\sigma_c\rangle)`, where[0m
[0;34m        :math:`\psi` and :math:`\rho` are respectively a pure state[0m
[0;34m        (wavefunction) and a m

The easiest way to implement a model is to subclass `flax.linen.Module`:

In [15]:
from typing import Any

class TwoLayerFFN(nn.Module):
  # Here we define model's parameters
  # We have to provide typing if we want to change values at initialization
  param_dtype: Any = jnp.float64
  num_features_1: int = 4
  num_features_2: int = 3
  use_bias: Any = True
  activation_fun: Any = nn.relu

  # nn.compact decorator provides a compact way of defining a model
  @nn.compact
  def __call__(self, x):
    # Let us do first layer using Flax:
    x = nn.Dense(name="Layer1", features=self.num_features_1, use_bias=self.use_bias, dtype=self.param_dtype)(x)
    x = self.activation_fun(x)

    # And second layer by hand
    # We have to provide 4 variables:
    #   - parameter name
    #   - initializer function (nn.initializers.normal() returns a function)
    #     initializer takes two inputs: random gen key and shape.
    #     We provide the former when we call mode.init(...)
    #   - tensor shape
    #   - tensor dtype
    W  = self.param("Layer2/Weights", nn.initializers.normal(), (self.num_features_1, self.num_features_2), self.param_dtype)  # Not shape= _2, _1 [*]
    b = self.param("Layer2/Bias", nn.initializers.normal(), (self.num_features_2,), self.param_dtype)
    x = x@W + b
    x = self.activation_fun(x)  # First dimension of x can be num_of_samples [*]

    # Finally we just sum outputs from the second layer
    return jnp.sum(x, axis=1)


In [16]:
model = TwoLayerFFN(num_features_1=6)  # activation_fun=jnp.sin
variables = model.init(jax.random.PRNGKey(0), jnp.ones(shape=(10,3)))

model.apply(variables, jnp.ones(shape=(6,3)))


  return [jnp.asarray(x, dtype) if x is not None else None for x in args]


Array([0.01679351, 0.01679351, 0.01679351, 0.01679351, 0.01679351,
       0.01679351], dtype=float32)

## Optax

https://optax.readthedocs.io/en/latest/

For the moment, NetKet's optimizer suffice.

## Plum – multiple dispatch

"Polymorphism where function overloading happens at runtime and not at compilation time"

Python already has single dispatch (function is decided on the first argument, e.g. `self`).

In [1]:
from plum import dispatch


In [2]:
@dispatch
def div(x, y):
  print("No idea")

@dispatch
def div(x: float, y: int):
  print(f"float/int: {x/y}")

@dispatch
def div(x: int, y: int):
  print(f"int/int: {x//y}")

div(2.3, 4.0)
div(4.3, 2)
div(5, 2)


No idea
float/int: 2.15
int/int: 2
