# Symmetries and VMC
**Author: Louis Sharma**


In this tutorial, we will learn how to use symmetries of quantum Hamiltonian to enhance the performance of a VMC calculation. More specifically, you will learn:


* How to symmetrize a `vstate` with respect to lattice symmetries. 
* How to implement custom symmetries.


## Heisenberg antiferromagnet
We will consider the Heisenberg antiferromagnet on a 2d $L \times L$ square lattice: 

$$
\hat H = J \sum_{\langle ij } \hat{\vec S}_i \cdot \hat{\vec S}_j
$$

where $J>0$ is the antiferromagnetic exchange coupling, $\langle ij \rangle$ refers to pairs of first neighbor sites on the lattice and $\vec{\hat S}_i = \frac 12 ( \hat \sigma_i^x, \hat \sigma_i^y, \hat \sigma_i^z)$ is the spin operator for site $i$.




## Setting up the problem

### Definining the lattice
First, we need to define a square lattice. This is done using the ```netket.graph.Square``` class. For this tutorial, we will pick $L=4$.


In [1]:
#import netket
import netket as nk
import jax

ModuleNotFoundError: No module named 'netket'

In [None]:
seed = jax.random.PRNGKey(1234) # For reproducibility

In [None]:
#define the square lattice
square_lattice = nk.graph.Square(length=4, pbc=True) # 4x4 square lattice with periodic boundary conditions

In [None]:
square_lattice.draw()

### Defining the Hilbert space
Then, we need to define the Hilbert space on which this model is defined. The relevant class here is `netket.hilbert.Spin`

In [None]:
hilbert = nk.hilbert.Spin(s=0.5, N=square_lattice.n_nodes) #16 spin 1/2 particles 
import numpy as np

In [None]:
print('Number of local degrees of freedom: {}'.format(hilbert.size))
print('Number of states: {}'.format(hilbert.n_states))

### Constructing the Hamiltonian
Now we will construct the Hamiltonian of this system.

In [None]:
#import the fermionic creation and annihilation operators
from netket.operator.spin import sigmax, sigmaz, sigmay

In [None]:
H = 0.0+0.0j
J = 1.0

for i,j in square_lattice.edges():
    H += J * sigmaz(hilbert=hilbert, site=i) * sigmaz(hilbert=hilbert, site=j)
    H += J * sigmax(hilbert=hilbert, site=i) * sigmax(hilbert=hilbert, site=j)
    H += J * sigmay(hilbert=hilbert, site=i) * sigmay(hilbert=hilbert, site=j)

## Exact diagonalization (ED)

Since $\hat H$ is sparse and its dimension is not too big, we can perform ED to get the ground state energy.

In [None]:
from scipy.sparse.linalg import eigsh

In [None]:
H_sp = H.to_sparse()
evals, evecs = eigsh(H_sp, k=1, which='SA')  # 'SA' means smallest algebraic eigenvalue
print("Ground state energy (exact diagonalization): ", evals[0])

## Symmetries
The Hamiltonian commutes with the set of operators that correspond to a *representation* of the space group of the lattice. From representation theory, we know that we can use the irreducible representations (irreps) of the group to block diagonalize $\hat H$ and restrict the search for the ground state to a particular irrep. 


The first step in doing this is to select the relevant group. 

Here, we will consider the translation group of the lattice. 

In [None]:
translation_group = square_lattice.translation_group()
for g in translation_group: 
    print(g)

The elements of `translation_group` correspond to permutations of the lattice sites. 

In [None]:
print("Permutation corresponding to a translation by R= [0,1]: ", translation_group[1].permutation_array)

We can also view the characters which classify different irreducible representations.

In [None]:
print("Second row of the character table:", translation_group.character_table()[1])

As it turns out, the characters of the translation group can all be written in the form:

$$\chi_{\vec k}(\vec R) = e^{i \vec k \cdot \vec R}$$

where $\chi_{\vec k}(\vec R)$ is the character corresponding to the translation by a lattice vector $\vec R$ and $\vec k$ is a vector in the first Brillouin zone.  

In principle, the true ground state may be found at any value of $\vec k$. However, due to symmetry, we can restrict our search for the ground state to the *irreducible Brillouin zone* (red triangle on the plot below). 

In [None]:
import matplotlib.pyplot as plt
kx = np.linspace(-np.pi, np.pi, 4+1, endpoint=True)
ky = np.linspace(-np.pi, np.pi, 4+1, endpoint=True)

Kx, Ky = np.meshgrid(kx, ky)
plt.scatter(Kx.flatten(), Ky.flatten(), s=10, color='black')
plt.plot([0,np.pi], [0, np.pi], color='r', lw=0.5)
plt.plot([0,np.pi], [0, 0], color='r', lw=0.5)
plt.plot([np.pi, np.pi], [0, np.pi], color='r', lw=0.5)
plt.gca().set_aspect('equal', adjustable='box')
plt.xlabel(r'$k_x$')
plt.ylabel(r'$k_y$')
plt.title('Brillouin zone of the square lattice')

The next step is to construct ```netket.operator``` objects from the group elements 
that can act on the states of the Hilbert space. This is known as a *representation*. For lattice symmetries, `netket` has built in methods
to construct representations, which is done using the `Representation`class. 

In [None]:
translation_group_representation = square_lattice.translation_group_representation(hilbert=hilbert)

In `netket`, a `Representation` object can be constructed by specifying a dictionary whose values are the elements of the representation. In this example, the elements of the translation group may be represented by `PermutationOperator` objects. These encode how permutations of the lattice sites act on the local degrees of freedom of the Hilbert space. 

For more details on the mathematical foundations of permutation operators, see the [symmetry documentation](../advanced/symmetry.md).


They are equipped with a `get_conn_padded` method, allowing them to act on states of the basis. 

In the next section, we will discuss in more detail how to construct `Representation` for custom groups.

## VMC

Within the VMC scheme, wavefunction amplitudes may be made symmetric with respect to a particular group $G$ and irrep $\mu$ using the following projection formula:

$$\psi_\mu(x) = \frac{d_\mu}{|G|} \sum_{g \in G} \chi_\mu^\ast(g) \psi( g^{-1}x)$$
where $x$ is the encoding of a basis state, $d_\mu$ is the dimension of the irrep, $|G|$ is the number of elements in the group, $\chi_\mu(g)$ is the character of the representation evaluated on element $g$ and $\hat U_g$ is the representation of $g$ on the Hilbert space. Here $\psi(g^{-1}x)$ is the *left action* of element $g$ on the state $x$. 

In the context of permutations and spin systems, we can view $g$ as a function from $\{0,\ldots, n-1\} \to \{0, \ldots, n-1\}$ and $x$ as a function from $\{0,\ldots, n-1\} \to \{0,1\}$ such that the left action is the composition:


$$\psi(g^{-1}x) = \psi(x \circ g)$$ 

For more mathematical details on permutations and their representations, see the [symmetry documentation](../advanced/symmetry.md).

In the following, we will compare the performance of an unsymmetrized ansatz versus a symmetrized one.

In [None]:
import netket.nn as nknn
import flax.linen as nn

import jax.numpy as jnp


class FFNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(
            features=2 * x.shape[-1],
            use_bias=True,
            param_dtype=np.complex128,
            kernel_init=nn.initializers.normal(stddev=0.01),
            bias_init=nn.initializers.normal(stddev=0.01),
        )(x)
        x = nknn.log_cosh(x)
        x = jnp.sum(x, axis=-1)
        return x


model = FFNN()

In [None]:
#model = nk.models.RBM(alpha=4) #define a restricted Boltzmann machine with hidden unit density alpha=2
sampler = nk.sampler.MetropolisLocal(hilbert=hilbert) #Metropolis-Hastings local sampler
optimizer = nk.optimizer.Sgd(learning_rate=5e-3) #stochastic gradient descent optimizer
solver  = nk.optimizer.solver.cholesky
preconditioner = nk.optimizer.SR(solver=solver, diag_shift=1e-6) #stochastic reconfiguration preconditioner

In [None]:
#train the unsymmetrized model
unsymmetrized_log = nk.logging.RuntimeLog()
unsymmetrized_vstate = nk.vqs.MCState(sampler, model, n_samples = 512, seed=seed)
unsymmetrized_log = nk.logging.RuntimeLog()
driver = nk.VMC(H, optimizer, variational_state=unsymmetrized_vstate)
driver.preconditioner = preconditioner
driver.run(n_iter=300, out=unsymmetrized_log)


To construct a symmetry projected wavefunction in netket, we use `project` method of the `Representation` class. 
This method takes in a variational state as well as the integer correspond to the id of the irrep. 
It then implements the above equation

In [None]:
vstate=nk.vqs.MCState(sampler, model, n_samples = 512, seed=seed) #reinitialize the variational state to symmetrize it
symmetrize_vstate0 = translation_group_representation.project(state=vstate, character_index=0) #project on the trivial represention

Then we just run the optimization as usual!

In [None]:
driver = nk.VMC(H, optimizer, variational_state=symmetrize_vstate0) 
driver.preconditioner = preconditioner
driver.run(n_iter=300, out=nk.logging.RuntimeLog()) #takes a bit of time to run

**Important**
In general, the ground state is not necessarily in the trivial sector. In practice, all relevant sectors should be scanned to find which one is energetically favorable. 
In the case of the translation group, this means all the points in the irreducible Brillouin zone should be checked. 

We can now compare the performance of the two models. 

In [None]:
def relative_error(approx, exact):
    return np.abs((approx - exact) / exact)

E_no_symm = float(unsymmetrized_vstate.expect(H).mean.real)
E_symm = float(symmetrize_vstate0.expect(H).mean.real)

print("Energy without symmetrization: {:.4f} Relative error without symmetrization: {:.4f} %".format(E_no_symm, 100*relative_error(E_no_symm, evals[0])))
print("Energy with symmetrization in the trivial irrep: {:.4f} Relative error with symmetrization in the trivial irrep: {:.4f} %".format(E_symm, 100*relative_error(E_symm, evals[0])))   

## Implementing custom symmetries
The Heisenberg Hamiltonian also commutes with the spin flip operator:

$$\hat \sigma^x = \bigotimes_i \hat \sigma_i^x$$

The set of operators $\hat I, \hat \sigma^x$ form a representation of the group $\mathbb Z_2$

In this part of the tutorial, we will see how to construct a representation of this group in `netket` on our spin Hilbert space.

The `Representation` class needs two fundamental things to function:

* A `FiniteGroup` object. This just the group. 
* A `dict` object mapping elements of the group to operators on the Hilbert space. 

The group $\mathbb Z_2$ is just a set with 2 elements  $\{e, g\}$ with one rule: $g^2= e$. A simple example of a group that follows this blueprint is the symmetric group $\mathcal S_2.$ This group can be implemented in `netket` using the `PermutationGroup` class. 

In [None]:
from netket.utils.group import PermutationGroup, Permutation, Identity
from netket.symmetry import Representation

In [None]:
e = np.array([0,1])
g = np.array([1,0])
group = PermutationGroup(elems=[Permutation(e, name='Identity()'), Permutation(g, name="SpinFlip()")], degree = 2)

From this object, we can extract the characters, etc... 

In [None]:
group.character_table()


Next, we need to define the operators $\hat I$ and $\hat \sigma^x$ which furnish the representation of our group on the spin Hilbert space

In [None]:
spin_flip = 1.0
for i in range(square_lattice.n_nodes):
    spin_flip *= sigmax(hilbert=hilbert, site=i)

identity = nk.operator.spin.identity(hilbert=hilbert) #identity operator

To check that this all works as expected, recall that the states of the computational basis of our spin $1/2$ system are encoded as `jax.Array` objects:

* `array[i]` refers to the spin on the $i$th site. 
* ` array[i] = -1` for spin down or `+1`for spin up

Therefore, we can generate a random state of the basis and apply our spin flip operator to it using the `get_conn_padded` method. 
The resulting array should send all $-1$ to $1$ and $1$ to $-1$ in the original array. 

In [None]:
state = hilbert.random_state(key=seed, size=1) #generate a random state of the basis
new_state, matrix_element = spin_flip.get_conn_padded(state)
print("Original state: ", state)
print("State after spin flip: ", new_state)

print("Sum of element-wise entries:", new_state + state)

Now we create our dictionnary and pass it to instantiate a `Representation` object. Note that the keys of the dictionnary must be identical to the elements of the `group` argument. 

In [None]:

representation_dict = {Permutation(e, name='Identity()'): identity, Permutation(g, name="SpinFlip()"): spin_flip}
spin_flip_representation = Representation(group=group, representation_dict=representation_dict)

Now we can optimize a new vstate which is symmetric with respect to this group.

In [None]:
vstate=nk.vqs.MCState(sampler, model, n_samples = 512, seed=seed) #reinitialize the variational state to symmetrize it
symmetrize_vstate_spin_flip = spin_flip_representation.project(state=vstate, character_index=0) #project on the trivial represention
driver = nk.VMC(H, optimizer, variational_state=symmetrize_vstate_spin_flip)
driver.preconditioner = preconditioner
driver.run(n_iter=300, out=nk.logging.RuntimeLog())

In [None]:
E_spin_flip = float(symmetrize_vstate_spin_flip.expect(H).mean.real)
print("Energy with symmetrization in the trivial irrep of the spin flip group: {:.4f} Relative error with symmetrization in the trivial irrep of the spin flip group: {:.4f} %".format(E_spin_flip, 100*relative_error(E_spin_flip, evals[0])))

## Combining representations
When we have two commuting groups, $G_1$ and $G_2$ and two representations $\hat U$ and $\hat V$ on the same vector space $\mathcal{H}$, 
we can define the following *product* representation $\hat \Gamma$ from $G_1 \times G_2 \to \mathcal{H}$  such that $\hat \Gamma(g_1 g_2) = \hat U_{g_1} \hat V_{g_2}$
As it turns out, characters of the irreps of $\hat \Gamma$, satisfy $\chi_{\mu, \nu}(g_1 g_2) = \chi_\mu(g_1) \chi_\nu(g_2)$  where $\chi_\mu$ (resp. $\chi_\nu$) are the characters of the irreps of $\hat U$ (resp. $\hat V$)

We can apply this to the translation group and the spin-flip group! 
Essentially, we can combine these two groups and classify the eigenstates by their momentum and their spin flip parity. 
To do this in `netket`, we first project the state onto an irrep of one group then do another projection onto the other group. 

In [None]:
vstate=nk.vqs.MCState(sampler, model, n_samples = 512, seed=seed) #fresh vstate
projector_T = translation_group_representation.projector(character_index=0) #projector on the trivial representation
projector_S = spin_flip_representation.projector(character_index=0) #projector on the trivial representation

## Extension to fermionic systems

The concepts we've discussed so far naturally extend to fermionic systems, but there are two subtleties to bear in mind:

* On top of "spatial" degrees of freedom, fermionic Hilbert spaces may have additional degrees of freedom, like spin. 
* Fermionic states are antisymmetric with respect to particle exchange. 

In `netket` fermionic Hilbert spaces are handled by the `netket.hilbert.SpinOrbitalFermions` class.
Let's define a  spin $1/2$ fermion Hilbert space on the square lattice. 

In [None]:
fermion_hilbert = nk.hilbert.SpinOrbitalFermions(n_orbitals=square_lattice.n_nodes, s = 1/2, n_fermions_per_spin=(8,8)) #half-filling, zero spin
print("Size of basis states: ", fermion_hilbert.size)