# Exploring the ability of (non-symm) RBM to remember phases

In [1]:
import netket as nk
import numpy as np
import time
import json
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import jax
import flax
import optax
import pprint
import sys
sys.path.append("..") # to import files from parent folder
print("NetKet version: {}".format(nk.__version__))
print("NumPy version: {}".format(np.__version__))

NetKet version: 3.3.1
NumPy version: 1.20.3


Setup relevant parameters

In [2]:
"""lattice"""
SITES    = 8             # 4, 8, 16, 20 ... number of vertices in a tile determines the tile shape 
JEXCH1   = 1.2            # nn interaction
JEXCH2   = 1            # nnn interaction
#USE_MSR  = True        # should we use a Marshall sign rule?
"""machine learning"""
MACHINE = "Jastrow"
TOTAL_SZ = None            # 0, None ... restriction of Hilbert space
DTYPE = np.complex128      #np.float64 #double #np.complex128   # type of weights in neural network
SAMPLER = 'exact'       # 'local' = MetropolisLocal, 'exchange' = MetropolisExchange
ALPHA = 16               # N_hidden / N_visible
ETA   = .01             # learning rate (0.01 usually works)
SAMPLES = 1000
NUM_ITER = 2000
num_layers = 2 #2          # number of layers in G-CNN
feature_dims = (8,4) #(8,4) #(8,8,8,8) # dimensions of layers in G-CNN

OUT_NAME = "SS-RBM_ops"+str(SITES)+"j1="+str(JEXCH1) # output file name

Lattice and hamiltonian definition: &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; $ H = J_{1} \sum\limits_{\langle i,j \rangle}^{L} \vec{\sigma}_{i} \cdot \vec{\sigma}_{j} + J_{2} \sum\limits_{\langle\langle i,j \rangle\rangle_{SS}}^{L}  \vec{\sigma}_{i} \cdot \vec{\sigma}_{j}\,, $

In [3]:
from lattice_and_ops import Lattice
lattice = Lattice(SITES)

# Define custom graph
edge_colors = []
for node in range(SITES):
    edge_colors.append([node,lattice.rt(node), 1]) #horizontal connections
    edge_colors.append([node,lattice.bot(node), 1]) #vertical connections
    row, column = lattice.position(node)
    if column%2 == 0:
        if row%2 == 0:
            edge_colors.append([node,lattice.lrt(node),2])
        else:
            edge_colors.append([node,lattice.llft(node),2])

# Define the netket graph object
g = nk.graph.Graph(edges=edge_colors) #,n_nodes=3)
N = g.n_nodes

hilbert = nk.hilbert.Spin(s=.5, N=g.n_nodes, total_sz=TOTAL_SZ)

## Hamoltonian

In [4]:
from lattice_and_ops import HamOps
ho = HamOps()
ha = nk.operator.GraphOperator(hilbert, graph=g, bond_ops=ho.bond_operator(JEXCH1,JEXCH2, use_MSR=False), bond_ops_colors=ho.bond_color)
ha_MSR = nk.operator.GraphOperator(hilbert, graph=g, bond_ops=ho.bond_operator(JEXCH1,JEXCH2, use_MSR=True), bond_ops_colors=ho.bond_color)

## Exact diagonalization

In [5]:
if g.n_nodes < 20:
    start = time.time()
    if g.n_nodes < 15:
        evals, eigvects = nk.exact.full_ed(ha, compute_eigenvectors=True)
    else:
        evals, eigvects = nk.exact.lanczos_ed(ha, k=3, compute_eigenvectors=True)
    end = time.time()
    diag_time = end - start
    print("Ground state energy:",evals[0], "\nIt took ", round(diag_time,2), "s =", round((diag_time)/60,2),"min")
else:
    print("System is too large for exact diagonalization. Setting exact_ground_energy = 0 (which is wrong)")
    evals = [0,0,0]
    eigvects = None 
exact_ground_energy = evals[0]
# -36.2460684609957 

Ground state energy: -24.800000000000004 
It took  3.38 s = 0.06 min


In [6]:
print(evals)
# print([eigvects[i,0] if np.abs(eigvects[i,0]) > 0.001 else 0 for i in range(len(eigvects[:,0]))])

[-2.48000000e+01 -2.00000000e+01 -2.00000000e+01 -2.00000000e+01
 -1.44000000e+01 -1.44000000e+01 -1.44000000e+01 -1.44000000e+01
 -1.44000000e+01 -1.44000000e+01 -1.44000000e+01 -1.44000000e+01
 -1.44000000e+01 -1.44000000e+01 -1.44000000e+01 -1.44000000e+01
 -1.36000000e+01 -1.36000000e+01 -1.36000000e+01 -1.36000000e+01
 -1.20000000e+01 -1.04000000e+01 -1.04000000e+01 -1.04000000e+01
 -1.04000000e+01 -1.04000000e+01 -1.04000000e+01 -1.04000000e+01
 -1.04000000e+01 -1.04000000e+01 -1.04000000e+01 -1.04000000e+01
 -9.60000000e+00 -9.60000000e+00 -9.60000000e+00 -9.60000000e+00
 -8.80000000e+00 -8.80000000e+00 -8.80000000e+00 -8.80000000e+00
 -8.80000000e+00 -8.80000000e+00 -8.80000000e+00 -8.80000000e+00
 -8.80000000e+00 -8.80000000e+00 -8.80000000e+00 -8.80000000e+00
 -8.00000000e+00 -8.00000000e+00 -8.00000000e+00 -8.00000000e+00
 -8.00000000e+00 -8.00000000e+00 -8.00000000e+00 -8.00000000e+00
 -8.00000000e+00 -8.00000000e+00 -8.00000000e+00 -8.00000000e+00
 -5.60000000e+00 -4.80000

## RBM
```

In [7]:
if MACHINE == "ModPhase":
    # A linear schedule varies the learning rate from 0 to 0.01 across 600 steps.
    modulus_schedule=optax.linear_schedule(0,0.01,NUM_ITER)
    modulus_schedule_MSR=optax.linear_schedule(0,0.01,NUM_ITER)

    # The phase starts with a larger learning rate and then is decreased.
    phase_schedule=optax.linear_schedule(0.05,0.01,NUM_ITER)
    phase_schedule_MSR=optax.linear_schedule(0.05,0.01,NUM_ITER)

    # Combine the linear schedule with SGD
    optm=optax.sgd(modulus_schedule)
    optp=optax.sgd(phase_schedule)
    optm_MSR=optax.sgd(modulus_schedule_MSR)
    optp_MSR=optax.sgd(phase_schedule_MSR)

    # The multi-transform optimizer uses different optimisers for different parts of the
    # parameters.
    optimizer = optax.multi_transform({'o1': optm, 'o2': optp}, flax.core.freeze({"Dense_0":"o1", "Dense_1":"o2"}))
    optimizer_MSR = optax.multi_transform({'o1': optm_MSR, 'o2': optp_MSR}, flax.core.freeze({"Dense_0":"o1", "Dense_1":"o2"}))
else:
    optimizer = nk.optimizer.Sgd(learning_rate=ETA)
    optimizer_MSR = nk.optimizer.Sgd(learning_rate=ETA)

In [8]:
#definice modelu, sampleru atd.
if MACHINE == "RBM":
    machine =     nk.models.RBM(dtype=DTYPE, alpha=ALPHA)#, use_visible_bias=False) 
    machine_MSR = nk.models.RBM(dtype=DTYPE, alpha=ALPHA)#, use_visible_bias=False)
elif MACHINE == "RBMModPhase":
    machine = nk.models.RBMModPhase(alpha=ALPHA, use_hidden_bias=True, dtype=DTYPE)
    machine_MSR = nk.models.RBMModPhase(alpha=ALPHA, use_hidden_bias=True, dtype=DTYPE)
elif MACHINE == "GCNN":
    raise Exception("You need to define the characters of symmetry group.")
    machine     = nk.models.GCNN(symmetries=g.automorphisms(), dtype=DTYPE, layers=num_layers, features=feature_dims, characters=characters_dimer)
    machine_MSR = nk.models.GCNN(symmetries=g.automorphisms(), dtype=DTYPE, layers=num_layers, features=feature_dims, characters=characters_dimer_msr)
elif MACHINE == "RBMSymm":
    machine =     nk.models.RBMSymm(g.automorphisms(), dtype=DTYPE, alpha=ALPHA)#, use_visible_bias=False) 
    machine_MSR = nk.models.RBMSymm(g.automorphisms(), dtype=DTYPE, alpha=ALPHA)#, use_visible_bias=False)
elif MACHINE == "RBMSymm_transl":
    raise Exception("The restriction to translational group is not yet implemented.")
    machine =     nk.models.RBMSymm(translation_group, dtype=DTYPE, alpha=ALPHA)#, use_visible_bias=False) 
    machine_MSR = nk.models.RBMSymm(translation_group, dtype=DTYPE, alpha=ALPHA)#, use_visible_bias=False)
elif MACHINE == "Jastrow":
    from lattice_and_ops import Jastrow
    machine = Jastrow()
    machine_MSR = Jastrow()
else:
    raise Exception(str("undefined MACHINE: ")+str(MACHINE))

# Meropolis Exchange Sampling
if SAMPLER == 'local':
    sampler = nk.sampler.MetropolisLocal(hilbert=hilbert)
    sampler_MSR = nk.sampler.MetropolisLocal(hilbert=hilbert)
elif SAMPLER == 'exact':
    sampler = nk.sampler.ExactSampler(hilbert=hilbert)
    sampler_MSR = nk.sampler.ExactSampler(hilbert=hilbert)
else:
    sampler = nk.sampler.MetropolisExchange(hilbert=hilbert, graph=g)
    sampler_MSR = nk.sampler.MetropolisExchange(hilbert=hilbert, graph=g)
    if SAMPLER != 'exchange':
        print("Warning! Undefined fq.SAMPLER:", SAMPLER, ", dafaulting to MetropolisExchange fq.SAMPLER")


# Stochastic Reconfiguration
sr  = nk.optimizer.SR(diag_shift=0.01)
sr_MSR  = nk.optimizer.SR(diag_shift=0.01)

# The variational state (former name: nk.variational.MCState)
vss = nk.vqs.MCState(sampler, machine, n_samples=SAMPLES)
vs_MSR  = nk.vqs.MCState(sampler_MSR, machine_MSR, n_samples=SAMPLES)
vss.init_parameters(jax.nn.initializers.normal(stddev=0.001))
vs_MSR.init_parameters(jax.nn.initializers.normal(stddev=0.001))


gs_1 = nk.VMC(hamiltonian=ha ,optimizer=optimizer,preconditioner=sr,variational_state=vss)               # 0 ... symmetric
gs_MSR = nk.VMC(hamiltonian=ha_MSR ,optimizer=optimizer_MSR,preconditioner=sr_MSR,variational_state=vs_MSR)   # 1 ... symmetric+MSR



In [9]:
vss.n_parameters

72

# Model loading

In [10]:
j = 0.8
with open("SS-RBM_ops4j1="+str(j)+"0.mpack", "rb") as data_file:
    byte_data = data_file.read()
vss.variables = flax.serialization.from_bytes(vss.variables, byte_data)
with open("SS-RBM_ops4j1="+str(j)+"1.mpack", "rb") as data_file:
    byte_data = data_file.read()
vs_MSR.variables = flax.serialization.from_bytes(vs_MSR.variables, byte_data)

KeyError: 'kernel'

# Calculation

In [None]:
no_of_runs = 2 #2 ... bude se pocitat i druhý způsob (za použití MSR)
use_MSR = 0 # in case of one run
NUM_ITER = 1000
print("J_1 =", JEXCH1)
if exact_ground_energy != 0:
    print("Expected exact energy:", exact_ground_energy)
for i,gs in enumerate([gs_1,gs_MSR][use_MSR:use_MSR+no_of_runs]):
    start = time.time()
    gs.run(out=OUT_NAME+str(i), n_iter=int(NUM_ITER))#, obs={'symmetry':P(0,1)})
    end = time.time()
    print("The type {} of RBM calculation took {} min".format(i, (end-start)/60))


J_1 = 0.1
Expected exact energy: -6.000000000000002


100%|██████████| 1000/1000 [01:03<00:00, 15.70it/s, Energy=-5.98093+0.00195j ± 0.00036 [σ²=0.00013, R̂=1.4086]]   


The type 0 of RBM calculation took 1.0630114436149598 min


100%|██████████| 1000/1000 [01:02<00:00, 15.90it/s, Energy=-2.092623+0.000146j ± 0.000070 [σ²=0.000005, R̂=1.4086]]

The type 1 of RBM calculation took 1.0490299503008524 min





In [None]:
NUM_ITER = 1500
with open("out.txt", "a") as out_file:
    out_file.write("# N = {:1.0f}, samples = {:1.0f}, iters = {:1.0f}, sampler = {}, TOTAL_SZ = {}, machine = {}, dtype = {}, alpha = {:1.0f}, eta = {:1.5f}\n".format(SITES,SAMPLES,NUM_ITER,SAMPLER, str(TOTAL_SZ), MACHINE, DTYPE, ALPHA, ETA))

## Loop

In [None]:
NUM_ITER = 500
with open("out.txt", "a") as out_file:
    out_file.write("# N = {:1.0f}, samples = {:1.0f}, iters = {:1.0f}, sampler = {}, TOTAL_SZ = {}, machine = {}, dtype = {}, alpha = {:1.0f}, eta = {:1.5f}\n".format(SITES,SAMPLES,NUM_ITER,SAMPLER, str(TOTAL_SZ), MACHINE, DTYPE, ALPHA, ETA))

for j1 in np.arange(0.8,-0.01, step=-0.1):
# for j1 in np.arange(0.2,1.21, step=0.1): 
    # define a new hamiltonian
    ha = nk.operator.GraphOperator(hilbert, graph=g, bond_ops=ho.bond_operator(j1,JEXCH2, use_MSR=False), bond_ops_colors=ho.bond_color)

    # get exact energy
    if g.n_nodes < 20:
        exact_energy = nk.exact.lanczos_ed(ha, k=1, compute_eigenvectors=False)[0]

    # redefine the variational driver
    gs_1 = nk.VMC(hamiltonian=ha ,optimizer=optimizer,preconditioner=sr,variational_state=vss)
    
    # run the thermalisation
    gs_1.run(out="RBM", n_iter=int(NUM_ITER))

    # log the results
    with open("out.txt", "a") as out_file:
        out_file.write("{:9.5f}     {:9.5f}    {:9.5f} \n".format(j1, exact_energy, gs_1.energy.mean.real))

100%|██████████| 500/500 [00:04<00:00, 111.27it/s, Energy=-11.9937-0.0011j ± 0.0048 [σ²=0.0237]]        
100%|██████████| 500/500 [00:04<00:00, 122.45it/s, Energy=-11.999975-0.000017j ± 0.000023 [σ²=0.000001]]
 83%|████████▎ | 415/500 [00:03<00:00, 118.36it/s, Energy=-12.000020+0.000007j ± 0.000018 [σ²=0.000000]]


SystemError: CPUDispatcher(<function LocalOperator._get_conn_flattened_kernel at 0x7f88b93e48b0>) returned a result with an error set

In [11]:
## More efficient loop
N_PRE_ITER = 30
# NUM_ITER = 5500
MSR = not True

# for j1 in np.arange(0.0,5.21, step=1.1):
# for j1 in np.arange(1.9,-0.01, step=-0.1):
for j1 in np.arange(0.0,1.21, step=0.1): 
    # define a new hamiltonian
    ha = nk.operator.GraphOperator(hilbert, graph=g, bond_ops=ho.bond_operator(j1,JEXCH2, use_MSR=MSR), bond_ops_colors=ho.bond_color)

    # get exact energy
    if g.n_nodes < 20:
        exact_energy = nk.exact.lanczos_ed(ha, k=1, compute_eigenvectors=False)[0]

    # redefine the variational driver
    gs_1 = nk.VMC(hamiltonian=ha ,optimizer=optimizer,preconditioner=sr,variational_state=vss)
    print("J1 =", j1, "expected energy =", exact_energy)
    # run the thermalisation
    gs_1.run(out="RBM", n_iter=int(N_PRE_ITER))
    if gs_1.energy.mean.real > 0.95*exact_energy:
        gs_1.run(out="RBM", n_iter=int(NUM_ITER))

    # log the results
    with open("out.txt", "a") as out_file:
        out_file.write("{:9.5f}     {:9.5f}    {:9.5f} \n".format(j1, exact_energy, gs_1.energy.mean.real))

J1 = 0.0 expected energy = -12.0


  0%|          | 0/30 [00:00<?, ?it/s]

x=Traced<ShapedArray(complex128[8,8])>with<DynamicJaxprTrace(level=0/1)>  target=Traced<ShapedArray(complex128[8,8])>with<DynamicJaxprTrace(level=0/1)>
x=Traced<ShapedArray(complex128[8])>with<DynamicJaxprTrace(level=0/1)>  target=Traced<ShapedArray(complex128[8])>with<DynamicJaxprTrace(level=0/1)>
x=Traced<ShapedArray(complex128[8,8])>with<DynamicJaxprTrace(level=1/2)>  target=Traced<ShapedArray(complex128[8,8])>with<DynamicJaxprTrace(level=1/2)>
x=Traced<ShapedArray(complex128[8])>with<DynamicJaxprTrace(level=1/2)>  target=Traced<ShapedArray(complex128[8])>with<DynamicJaxprTrace(level=1/2)>
x=Traced<ShapedArray(complex128[8,8])>with<DynamicJaxprTrace(level=1/2)>  target=Traced<ShapedArray(complex128[8,8])>with<DynamicJaxprTrace(level=1/2)>
x=Traced<ShapedArray(complex128[8])>with<DynamicJaxprTrace(level=1/2)>  target=Traced<ShapedArray(complex128[8])>with<DynamicJaxprTrace(level=1/2)>
x=Traced<ShapedArray(complex128[8,8])>with<DynamicJaxprTrace(level=2/2)>  target=Traced<ShapedArray(

100%|██████████| 30/30 [00:00<00:00, 147.05it/s, Energy=3.9936+0.0053j ± 0.0066 [σ²=0.0377]]  
100%|██████████| 5500/5500 [00:35<00:00, 156.84it/s, Energy=-11.9999989+0.0000000j ± 0.0000061 [σ²=0.0000001]]    


J1 = 0.1 expected energy = -11.999999999999991


100%|██████████| 30/30 [00:00<00:00, 123.56it/s, Energy=-11.9999980-0.0000001j ± 0.0000019 [σ²=0.0000000]]


J1 = 0.2 expected energy = -12.000000000000007


100%|██████████| 30/30 [00:00<00:00, 113.01it/s, Energy=-11.9999988+0.0000004j ± 0.0000040 [σ²=0.0000000]]


J1 = 0.30000000000000004 expected energy = -11.999999999999998


100%|██████████| 30/30 [00:00<00:00, 116.87it/s, Energy=-12.0000039-0.0000012j ± 0.0000069 [σ²=0.0000000]]


J1 = 0.4 expected energy = -11.999999999999995


100%|██████████| 30/30 [00:00<00:00, 113.74it/s, Energy=-11.9999985+0.0000000j ± 0.0000012 [σ²=0.0000000]]


J1 = 0.5 expected energy = -11.999999999999998


100%|██████████| 30/30 [00:00<00:00, 116.54it/s, Energy=-12.0000024-0.0000000j ± 0.0000093 [σ²=0.0000001]]


J1 = 0.6000000000000001 expected energy = -11.999999999999993


100%|██████████| 30/30 [00:00<00:00, 116.31it/s, Energy=-11.9999993-0.0000001j ± 0.0000020 [σ²=0.0000000]]


J1 = 0.7000000000000001 expected energy = -12.800000000000011


100%|██████████| 30/30 [00:00<00:00, 115.92it/s, Energy=-12.0000064+0.0000006j ± 0.0000053 [σ²=0.0000000]]    
  1%|          | 32/5500 [00:00<00:48, 112.18it/s, Energy=-11.9999956+0.0000012j ± 0.0000046 [σ²=0.0000000]]


KeyboardInterrupt: 

In [None]:
from lattice_and_ops import Operators, Lattice
ops = Operators(lattice,hilbert,ho.mszsz,ho.exchange)
for i,gs in enumerate([gs_1,gs_MSR][use_MSR:use_MSR+no_of_runs]):
    print("Trained RBM with MSR:" if i else "Trained RBM without MSR:")
    print("m_d^2 =", gs.estimate(ops.m_dimer_op))
    print("m_p =", gs.estimate(ops.m_plaquette_op_MSR))
    print("m_s^2 =", gs.estimate(ops.m_s2_op_MSR))
    print("m_s^2 =", gs.estimate(ops.m_s2_op), "<--- no MSR!!")

Trained RBM without MSR:
m_d^2 = -0.33382-0.00172j ± 0.00038 [σ²=0.00014, R̂=1.2783]
m_p = 1.000e+00+0.000e+00j ± nan [σ²=0.000e+00]
m_s^2 = 0.630+0.001j ± 0.015 [σ²=0.234, R̂=1.3665]
m_s^2 = 2.00009+0.00184j ± 0.00088 [σ²=0.00079, R̂=1.1442] <--- no MSR!!
Trained RBM with MSR:
m_d^2 = -0.333513-0.000002j ± 0.000030 [σ²=0.000001, R̂=1.4086]
m_p = 1.000e+00+0.000e+00j ± nan [σ²=0.000e+00]
m_s^2 = 2.000333+0.000003j ± 0.000071 [σ²=0.000005, R̂=1.4086]
m_s^2 = 0.750+0.000j ± 0.014 [σ²=0.188, R̂=1.4086] <--- no MSR!!


In [None]:
print(gs_1.estimate(2*nk.operator.LocalOperator(hilbert,operators=[[1,0],[0,1]],acting_on=[0])))

2.000e+00+0.000e+00j ± nan [σ²=0.000e+00]


In [None]:
no_of_runs = 1
# exact_ground_energy = -6
JEXCH1 = 0.2

## Energy

In [None]:
# exact energy line
figure = go.Figure(
    data=[go.Scatter(x=(0,NUM_ITER),y=(exact_ground_energy,exact_ground_energy),mode="lines",line=go.scatter.Line(color="#000000",width=1), name="exact energy"),go.Scatter(x=(0,NUM_ITER),y=(.995*exact_ground_energy,.995*exact_ground_energy),mode="lines",line=go.scatter.Line(color="#000000",width=1), name="99.5 % of exact energy")], 
    layout=go.Layout(template="simple_white",
        xaxis=dict(title="Iteration", mirror=True, showline=True),
        yaxis=dict(title="Energy", mirror=True, showline=True),
        title=("<b>"+"S-S"+" model </b>, L="+str(SITES)+", J2 ="+str(JEXCH2)+ ", J1 ="+str(JEXCH1)+" , η="+str(ETA)+", α="+str(ALPHA)+", samples="+str(SAMPLES))))

In [None]:
# import the data from log file
#name = OUT_NAME+str(i)+".log"
name = "RBM.log"
data = []
for i in range(no_of_runs):
    data.append(json.load(open(name)))
names = ["AF init (j1=0.8)","MSR basis"]
if type(data[0]["Energy"]["Mean"]) == dict: #DTYPE in (np.complex128, np.complex64):#, np.float64):# and False:
    energy_convergence = [data[i]["Energy"]["Mean"]["real"] for i in range(no_of_runs)]
    # symmetry = [data[i]["symmetry"]["Mean"]["real"] for i in range(no_of_runs-use_MSR)]
else:
    energy_convergence = [data[i]["Energy"]["Mean"] for i in range(no_of_runs)]
    # symmetry = [data[i]["symmetry"]["Mean"] for i in range(no_of_runs-use_MSR)]
for i in range(no_of_runs):
    figure.add_trace(go.Scatter(
        x=data[i]["Energy"]["iters"], y=energy_convergence[i],
        name=names[i]
    ))
    # figure.add_trace(go.Scatter(
    #     x=data[i]["Energy"]["iters"], y=symmetry[i],
    #     name=names[i]+"_swap"
    # ))

#figure.add_hline(y=exact_gs_energy)
figure.update_layout(xaxis_title="Iteration",yaxis_title="Energy")
figure.show()

# Calculating symmetrizations