# Exploring Shastry-Sutterland model with neural network variational wave function

In [44]:
import netket as nk
import numpy as np
import time
import json
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import jax, flax, optax
from sys import version as pyvers

print("Python version: {}".format(pyvers))
print("NetKet version: {}".format(nk.__version__))
print("NumPy version: {}".format(np.__version__))

Python version: 3.8.10 (default, Nov 26 2021, 20:14:08) 
[GCC 9.3.0]
NetKet version: 3.3.3
NumPy version: 1.20.3


### Setup relevant parameters of the simulation 

In [45]:
"""lattice"""
SITES    = 8              # 4, 8, 16, 20, 36, 64, 100 ... number of particles
JEXCH1   = .2             # nn interaction (denoted J)
JEXCH2   = 1              # nnn interaction (denoted J')
H_Z      = 0              # external magnetic field (denoted h)

"""neural network"""
MACHINE  = "RBM"          # RBM, RBM-b, sRBM, pRBM, GCNN, pmRBM, Jastrow, Jastrow+b
DTYPE    = np.complex128  # data-type of weights in neural network (pmRBM uses always just floats)
ALPHA    = 32             # size of the RBM, alpha = N_hidden / N_visible

"""machine learning"""
TOTAL_SZ = None           # 0, None ... restriction of Hilbert space's magnetization
ETA      = .01            # learning rate (0.01 usually works)
SIGMA    = .01            # initial variance of parameters (distributed according to a normal distribution)
SAMPLER  = 'exact'        # 'local' = MetropolisLocal, 'exchange' = MetropolisExchange, 'exact' = ExactSampler
SAMPLES  = 1000           # number of Monte Carlo samples
NUM_ITER = 200            # number of convergence iterations

"""in case of GCNN machine, shape specification is needed"""
N_LAYERS = 2              # number of layers
FEATURES = (8,4)          # dimensions of layers

"""output file name"""
OUT_NAME = "SSM_"+str(SITES)+"j1="+str(JEXCH1)

### Lattice definition
Basic structure of tiled lattice tiles is implemented in `lattice_and_ops.py` file. Class `Lattice` implements relative positional functions, 
- e.g. *`rt(node)` returns the index of the right neighbour of a site with index `node`*

The `for` loop constructs full Shastry-Sutherland lattice (with PBC) using these auxiliary positional functions.

In [46]:
from lattice_and_ops import Lattice
lattice = Lattice(SITES)

# Construction of custom graph according to tiled lattice structure defined in the Lattice class.
edge_colors = []
for node in range(SITES):
    edge_colors.append([node,lattice.rt(node), 1])  # horizontal connections
    edge_colors.append([node,lattice.bot(node), 1]) # vertical connections
    row, column = lattice.position(node)
    if column%2 == 0:
        if row%2 == 0:
            edge_colors.append([node,lattice.lrt(node),2]) # diagonal bond
        else:
            edge_colors.append([node,lattice.llft(node),2]) # diagonal bond

g = nk.graph.Graph(edges=edge_colors)
N = g.n_nodes

hilbert = nk.hilbert.Spin(s=.5, N=g.n_nodes, total_sz=TOTAL_SZ)

### Characters of the symmetries
In case of GCNN or pRBM, we need to specify the characters of the symmetry transformations.
- DS phase anti-symmetric w.r.t. permutations with negative sign and symmetric w.r.t. permutaions with postive sign
- PS is more complicated due to double degeneracy and is not implemented here
- AF phase is always symmetric for all permutations  

In [47]:
if MACHINE in ("GCNN","pRBM","pRBM_transl"):
    print("There are", len(g.automorphisms()), "in the full symmetry group.")
    if JEXCH1 < 0.6: # deciding point between DS and AF phase is set to 0.6
        # DS phase is partly anti-symmetric, the (anti)symmetry is given by the permutation parity
        characters = []
        from lattice_and_ops import permutation_sign
        for perm in g.automorphisms():
            characters.append(permutation_sign(np.asarray(perm)))
        characters_1 = np.asarray(characters,dtype=complex)
    else:
        # AF phase if fully symmetric, hence we set characters to ones
        characters_1 = np.ones((len(g.automorphisms()),), dtype=complex)
    characters_2 = np.ones((len(g.automorphisms()),), dtype=complex)

### Translations

If we want to include only translations, we have to exclude some symmetries from the set `g.automorphisms()` of all symmetry permutations.


⚠️ WARNING ⚠️ This part is not fully automated for general lattice yet. Translations are currently being picked by hand from the group of all automorphisms.

In [48]:
if MACHINE in ("sRBM_transl","GCNN_transl","pRBM_transl"):
    if not N in (4,16):
        raise NotImplementedError("Extraction of translations from the group of automorphisms is not implemented yet.")
    translations = []
    for perm in g.automorphisms():
        aperm = np.asarray(perm)
        if N == 4:
            # there is no translation
            if (aperm[0],aperm[1]) in ((0,1),(1,0),(2,3),(3,2)): 
              translations.append(nk.utils.group._permutation_group.Permutation(aperm))
        elif N==16:
            # there are 3 translations
            if (aperm[0],aperm[1],aperm[3]) in ((0,1,3),(2,3,1),(8,9,11),(10,11,9)): 
                translations.append(nk.utils.group._permutation_group.Permutation(aperm))
    translation_group = nk.utils.group._permutation_group.PermutationGroup(translations,degree=SITES)
    print("Out of", len(g.automorphisms()), "permutations,",len(translation_group), "translations were picked.")

## Hamoltonian definition
*Note.* Tha Hamiltonian used here differs by a factor of 4 from the Hamiltonian used in the thesis. Hence you need to devide all of the reuslts by 4 to get the same values as in the thesis. 
$$ H = J_{1} \sum\limits_{\langle i,j \rangle}^{L} \vec{\sigma}_{i} \cdot \vec{\sigma}_{j} + J_{2} \sum\limits_{\langle i,j \rangle'}^{L}  \vec{\sigma}_{i} \cdot \vec{\sigma}_{j} + 2h\sum\limits_{i} \sigma^z_{i}\,. $$



Axiliary constant operators used to define hamiltonian are loaded from the external file. They are pre-defined in the `HamOps` class.

In [49]:
from lattice_and_ops import HamOps
ho = HamOps()
ha_1 = nk.operator.GraphOperator(hilbert, graph=g, bond_ops=ho.bond_operator(JEXCH1,JEXCH2, h_z=H_Z, use_MSR=False), bond_ops_colors=ho.bond_color)
ha_2 = nk.operator.GraphOperator(hilbert, graph=g, bond_ops=ho.bond_operator(JEXCH1,JEXCH2, h_z=H_Z, use_MSR=True), bond_ops_colors=ho.bond_color)


### Magnetization operator definition
$$ \hat{m}_z := \sum\limits_i \sigma_i^z $$

In [50]:
m_z = sum(nk.operator.spin.sigmaz(hilbert, i) for i in range(hilbert.size))
from lattice_and_ops import Operators, Lattice
ops = Operators(lattice,hilbert,ho.mszsz,ho.exchange)

## Exact diagonalization in case of $N<20$

In [51]:
if g.n_nodes < 20:
    start = time.time()
    if g.n_nodes < 15:
        evals, eigvects = nk.exact.full_ed(ha_1, compute_eigenvectors=True)
    else:
        evals, eigvects = nk.exact.lanczos_ed(ha_1, k=3, compute_eigenvectors=True)
    end = time.time()
    diag_time = end - start
    print("Ground state energy:",evals[0], "\nIt took ", round(diag_time,2), "s =", round((diag_time)/60,2),"min")
else:
    print("System is too large for exact diagonalization. Setting exact_ground_energy = 0 (which is wrong)")
    evals = [0,0,0]
    eigvects = None 
exact_ground_energy = evals[0]

Ground state energy: -23.99999999999996 
It took  1.15 s = 0.02 min


# Machine learning

## Machine definition and other auxiliary `netket` objects
We define two sets of these objects, usually: 
- variables ending with ...`_1` are calculated in the normal basis,
- variables ending with ...`_2` are calculated in the MSR basis.

But they can be modified to be used in a different way when we need to compare two different models.

In [52]:
# Selection of machine type
if MACHINE == "RBM":
    machine_1 = nk.models.RBM(dtype=DTYPE, alpha=ALPHA)
    machine_2 = nk.models.RBM(dtype=DTYPE, alpha=ALPHA)
elif MACHINE == "RBM-b":
    machine_1 = nk.models.RBM(dtype=DTYPE, alpha=ALPHA, use_visible_bias=False) 
    machine_2 = nk.models.RBM(dtype=DTYPE, alpha=ALPHA, use_visible_bias=False)
elif MACHINE == "Jastrow":
    machine_1 = nk.models.Jastrow(dtype=DTYPE)
    machine_2 = nk.models.Jastrow(dtype=DTYPE)
elif MACHINE == "Jastrow+b":
    from lattice_and_ops import Jastrow_b
    machine_1 = Jastrow_b()
    machine_2 = Jastrow_b()
elif MACHINE == "sRBM":
    machine_1 = nk.models.RBMSymm(g.automorphisms(), dtype=DTYPE, alpha=ALPHA)
    machine_2 = nk.models.RBMSymm(g.automorphisms(), dtype=DTYPE, alpha=ALPHA)
elif MACHINE == "sRBM_transl":
    machine_1 = nk.models.RBMSymm(translation_group, dtype=DTYPE, alpha=ALPHA)
    machine_2 = nk.models.RBMSymm(translation_group, dtype=DTYPE, alpha=ALPHA)
elif MACHINE == "pRBM":
    from pRBM import pRBM
    machine_1 = pRBM(symmetries=g.automorphisms(), dtype=DTYPE, layers=1, features=ALPHA, characters=characters_1, output_activation=nk.nn.log_cosh, use_bias=True, use_visible_bias=True)
    machine_2 = pRBM(symmetries=g.automorphisms(), dtype=DTYPE, layers=1, features=ALPHA, characters=characters_2, output_activation=nk.nn.log_cosh, use_bias=True, use_visible_bias=True)
elif MACHINE == "pRBM_transl":
    from pRBM import pRBM
    machine_1 = pRBM(symmetries=translation_group, dtype=DTYPE, layers=1, features=ALPHA, characters=characters_1, output_activation=nk.nn.log_cosh, use_bias=True, use_visible_bias=True)
    machine_2 = pRBM(symmetries=translation_group, dtype=DTYPE, layers=1, features=ALPHA, characters=characters_2, output_activation=nk.nn.log_cosh, use_bias=True, use_visible_bias=True)
elif MACHINE == "GCNN":
    machine_1 = nk.models.GCNN(symmetries=g.automorphisms(), dtype=DTYPE, layers=N_LAYERS, features=FEATURES, characters=characters_1)
    machine_2 = nk.models.GCNN(symmetries=g.automorphisms(), dtype=DTYPE, layers=N_LAYERS, features=FEATURES, characters=characters_2)
elif MACHINE == "pmRBM":
    machine_1 = nk.models.RBMModPhase(alpha=ALPHA, use_hidden_bias=True, dtype=np.float64)
    machine_2 = nk.models.RBMModPhase(alpha=ALPHA, use_hidden_bias=True, dtype=np.float64)
    # A linear schedule varies the learning rate from 0 to 0.01 across 600 steps.
    modulus_schedule_1=optax.linear_schedule(0,0.01,NUM_ITER)
    modulus_schedule_2=optax.linear_schedule(0,0.01,NUM_ITER)
    # The phase starts with a larger learning rate and then is decreased.
    phase_schedule_1=optax.linear_schedule(0.05,0.01,NUM_ITER)
    phase_schedule_2=optax.linear_schedule(0.05,0.01,NUM_ITER)
    # Combine the linear schedule with SGD
    optm_1=optax.sgd(modulus_schedule_1)
    optp_1=optax.sgd(phase_schedule_1)
    optm_2=optax.sgd(modulus_schedule_2)
    optp_2=optax.sgd(phase_schedule_2)
    # The multi-transform optimizer uses different optimisers for different parts of the parameters.
    optimizer_1 = optax.multi_transform({'o1': optm_1, 'o2': optp_1}, flax.core.freeze({"Dense_0":"o1", "Dense_1":"o2"}))
    optimizer_2 = optax.multi_transform({'o1': optm_2, 'o2': optp_2}, flax.core.freeze({"Dense_0":"o1", "Dense_1":"o2"}))
else:
    raise Exception(str("undefined MACHINE: ")+str(MACHINE))

# Selection of sampler type
if SAMPLER == 'local':
    sampler_1 = nk.sampler.MetropolisLocal(hilbert=hilbert)
    sampler_2 = nk.sampler.MetropolisLocal(hilbert=hilbert)
elif SAMPLER == 'exact':
    sampler_1 = nk.sampler.ExactSampler(hilbert=hilbert)
    sampler_2 = nk.sampler.ExactSampler(hilbert=hilbert)
else:
    sampler_1 = nk.sampler.MetropolisExchange(hilbert=hilbert, graph=g)
    sampler_2 = nk.sampler.MetropolisExchange(hilbert=hilbert, graph=g)
    if SAMPLER != 'exchange':
        print("Warning! Undefined cf.SAMPLER:", SAMPLER, ", dafaulting to MetropolisExchange cf.SAMPLER")

if MACHINE != "pmRBM":
    optimizer_1 = nk.optimizer.Sgd(learning_rate=ETA)
    optimizer_2 = nk.optimizer.Sgd(learning_rate=ETA)

# Stochastic Reconfiguration as a preconditioner
sr_1  = nk.optimizer.SR(diag_shift=0.01)
sr_2  = nk.optimizer.SR(diag_shift=0.01)

# The variational state (former name: nk.variational.MCState)
vs_1 = nk.vqs.MCState(sampler_1 , machine_1 , n_samples=SAMPLES)
vs_2 = nk.vqs.MCState(sampler_2 , machine_2 , n_samples=SAMPLES)
vs_1.init_parameters(jax.nn.initializers.normal(stddev=SIGMA))
vs_2.init_parameters(jax.nn.initializers.normal(stddev=SIGMA))
print("The",MACHINE,"machine has",vs_1.n_parameters,"variational parameters.")

gs_1 = nk.VMC(hamiltonian=ha_1 ,optimizer=optimizer_1 ,preconditioner=sr_1 ,variational_state=vs_1)
gs_2 = nk.VMC(hamiltonian=ha_2 ,optimizer=optimizer_2 ,preconditioner=sr_2 ,variational_state=vs_2) 

The sRBM machine has 137 variational parameters.


# Calculation
We let the calculation run for `NUM_ITERS` iterations for both cases _1 and _2 (without MSR and with MSR). If only one case is desired, set `runs` variable to `[1,0]` or `[0,1]`.

In [53]:
runs = [1,1] # run_MSR, run_normal
no_of_runs = np.sum(runs)
run_only_2 = (runs[1]==1 and runs[0]==0) # useful only in case of no_of_runs=1
print("J_1 =", JEXCH1,"; H_Z =",H_Z, end="; ")
if exact_ground_energy != 0:
    print("Expected exact energy:", exact_ground_energy)
for i,gs in enumerate([gs_1,gs_2][run_only_2:run_only_2+no_of_runs]):
    start = time.time()
    gs.run(out=OUT_NAME+str(i), n_iter=int(NUM_ITER))
    end = time.time()
    print("The calculation for {} in {} basis took {} min".format(MACHINE, "MSR" if i else "normal", (end-start)/60))


J_1 = 0.2 ; H_Z = 0; Expected exact energy: -23.99999999999996


100%|██████████| 200/200 [04:43<00:00,  1.42s/it, Energy=-23.999976-0.000014j ± 0.000088 [σ²=0.000007]]


The calculation for sRBM in normal basis took 4.8384570201237995 min


100%|██████████| 200/200 [04:56<00:00,  1.48s/it, Energy=-8.000e+00+1.727e-21j ± 1.708e-36 [σ²=4.149e-31]]

The calculation for sRBM in MSR basis took 4.959580759207408 min





## Energy Convergence Plotting
In case that the machine did not converge, we can re-run the previous cell and than skip the next cell. This way, the replotting just appends the new results and does not erase the previoius results. 

In [54]:
# Exact Energy Line
no_of_all_iters = NUM_ITER
figure = go.Figure(
    data=[
        go.Scatter(
            x=(0,no_of_all_iters),
            y=(exact_ground_energy,exact_ground_energy),
            mode="lines",line=go.scatter.Line(color="#000000",width=1), name="exact energy")],
    layout=go.Layout(
        template="simple_white",
        xaxis=dict(title="Iteration", mirror=True, showline=True),
        yaxis=dict(title="Energy", mirror=True, showline=True),
        title=("<b>"+"S-S"+" model </b>, N="+str(SITES)+", J2 ="+str(JEXCH2)+ ", J1 ="+str(JEXCH1)+" , η="+str(ETA)+", α="+str(ALPHA)+", samples="+str(SAMPLES)))
    ).add_hline(y=exact_ground_energy, opacity=1, line_width=1)


In [55]:
# import the data from log file
OUT_NAME_suffixless=OUT_NAME
data = []
names = ["normal","MSR"]
for i in range(no_of_runs):
    data.append(json.load(open(OUT_NAME_suffixless+str(i)+".log")))
if type(data[0]["Energy"]["Mean"]) == dict:
    energy_convergence = [data[i]["Energy"]["Mean"]["real"] for i in range(no_of_runs)]
else:
    energy_convergence = [data[i]["Energy"]["Mean"] for i in range(no_of_runs)]
# plot the energy dependance on the iteration number
for i in range(no_of_runs):
    figure.add_trace(go.Scatter(
        x=np.array(data[i]["Energy"]["iters"]), y=energy_convergence[i],
        name=names[i],
    ))

figure.update_layout(xaxis_title="Iteration",yaxis_title="Energy")
figure.show()

## Assessment of the other simulation results

Here, we calculate of how long it took to reach 99.5% of exact energy. The first value under 0.5% deviation counts as a converged state.

We also print the values of the order parameters.

In [56]:
threshold_energy = 0.995*exact_ground_energy
data = []
for i in range(no_of_runs):
    data.append(json.load(open(OUT_NAME+str(i)+".log")))
if type(data[0]["Energy"]["Mean"]) == dict:
    energy_convergence = [data[i]["Energy"]["Mean"]["real"] for i in range(no_of_runs)]
else:
    energy_convergence = [data[i]["Energy"]["Mean"] for i in range(no_of_runs)]
steps_until_convergence = [next((i for i,v in enumerate(energy_convergence[j]) if v < threshold_energy), -1) for j in range(no_of_runs)]

In [57]:
# Evaluation of order parameters.
from lattice_and_ops import Operators, Lattice
ops = Operators(lattice,hilbert,ho.mszsz,ho.exchange)
for i,gs in enumerate([gs_1,gs_2][run_only_2:run_only_2+no_of_runs]):
    print("Trained RBM with MSR:" if (i+run_only_2) else "Trained RBM without MSR:")
    print("m_d^2 =", gs.estimate(ops.m_dimer_op))
    print("m_p(MSR) =", gs.estimate(ops.m_plaquette_op_MSR))
    print("m_s^2 =", gs.estimate(ops.m_s2_op_MSR))
    print("m_s^2(MSR) =", gs.estimate(ops.m_s2_op))

Trained RBM without MSR:
m_d^2 = 0.9999975+0.0000011j ± 0.0000027 [σ²=0.0000000]
m_p(MSR) = -0.003+0.000j ± 0.018 [σ²=0.287]
m_s^2 = 0.0000001-0.0000010j ± 0.0000023 [σ²=0.0000000]
m_s^2(MSR) = 0.0000001-0.0000010j ± 0.0000023 [σ²=0.0000000]
Trained RBM with MSR:
m_d^2 = 3.333e-01-7.197e-23j ± 9.813e-18 [σ²=3.081e-33]
m_p(MSR) = -1.763e-44-9.199e-45j ± 4.039e-45 [σ²=1.605e-86]
m_s^2 = 1.250e-01+1.349e-23j ± 1.041e-38 [σ²=1.366e-73]
m_s^2(MSR) = 1.250e-01+1.349e-23j ± 1.047e-38 [σ²=1.371e-73]


### Standardised logging and printing of final results

In [58]:
from lattice_and_ops import log_results
log_results(JEXCH1,gs_1,gs_2,ops,SAMPLES,NUM_ITER,exact_ground_energy,steps_until_convergence,filename="out_"+OUT_NAME+".txt")

 0.200  -23.99998  0.00009    -8.00000  0.00000   -0.0030  0.0183   1.0000  0.0000   0.0000  0.0000   -0.0000  0.0000   0.3333  0.0000   0.1250  0.0000   -24.00000  1000   200 87, -1


In [59]:
print("{:6.3f} {:10.5f} {:8.5f}  {:7.4f} {:7.4f}  {:7.4f} {:7.4f}  {:7.4f} {:7.4f}  {:5.0f} {:5.0f}".format(JEXCH1, 
    gs_1.energy.mean.real,                          gs_1.energy.error_of_mean,
    gs_1.estimate(ops.m_dimer_op).mean.real,        gs_1.estimate(ops.m_dimer_op).error_of_mean,
    gs_1.estimate(ops.m_plaquette_op).mean.real,    gs_1.estimate(ops.m_plaquette_op).error_of_mean,
    gs_1.estimate(ops.m_s2_op).mean.real,           gs_1.estimate(ops.m_s2_op).error_of_mean, 
    SAMPLES, NUM_ITER, steps_until_convergence[0], sep='    '))

 0.200  -23.99998  0.00009   1.0000  0.0000  -0.0030  0.0183   0.0000  0.0000   1000   200
