In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cmasher as cmr
from scipy.stats import sem
from scipy.optimize import curve_fit
import os
import sys
import json
import equinox as eqx
import flax
from flax import nnx

sys.path.append('../')
plt.rc('text', usetex=True)
plt.rc('font', family='serif')

In [None]:
def load_training_data(load_data_path, piter):
    """ Load a network from a file
    """
    file_path = load_data_path + f'training_log_piter={piter}.json'
    os.path.expanduser(file_path)

    if not os.path.exists(file_path):
        print(f"Error: File not found - {file_path}")
    else:
        with open(file_path, 'r') as f:
            data = json.load(f)

    return data

In [None]:
def load_sampling_data(load_data_path, piter):
    """ Load a network from a file
    """
    file_path = load_data_path + f'sampling_log_piter={piter}.json'
    os.path.expanduser(file_path)

    if not os.path.exists(file_path):
        print(f"Error: File not found - {file_path}")
    else:
        with open(file_path, 'r') as f:
            data = json.load(f)

    return data

In [None]:
cnn_energies = []
cnn_variances = []

file_path = '../../data/tfim_2d/cnn/IMP_WR/'
sampling_data = load_sampling_data(file_path, 'dense')

energy = sampling_data['Energy']['Mean']
energy = np.array(energy).mean()
variance = sampling_data['Energy']['Variance']
variance = np.array(variance).mean()
cnn_energies.append(energy)
cnn_variances.append(variance)


In [None]:
ffnn_energies = []
ffnn_variances = []

file_path = '../../data/tfim_2d/ffnn/IMP_WR/k=3.04438/alpha=1/'
sampling_data = load_sampling_data(file_path, 'dense')

energy = sampling_data['Energy']['Mean']
energy = np.array(energy).mean()
variance = sampling_data['Energy']['Variance']
variance = np.array(variance).mean()
ffnn_energies.append(energy)
ffnn_variances.append(variance)

file_path = '../../data/tfim_2d/ffnn/IMP_WR/k=3.04438/alpha=2.5/'
sampling_data = load_sampling_data(file_path, 'dense')

energy = sampling_data['Energy']['Mean']
energy = np.array(energy).mean()
variance = sampling_data['Energy']['Variance']
variance = np.array(variance).mean()
ffnn_energies.append(energy)
ffnn_variances.append(variance)

file_path = '../../data/tfim_2d/ffnn/IMP_WR/k=3.04438/alpha=5/'
sampling_data = load_sampling_data(file_path, 'dense')

energy = sampling_data['Energy']['Mean']
energy = np.array(energy).mean()
variance = sampling_data['Energy']['Variance']
variance = np.array(variance).mean()
ffnn_energies.append(energy)
ffnn_variances.append(variance)

In [None]:
res_cnn_energies = []
res_cnn_variances = []

file_path = '../../data/tfim_2d/res_cnn/dense_network_reference_energies/k=3.04483/nf=4_nb=1/'
sampling_data = load_sampling_data(file_path, 'dense')

energy = sampling_data['Energy']['Mean']
energy = np.array(energy).mean()
variance = sampling_data['Energy']['Variance']
variance = np.array(variance).mean()
res_cnn_energies.append(energy)
res_cnn_variances.append(variance)

file_path = '../../data/tfim_2d/res_cnn/dense_network_reference_energies/k=3.04483/nf=4_nb=2/'
sampling_data = load_sampling_data(file_path, 'dense')

energy = sampling_data['Energy']['Mean']
energy = np.array(energy).mean()
variance = sampling_data['Energy']['Variance']
variance = np.array(variance).mean()
res_cnn_energies.append(energy)
res_cnn_variances.append(variance)

file_path = '../../data/tfim_2d/res_cnn/dense_network_reference_energies/k=3.04483/nf=4_nb=4/'
sampling_data = load_sampling_data(file_path, 'dense')

energy = sampling_data['Energy']['Mean']
energy = np.array(energy).mean()
variance = sampling_data['Energy']['Variance']
variance = np.array(variance).mean()
res_cnn_energies.append(energy)
res_cnn_variances.append(variance)

file_path = '../../data/tfim_2d/res_cnn/dense_network_reference_energies/k=3.04483/nf=8_nb=1/'
sampling_data = load_sampling_data(file_path, 'dense')

energy = sampling_data['Energy']['Mean']
energy = np.array(energy).mean()
variance = sampling_data['Energy']['Variance']
variance = np.array(variance).mean()
res_cnn_energies.append(energy)
res_cnn_variances.append(variance)

file_path = '../../data/tfim_2d/res_cnn/dense_network_reference_energies/k=3.04483/nf=8_nb=2/'
sampling_data = load_sampling_data(file_path, 'dense')

energy = sampling_data['Energy']['Mean']
energy = np.array(energy).mean()
variance = sampling_data['Energy']['Variance']
variance = np.array(variance).mean()
res_cnn_energies.append(energy)
res_cnn_variances.append(variance)

file_path = '../../data/tfim_2d/res_cnn/dense_network_reference_energies/k=3.04483/nf=8_nb=4/'
sampling_data = load_sampling_data(file_path, 'dense')

energy = sampling_data['Energy']['Mean']
energy = np.array(energy).mean()
variance = sampling_data['Energy']['Variance']
variance = np.array(variance).mean()
res_cnn_energies.append(energy)
res_cnn_variances.append(variance)

file_path = '../../data/tfim_2d/res_cnn/dense_network_reference_energies/k=3.04483/nf=16_nb=1/'
sampling_data = load_sampling_data(file_path, 'dense')

energy = sampling_data['Energy']['Mean']
energy = np.array(energy).mean()
variance = sampling_data['Energy']['Variance']
variance = np.array(variance).mean()
res_cnn_energies.append(energy)
res_cnn_variances.append(variance)

file_path = '../../data/tfim_2d/res_cnn/dense_network_reference_energies/k=3.04483/nf=16_nb=2/'
sampling_data = load_sampling_data(file_path, 'dense')

energy = sampling_data['Energy']['Mean']
energy = np.array(energy).mean()
variance = sampling_data['Energy']['Variance']
variance = np.array(variance).mean()
res_cnn_energies.append(energy)
res_cnn_variances.append(variance)

file_path = '../../data/tfim_2d/res_cnn/dense_network_reference_energies/k=3.04483/nf=16_nb=4/'
sampling_data = load_sampling_data(file_path, 'dense')

energy = sampling_data['Energy']['Mean']
energy = np.array(energy).mean()
variance = sampling_data['Energy']['Variance']
variance = np.array(variance).mean()
res_cnn_energies.append(energy)
res_cnn_variances.append(variance)

In [None]:
ffnn_energies_per_spin = np.array(ffnn_energies)/100
ffnn_variances_per_spin = np.array(ffnn_variances)/100

cnn_energies_per_spin = np.array(cnn_energies)/100
cnn_variances_per_spin = np.array(cnn_variances)/100

res_cnn_energies_per_spin = np.array(res_cnn_energies)/100
res_cnn_variances_per_spin = np.array(res_cnn_variances)/100

fig, ax = plt.subplots(figsize=(8, 6))

ax.scatter(cnn_energies_per_spin, cnn_variances_per_spin, marker='^', s=200, edgecolor='black', label='Shallow CNN', color='orange')
ax.scatter(ffnn_energies_per_spin, ffnn_variances_per_spin, marker='*', s=300, edgecolor='black', label='Shallow FFNN', color='blue')
ax.axvline(x=-321.0980007876571/100, color='black', linestyle='--', linewidth=2, label='DMRG')
ax.scatter(res_cnn_energies_per_spin, res_cnn_variances_per_spin, s=200, marker='o', edgecolor='black', label='Res-CNN', color='green')

ax.grid(linestyle='-', alpha=0.5)
ax.tick_params(axis='both', which='major', labelsize=26)
ax.legend(fontsize=24, reverse=False, loc='lower right')
ax.set_ylabel(r'$\sigma^2 / N$', size=40)
ax.set_xlabel(r'$E/N$', size=40)
ax.set_yscale('log')

# Add signficant digits to y labels
def format_func(value, tick_number):
    if value == 0:
        return '0'
    else:
        return f'{value:.5f}'
    
ax.xaxis.set_major_formatter(plt.FuncFormatter(format_func))

plt.savefig('appendix_variance_per_spin.pdf', format='pdf', bbox_inches='tight')


In [None]:
ffnn_energies_per_spin = np.array(ffnn_energies)/100
ffnn_variances_per_spin = np.array(ffnn_variances)/100

cnn_energies_per_spin = np.array(cnn_energies)/100
cnn_variances_per_spin = np.array(cnn_variances)/100

res_cnn_energies_per_spin = np.array(res_cnn_energies)/100
res_cnn_variances_per_spin = np.array(res_cnn_variances)/100

fig, ax = plt.subplots(figsize=(7, 6))

ax.scatter(cnn_variances_per_spin, cnn_energies_per_spin, marker='^', s=450, edgecolor='black', label='CNN', color='orange')
ax.scatter(ffnn_variances_per_spin[-1], ffnn_energies_per_spin[-1], marker='*', s=550, edgecolor='black', label='FFNN', color='blue')
ax.hlines(y=-321.0980007876571/100, xmin=10e-6, xmax=10**(-2.5), color='black', linestyle='-', linewidth=5, label='DMRG')
ax.scatter(res_cnn_variances_per_spin, res_cnn_energies_per_spin,  s=450, marker='o', edgecolor='black', label='ResCNN', color='green')

ax.tick_params(axis='both', which='major', labelsize=45)
ax.legend(fontsize=40, reverse=False, loc='upper left')
ax.set_xlabel(r'$\sigma^2 / N$', size=65)
ax.set_ylabel(r'$E/N$', size=66)
ax.set_xscale('log')

# Add signficant digits to y labels
def format_func(value, tick_number):
    if value == 0:
        return '0'
    else:
        return f'{value:.4f}'
    
from matplotlib.ticker import ScalarFormatter
# Use ScalarFormatter and force scientific notation if needed
formatter = ScalarFormatter(useMathText=True)
formatter.set_powerlimits((-4, 4))  # use scientific notation if exponent between these
ax.yaxis.set_major_formatter(formatter)
offset_text = ax.yaxis.get_offset_text()
offset_text.set_fontsize(30)

plt.savefig('fig1_inset.pdf', format='pdf', bbox_inches='tight')


In [None]:
sorted_var = np.argsort(res_cnn_variances_per_spin)
average_energy_per_spin = np.mean(res_cnn_energies_per_spin[sorted_var])
print(average_energy_per_spin)