In [1]:
import copy
import os
import sys
import time

import jax
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from helper import df_to_latex, fig_path, set_figsize, tab_path

%matplotlib inline
%config InlineBackend.figure_format = 'retina'


# Import code from src
sys.path.insert(0, '../nqs/')
import nqs  # noqa

# Set plot and dataframe style
sns.set(context="paper", style='darkgrid', rc={"axes.facecolor": "0.96"})

fontsize = "large"
params = {"font.family": "serif",
          "font.sans-serif": ["Computer Modern"],
          "axes.labelsize": fontsize,
          "legend.fontsize": fontsize,
          "xtick.labelsize": fontsize,
          "ytick.labelsize": fontsize,
          "legend.handlelength": 2
          }

plt.rcParams.update(params)
plt.rc('text', usetex=True)

pd.set_option('display.max_columns', 50)

jax.config.update("jax_enable_x64", True)
jax.config.update('jax_platform_name', 'cpu')

In [2]:
nparticles = 1    # particles
dim = 1           # dimensionality
nhidden = 2       # hidden neurons

system = nqs.NQS(nparticles,
                 dim,
                 nhidden=nhidden,
                 interaction=False,
                 mcmc_alg='rwm',
                 nqs_repr='psi',
                 backend='numpy',
                 log=True
                 )

[32mINFO:NQS     [0m [35mNeural Network Quantum State initialized as RBM with 2 hidden neurons[0m


In [3]:
# Initialize parameters
system.init(sigma2=1.0, scale=3.0)

In [4]:
system.train(max_iter=500_000,
             batch_size=5_000,
             gradient_method='adam',
             eta=0.05,
             beta1=0.9,
             beta2=0.999,
             epsilon=1e-8,
             early_stop=False,
             rtol=1e-05,
             atol=1e-08,
             seed=None
             )

[Training progress]:   0%|          | 0/500000 [00:00<?, ?it/s]

In [5]:
nsamples = int(2**15)
nchains = 8
df = system.sample(nsamples,
                   nchains=nchains,
                   seed=None,
                   mcmc_alg=None
                   )

[32mINFO:NQS     [0m [35mSampling done[0m


In [6]:
df

Unnamed: 0,chain_id,nparticles,dim,energy,std_error,variance,accept_rate,eta,scale,nvisible,nhidden,mcmc_alg,nsamples,training_cycles,training_batch
0,1,1,1,0.500056,9.5e-05,5.1e-05,0.284943,0.05,3.0,1,2,rwm,32768,500000,5000
1,2,1,1,0.50011,9.4e-05,5e-05,0.283783,0.05,3.0,1,2,rwm,32768,500000,5000
2,3,1,1,0.500017,9.1e-05,5e-05,0.279419,0.05,3.0,1,2,rwm,32768,500000,5000
3,4,1,1,0.499903,9.2e-05,5.1e-05,0.276459,0.05,3.0,1,2,rwm,32768,500000,5000
4,5,1,1,0.500049,9.4e-05,5.1e-05,0.280701,0.05,3.0,1,2,rwm,32768,500000,5000
5,6,1,1,0.499958,9.8e-05,5.2e-05,0.280243,0.05,3.0,1,2,rwm,32768,500000,5000
6,7,1,1,0.500186,0.000101,5.4e-05,0.281769,0.05,3.0,1,2,rwm,32768,500000,5000
7,8,1,1,0.499981,9.6e-05,5.1e-05,0.280151,0.05,3.0,1,2,rwm,32768,500000,5000
