In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cmasher as cmr
from scipy.stats import sem
from scipy.optimize import curve_fit
import os
import sys
import json
import equinox as eqx
import flax
from flax import nnx

sys.path.append('../')
from src.networks.prunable_res_cnn import PrunableResCNN
from src.networks.prunable_ffnn import PrunableFFNN

plt.rc('text', usetex=True)
plt.rc('font', family='serif')

In [None]:
def relative_error(x, y):
    return np.abs(x - y) / np.abs(y)

def abs_error(x, y):
    return np.abs(x - y)

def abs_difference_per_spin(x, y, N):
    return np.abs(x - y) / N

def difference(x, y):
    return x - y

In [None]:
def load_serialized_ffnn(load_data_path, pruning_iter):
    """ Load a network from a file
    """
    file_name = load_data_path + r'model_piter={}.eqx'.format(pruning_iter)

    with open(file_name, "rb") as f:
        hyperparams = json.loads(f.readline().decode())
        network = PrunableFFNN(**hyperparams)
        graph, old_state = nnx.split(network)
        new_state = eqx.tree_deserialise_leaves(f, old_state)
        new_model = nnx.merge(graph, new_state)

    return new_model

def load_serialized_rescnn(load_data_path, pruning_iter):
    """ Load a network from a file
    """
    file_name = load_data_path + r'model_piter={}.eqx'.format(pruning_iter)

    with open(file_name, "rb") as f:
        hyperparams = json.loads(f.readline().decode())
        network = PrunableResCNN(**hyperparams)
        graph, old_state = nnx.split(network)
        new_state = eqx.tree_deserialise_leaves(f, old_state)
        new_model = nnx.merge(graph, new_state)

    return new_model

def load_sampling_data(load_data_path, piter):
    """ Load a network from a file
    """
    file_path = load_data_path + f'sampling_log_piter={piter}.json'
    os.path.expanduser(file_path)

    if not os.path.exists(file_path):
        print(f"Error: File not found - {file_path}")
    else:
        with open(file_path, 'r') as f:
            data = json.load(f)

    return data

def load_training_data(load_data_path, piter):
    """ Load a network from a file
    """
    file_path = load_data_path + f'training_log_piter={piter}.json'
    os.path.expanduser(file_path)

    if not os.path.exists(file_path):
        print(f"Error: File not found - {file_path}")
    else:
        with open(file_path, 'r') as f:
            data = json.load(f)

    return data

Load ResCNN reference energies for various transverse field strengths $\kappa = \{0.1, 1, 2, \kappa_c, 4, 5, 6\}$

In [None]:
kappas = np.array([0.1, 1, 2, 4, 5, 6])

dense_energies = []
dense_variances = []

for i, kappa in enumerate(kappas):
    print(f'Loading data for kappa = {kappa}')
  
    if kappa != 0.1:
        kappa = int(kappa)

    load_data_path = '../../data/tfim_2d/res_cnn/dense_network_reference_energies/'

    file_path = load_data_path + 'k={}/'.format(kappa)
    network = load_serialized_rescnn(file_path, 'dense')
    n_dense = network.get_num_params()
    n_dense = n_dense[0] + n_dense[1]

    data = load_sampling_data(file_path, 'dense')
    energies = data['Energy']['Mean']
    energies = np.array(energies)
    energy = np.mean(energies)

    var = data['Energy']['Variance']
    variance = np.array(var)
    variance = np.mean(var)

    dense_energies.append(energy)
    dense_variances.append(variance)

print(dense_energies)
print(dense_variances)


In [None]:
# ResCNN reference energies
e_k_01 = -180.14327137818287
e_k_1 = -194.44212670669265
e_k_2 = -239.0234337701538
e_k_4 = -411.94084029073827
e_k_5 = -509.3311469812418
e_k_6 = -607.6855556865164

ground_state_energies = np.array([e_k_01, e_k_1, e_k_2, e_k_4, e_k_5, e_k_6])

Load FFNN IMP-WR data, and compute relative errors to these ResCNN reference energies

In [None]:
kappas = np.array([0.1, 1, 2, 4, 5, 6])
# Number of pruning iterations
pruning_iterations = 65

load_data_path = '../../data/tfim_2d/ffnn/IMP_WR/'

abs_differences = np.empty((len(kappas), pruning_iterations+1))
remaining_params = np.empty((len(kappas), pruning_iterations+1))
variances = np.empty((len(kappas), pruning_iterations+1))

for i, kappa in enumerate(kappas):

    print(f'Loading data for kappa = {kappa}')
  
    if kappa != 0.1 and kappa != 3.04438:
        kappa = int(kappa)

    e_gs = ground_state_energies[i]

    file_path = load_data_path + 'k={}/'.format(kappa)
    network = load_serialized_ffnn(file_path, 'dense')
    n_dense = network.get_num_params()
    print(f'Number of parameters in dense network: {n_dense}')

    data = load_sampling_data(file_path, 'dense')
    energies = data['Energy']['Mean']
    energies = np.array(energies)
    energy = np.mean(energies)
    abs_difference = abs_difference_per_spin(energy, e_gs, 10**2)

    variance = data['Energy']['Variance']
    variance = np.array(variance)
    variance = np.mean(variance)

    variances[i,0] = variance
    remaining_params[i,0] = n_dense
    abs_differences[i,0] = abs_difference

    for piter in range(pruning_iterations):

        # Load the network
        network = load_serialized_ffnn(file_path, piter)
        n_rem = network.get_num_params()

        data = load_sampling_data(file_path, piter)      

        energies = data['Energy']['Mean']
        energies = np.array(energies)
        energy = np.mean(energies)
        abs_difference = abs_difference_per_spin(energy, e_gs, 10**2)
        
        variance = data['Energy']['Variance']
        variance = np.array(variance)
        variance = np.mean(variance)

        remaining_params[i,piter+1] = n_rem
        variances[i,piter+1] = variance
        abs_differences[i,piter+1] = abs_difference


In [None]:
e_critical = -3.211395417101927*100 #(compute as an average over several ResCNNs for the the critical point)

In [None]:
alphas = np.array([1,2.5,5]) # Network width scaling factors
p_iters = np.array([54,60,65]) # Number of pruning iterations for each network
N=100 # Number of spins in the system

load_data_path = '../../data/tfim_2d/ffnn/IMP_WR/k=3.04438/'

critical_abs_differences = np.empty((len(alphas), np.max(p_iters)+1))
critical_remaining_params = np.empty((len(alphas), np.max(p_iters)+1))
critical_variances = np.empty((len(alphas), np.max(p_iters)+1))

print('Loading data for alpha=3.04438')

for i, alpha in enumerate(alphas):

    pruning_iters = p_iters[i]

    if alpha == 1 or alpha == 5:
        alpha = int(alpha)

    file_path = load_data_path + 'alpha={}/'.format(alpha)
    network = load_serialized_ffnn(file_path, 'dense')
    n_dense = network.get_num_params()
    print(f'Number of parameters in dense network: {n_dense}')

    data = load_sampling_data(file_path, 'dense')
    energies = data['Energy']['Mean']
    energies = np.array(energies)
    energy = np.mean(energies)
    abs_diff = abs_difference_per_spin(energy, e_critical, N)

    variance = data['Energy']['Variance']
    variance = np.array(variance)
    variance = np.mean(variance)

    critical_variances[i,0] = variance/N
    critical_remaining_params[i,0] = n_dense
    critical_abs_differences[i,0] = abs_diff


    for piter in range(pruning_iters):

        # Load the network
        network = load_serialized_ffnn(file_path, piter)
        n_rem = network.get_num_params()

        data = load_sampling_data(file_path, piter)      

        energies = data['Energy']['Mean']
        energies = np.array(energies)
        energy = np.mean(energies)
        abs_diff = abs_difference_per_spin(energy, e_critical, N)

        variance = data['Energy']['Variance']
        variance = np.array(variance)
        variance = np.mean(variance)

        critical_variances[i,piter+1] = variance/N
        critical_remaining_params[i,piter+1] = n_rem
        critical_abs_differences[i,piter+1] = abs_diff


In [None]:
import matplotlib.gridspec as gridspec

fig = plt.figure(figsize=(16, 7), layout='constrained')
gs = gridspec.GridSpec(2, 2, width_ratios=[1.2, 1], height_ratios=[1,1])  # left column wider

# Left plot (takes both rows in first column)
ax1 = fig.add_subplot(gs[:, 0])

# Right column (top and bottom)
ax2 = fig.add_subplot(gs[0, 1])
ax3 = fig.add_subplot(gs[1, 1])

cmap = plt.get_cmap('Purples',len(alphas)+3)


ax1.axvline(x=10**(3.65),  color='purple', linestyle='--', linewidth=3, alpha=1)
ax1.axvline(x=10**(2.6),  color='purple', linestyle='--', linewidth=3,alpha=1)

ax2.axvline(x=10**(2.9), color='tab:red', linestyle='--', linewidth=3,alpha=1)
ax2.axvline(x=10**(2.3), color='tab:red', linestyle='--', linewidth=3,alpha=1)

ax3.axvline(x=10**(3.65), color='purple', linestyle='--',linewidth=3, alpha=1)
ax3.axvline(x=10**(2.6), color='purple', linestyle='--', linewidth=3,alpha=1)


for i, alpha in enumerate(alphas):
    if alpha == 1 or alpha == 5:
        alpha = int(alpha)

    if alpha == 3.04438:
        alpha = 'k_c'

    if alpha == 2.5:
        alpha = r'\frac{5}{2}'

    ax1.plot(critical_remaining_params[i,:p_iters[i]], critical_abs_differences[i,:p_iters[i]],
        '-',
        linewidth=3,
        color=cmap(i+2),
        label=r'$\alpha = {}$'.format(alpha), 
        )

cmap = plt.get_cmap('Reds_r',len(kappas))

for i, kappa in enumerate(kappas[:-3]):

    # if kappa == 0.1 or kappa == 3.04438 or kappa == 6.0:
    if kappa == 0.1 or kappa == 3.04438:
        kappa = kappa
    else:
        kappa = int(kappa)

    if kappa == 0.1:
        kappa = r'\frac{1}{10}'
        ax2.plot(remaining_params[i,:], abs_differences[i,:],
            '-',
            linewidth=3,
            color=cmap(2*i),
            label=r'$\kappa = {}$'.format(kappa), 
            ) 
    
    elif kappa != 3.04438:
        ax2.plot(remaining_params[i,:], abs_differences[i,:],
            '-',
            linewidth=3,
            color=cmap(2*i),
            label=r'$\kappa = {}$'.format(kappa), 
            )


for i, kappa in enumerate(kappas[3:]):

    cmap = plt.get_cmap('Blues',len(kappas))

    if kappa == 0.1 or kappa == 3.04438:
        kappa = kappa
    else:
        kappa = int(kappa)

    if kappa == 0.1:
        kappa = '\frac{1}{10}'

    if kappa != 3.04438:
        ax3.plot(remaining_params[i+3,:], abs_differences[i+3,:],
            '-',
            linewidth=3,
            color=cmap(2*i+1),
            label=r'$\kappa = {}$'.format(kappa), 
            )

ax1.fill_between(critical_remaining_params[2,:20], 0, 0.1, color='tab:green', alpha=0.2)
ymin, ymax = ax1.get_ylim()
ax1.fill_between(critical_remaining_params[2,19:39], ymin, ymax, color='yellow', alpha=0.2)   
ax1.fill_between(critical_remaining_params[2,38:], 0, 0.1, color='tab:red', alpha=0.2)

ax2.grid(linestyle='-', alpha=0.5)
ax2.tick_params(axis='both', which='major', labelsize=20)
ax2.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
ax2.loglog()
ax2.legend(fontsize=22, reverse=False, loc='lower left')

ax1.set_ylabel(r'$\Delta E/N$', size=30)
ax1.grid(linestyle='-', alpha=0.5)
ax1.tick_params(axis='both', which='major', labelsize=20)
ax1.loglog()
ax1.set_xlabel(r'$n$', size=30)
ax1.legend(fontsize=22, reverse=False, loc='lower left')

ax1.text(.85, 0.15, r'LEP', horizontalalignment='center', verticalalignment='bottom', transform=ax1.transAxes, fontsize=24)
ax1.text(0.2, 0.84, r'HEP', horizontalalignment='center', verticalalignment='bottom', transform=ax1.transAxes, fontsize=24)
ax1.text(0.55, 0.5, r'APL', horizontalalignment='center', verticalalignment='bottom', transform=ax1.transAxes, fontsize=24)

ax1.text(-0.12, 0.94, r'a)', horizontalalignment='center', verticalalignment='bottom', transform=ax1.transAxes, fontsize=24)
ax2.text(-0.09, 0.88, r'b)', horizontalalignment='center', verticalalignment='bottom', transform=ax2.transAxes, fontsize=24)
ax3.text(-0.09, 0.88, r'c)', horizontalalignment='center', verticalalignment='bottom', transform=ax3.transAxes, fontsize=24)

ax3.grid(linestyle='-', alpha=0.5)
ax3.tick_params(axis='both', which='major', labelsize=20)
ax3.loglog()
ax3.set_xlabel(r'$n$', size=30)
ax3.legend(fontsize=22, reverse=False, loc='lower left')

plt.tight_layout()
plt.savefig('fig2_abc.pdf', bbox_inches='tight', format='pdf')