In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cmasher as cmr
from scipy.stats import sem
from scipy.optimize import curve_fit
import os
import sys
import json
import jax.numpy as jnp
import equinox as eqx
import flax
from flax import nnx

sys.path.append('../')
from src.networks.prunable_ffnn import PrunableFFNN

plt.rc('text', usetex=True)
plt.rc('font', family='serif')

In [None]:
# Load reference data from DMRG
dmrg_data_path = '../../data/tfim_2d/dmrg/'
dmrg_N_4_7 = pd.read_csv(dmrg_data_path + 'N=4x4_7x7_k=kc.csv')
dmrg_N_8 = pd.read_csv(dmrg_data_path + 'N=8x8_k=kc.csv')
dmrg_N_9 = pd.read_csv(dmrg_data_path + 'N=9x9_k=kc.csv')
dmrg_N_10 = pd.read_csv(dmrg_data_path + 'N=10x10_k=kc.csv')

dmrg_data_df = pd.concat([dmrg_N_4_7, dmrg_N_8, dmrg_N_9, dmrg_N_10], ignore_index=True)
dmrg_data_df

In [None]:
def relative_error(x, y):
    return np.abs(x - y) / np.abs(y)

def abs_error(x, y):
    return np.abs(x - y)

def abs_difference_per_spin(x, y, N):
    return np.abs(x - y) / N

def difference(x, y):
    return x - y

In [None]:
def load_serialized_network(load_data_path, pruning_iter):
    """ Load a network from a file
    """
    file_name = load_data_path + r'model_piter={}.eqx'.format(pruning_iter)

    with open(file_name, "rb") as f:
        hyperparams = json.loads(f.readline().decode())
        network = PrunableFFNN(**hyperparams)
        graph, old_state = nnx.split(network)
        new_state = eqx.tree_deserialise_leaves(f, old_state)
        new_model = nnx.merge(graph, new_state)

    return new_model

def load_sampling_data(load_data_path, piter):
    """ Load a network from a file
    """
    file_path = load_data_path + f'sampling_log_piter={piter}.json'
    os.path.expanduser(file_path)

    if not os.path.exists(file_path):
        print(f"Error: File not found - {file_path}")
    else:
        with open(file_path, 'r') as f:
            data = json.load(f)

    return data

def load_training_data(load_data_path, piter):
    """ Load a network from a file
    """
    file_path = load_data_path + f'training_log_piter={piter}.json'
    os.path.expanduser(file_path)

    if not os.path.exists(file_path):
        print(f"Error: File not found - {file_path}")
    else:
        with open(file_path, 'r') as f:
            data = json.load(f)

    return data

In [None]:
system_sizes = np.array([4,5,6,7,8,9,10]) # Lx
# Network width scaling factors
alpha = 8
# Number of pruning iterations
pruning_iterations = np.array([51, 58, 64, 69, 65, 74, 74])

load_data_path = '../../data/tfim_2d/ffnn/IMP_WR_system_size_scaling/'

abs_differences = np.empty((len(system_sizes), np.max(pruning_iterations)+1))
remaining_params = np.empty((len(system_sizes), np.max(pruning_iterations)+1))

# Load training logs
for i, Lx in enumerate(system_sizes):

    print(f'Loading system size {Lx}x{Lx}')
    p_its = pruning_iterations[i]
    N = Lx**2
    e_gs = dmrg_data_df['energy'].iloc[i]
    
    file_path = load_data_path + f'N={Lx}x{Lx}/'

    network = load_serialized_network(file_path, 'dense')
    n_dense = network.get_num_params()
    print(f'Number of parameters in dense network: {n_dense}')

    data = load_sampling_data(file_path, 'dense')
    energies = data['Energy']['Mean']
    energies = np.array(energies)
    energy = np.mean(energies)
    abs_diff = abs_difference_per_spin(energy, e_gs, N)

    remaining_params[i, 0] = n_dense
    abs_differences[i, 0] = abs_diff

    for piter in range(p_its):

        # Load the network
        network = load_serialized_network(file_path, piter)
        n_rem = network.get_num_params()

        data = load_sampling_data(file_path, piter)      

        energies = data['Energy']['Mean']
        energies = np.array(energies)
        energy = np.mean(energies)
        abs_diff = abs_difference_per_spin(energy, e_gs, N)

        remaining_params[i, piter+1] = n_rem
        abs_differences[i, piter+1] = abs_diff


In [None]:
# Load magnetization data
Mz_totals = np.empty((len(system_sizes), np.max(pruning_iterations)+1))
Mx_totals = np.empty((len(system_sizes), np.max(pruning_iterations)+1))

# Load training logs
for i, Lx in enumerate(system_sizes):

    print(f'Loading system size {Lx}x{Lx}')
    p_its = pruning_iterations[i]
    N = Lx**2
    file_path = load_data_path + f'N={Lx}x{Lx}/'

    network = load_serialized_network(file_path, 'dense')
    n_dense = network.get_num_params()
    print(f'Number of parameters in dense network: {n_dense}')

    data = load_sampling_data(file_path, 'dense')
    
    Mz_tot = np.sum([np.mean(np.array(data['Z_{}'.format(i)]['Mean']['real'])) for i in range(0, N)])/N
    Mx_tot = np.sum([np.mean(np.array(data['X_{}'.format(i)]['Mean']['real'])) for i in range(0, N)])/N

    Mz_totals[i, 0] = Mz_tot
    Mx_totals[i, 0] = Mx_tot

    for piter in range(p_its):

        # Load the network
        network = load_serialized_network(file_path, piter)
        n_rem = network.get_num_params()

        data = load_sampling_data(file_path, piter)      

        Mz_tot = np.sum([np.mean(np.array(data['Z_{}'.format(i)]['Mean']['real'])) for i in range(0, N)])/N
        Mx_tot = np.sum([np.mean(np.array(data['X_{}'.format(i)]['Mean']['real'])) for i in range(0, N)])/N

        Mz_totals[i, piter+1] = Mz_tot
        Mx_totals[i, piter+1] = Mx_tot


In [None]:
import matplotlib.gridspec as gridspec

fig = plt.figure(figsize=(16, 8), layout='constrained')
gs = gridspec.GridSpec(2, 2, width_ratios=[1, 1], height_ratios=[1,1])  # left column wider

# Right column (top and bottom)
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1], sharex=ax1)
ax3 = fig.add_subplot(gs[1, 0], sharex=ax1)
ax4 = fig.add_subplot(gs[1, 1], sharex=ax1)

ax1.axvline(4, color='black', linestyle='--', linewidth=2, alpha=0.5)
ax3.axvline(4, color='black', linestyle='--', linewidth=2, alpha=0.5)
ax4.axvline(4, color='black', linestyle='--', linewidth=2, alpha=0.5)

################### plot the energy in the first plot
colors = ["#000000", "#004488", "#994455", "#997700", "#6d0b52", "#ee99aa", "#eecc66",  "#6699cc"]
colors = colors[::-1]

# Plot the errors
for i, Lx in enumerate(system_sizes):

    n_rem = remaining_params[i, :pruning_iterations[i]+1]
    n_rem = n_rem / (Lx**2)
    abs_diff = abs_differences[i, :pruning_iterations[i]+1]
    ax1.plot(n_rem, abs_diff,
            '-',
            linewidth=3,
            label=r'$N = {} \times {}$'.format(Lx, Lx),
            color=colors[i])


ax1.grid(linestyle='-', alpha=0.5)
ax1.tick_params(axis='both', which='major', labelsize=20)
ax1.loglog()
ax1.legend(fontsize=20, reverse=False, loc='lower left')
ax1.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
ax1.set_ylabel(r'$\Delta E / N$', size=30)

################ fidelity in second plot
N_4x4 = [1.0000050659450932, 0.9999676242692559, 0.9999666661364286, 0.9999867331832297, 1.0000788518786647, 0.9998810941662946, 0.9999420967876141, 1.0000810890424374, 1.0003895611133675, 1.0001235269372095, 0.999656740973385, 1.001000567605306, 1.0009545054445215, 0.9994141463094983, 0.9992116769472136, 0.998285200964079, 0.9994049360819917, 1.0016213433393901, 0.9974773711757056, 0.9973343968849565, 1.0000127503197804, 0.9994245324798146, 0.9993472643416184, 1.0000909738245243, 0.9999384145745116, 0.9991088947720133, 0.993043623163283, 0.9996254234735482, 0.9567363476210032, 0.9549836819543831, 0.9919245251980147, 0.966247877654268, 0.9810943515376572, 0.9871302093217297, 0.9953588412531283, 0.985942181316777, 0.9914573739487543, 0.9988454882714244, 0.9842172565910932, 1.001202396611198, 0.9974149920925672, 0.9980383618740414, 1.0006630484989782, 0.9943241510822589, 1.0000846524984566, 0.9891882871133441, 0.9992530785711682, 0.9978175027196943, 0.9991383783619674, 1.000000021696325]
N_5x5 = [0.99983260838824, 0.9999482329277788, 1.000255177543154, 0.9999237243447057, 0.9998659068391117, 1.0004396494648955, 1.0000475826702642, 1.0001938527324685, 1.00013931293819, 0.9998140052068587, 1.0001153308082689, 1.0001826776457594, 1.0006560247783918, 1.000672508763977, 0.9991325494069477, 0.9997308451527472, 1.0009276436759322, 0.9980341976692199, 1.0014141046951535, 0.9994403377003371, 1.002154656234265, 0.9989284127084488, 0.9960713498208301, 0.9978992079807635, 0.9980025146572585, 0.9916110622795333, 0.9805619129317688, 0.9985899687694256, 0.9990367852820775, 0.9802929058175055, 0.9983968476019285, 0.9915187953586598, 0.939603356376805, 0.982164100396639, 0.9435623670093027, 0.9819977457795868, 0.968715376596777, 0.9628538764461955, 0.977402852222437, 0.9983132494214558, 0.9919808067167504, 0.9746595044025265, 0.9938063225246054, 0.9807789908525788, 0.9933079037293743, 0.9881046811287746, 0.9968205479966632, 0.9982901775521047, 0.9966891331742109, 0.9988410973329205, 0.9917828971496494, 0.9893005194708367, 0.999422183926607, 0.9994283616395958, 0.9989466510794771, 0.9923462501394084, 0.9995805449196116]
N_6x6 = [1.0001603163884272, 0.9998002139558612, 0.9998914779255185, 1.0001869030277044, 0.9998240506356944, 0.999747588283312, 0.9999033485642634, 0.9999992179649893, 1.000460920967973, 1.001296294336467, 1.0007660771011222, 0.9996423187121249, 1.0006737436002402, 1.0000428636325351, 0.9983469371122747, 0.9974664611305485, 1.0014795265165894, 0.9993726007894336, 0.9983445867239025, 0.9993242439486972, 0.997278922039163, 1.0030802092740105, 1.0008162042776594, 0.9993853585677375, 0.998867155666137, 1.0002559525635333, 0.9982176739191974, 0.9964058692991117, 0.9974776343329587, 0.995232360073297, 1.0031404057363895, 0.999070915206645, 0.9718454272544597, 0.9590091970728785, 0.9310028220247081, 0.9709241319990777, 0.9082339484139589, 0.9598770116558015, 0.9722605722383383, 0.9727515578906182, 0.9740947180446912, 0.9786046041894614, 0.9792843119263717, 0.9673197910939397, 0.9707565196200809, 0.9707723082864498, 0.9993036447869427, 0.9847925020030992, 0.9944258873861385, 0.9893755201394704, 0.9988510750157166, 0.9942142407641726, 0.9811098600809737, 0.9983069687141164, 0.996019138172593, 0.9845202755321721, 0.9987552877659309, 0.9996684882074521, 0.9927241930029288, 0.9997890073198873, 0.9974095447259274, 0.9974911887338704, 0.9985208057802814]
N_7x7 = [0.999545109890897, 0.9997963240559827, 0.9997467940716876, 0.999948236212461, 0.9998516128011822, 1.0001359052591339, 1.0001217067740857, 0.9998447063870627, 1.0001698131138104, 0.9993817884421331, 0.9994446700444573, 1.0005873510232075, 0.9991091085039299, 1.0009314009451813, 0.9973005352250545, 0.9974957655636287, 0.9991972343211083, 0.9995701211868129, 0.999579831524541, 0.9991238164268497, 0.9945675473675664, 0.9990860066450412, 0.9997882010373983, 0.9986528978636108, 1.0014087772880034, 0.9954749478943062, 0.9945001992427975, 0.9951628461529237, 0.9952353968946747, 0.9969442195508691, 0.9944234196948012, 0.989135678382357, 1.0033744800330553, 0.9985943388625208, 0.9313784931896806, 0.9358922012274034, 0.6617906885587097, 0.9802084576808552, 0.7905429983049352, 0.974542025460391, 0.9293577415191129, 0.9715592355351955, 0.9250012523081961, 0.9901567308352174, 0.8970423263647394, 0.946905599936753, 0.9512473000134802, 0.9585109256593121, 0.9267455884133925, 0.958842938185676, 0.9848473569882602, 0.9332153421406082, 0.9953528841084404, 0.97077180967722, 0.9582561843750688, 0.934572266813522, 0.9703958386271734, 0.9940124933078929, 0.9981931675775699, 0.9951547103538962, 0.9966594298668879, 1.000673590829976, 0.998570434326455, 0.9989479770015618, 0.9998991889892579, 0.9976970085033944, 0.9982732795149948, 0.9998604223025145]
N_8x8 = [1.0002820126537397, 0.9996673360114515, 0.9997932509806633, 0.9999340686843823, 1.0000889451671338, 0.9999115198304588, 1.0000099555376285, 1.000003696443033, 0.9999152195648258, 0.9993345874068519, 1.0000617454936926, 0.9988030936167676, 1.0007502913644715, 1.0009594293011952, 0.9988413926739714, 0.9979463919484461, 0.9990377629352375, 0.9971322763217995, 0.9958458015852248, 0.9968492340739245, 0.9993719435971649, 1.0013673497646498, 0.9942265276595253, 0.9975915415725908, 0.9914103970643395, 0.9983506210775962, 0.9985196959874353, 0.9986523977953756, 0.9938966253746792, 0.998721014434733, 0.9962242954727141, 0.9920796873044301, 0.9915192703629553, 0.9837218068296436, 0.9967792869020694, 0.9991922980001244, 0.9015919092653596, 0.8079317271010876, 0.8768019390066876, 0.807085291864187, 0.9770151259534382, 0.8444061845201282, 0.9146489640512729, 0.9199546236509923, 0.931588109160767, 0.9630284155343893, 0.9452459675334568, 0.9852274874743866, 0.9412130979123285, 0.974043987536382, 0.9815833998381833, 0.9662596820140487, 0.9765583248989124, 0.9832700365318738, 0.9892976049268065, 0.9922393066318005, 0.9798601077331734, 0.9990102651806482, 1.0002394337481937, 0.9834159591805369, 0.9763611339745716, 1.0003332114086225, 0.9932088245131457, 0.9811501046525911]
N_9x9 = [1.0000339192827048, 1.00013814180612, 1.000258747797163, 0.9998061005076345, 1.0000109687939347, 0.999720636180108, 1.000286340148213, 1.0000251093289036, 0.9999849364420068, 1.0003917346704523, 0.9996119197693885, 0.9998004488313131, 1.000468962082834, 0.9998129759354251, 0.9989200170861036, 1.0002714669864259, 1.0002927229125815, 0.9984055595077853, 0.9995738803043945, 0.9972135691719372, 0.9893033119416705, 0.9938497695330695, 0.9951724970062326, 0.9970116774467636, 0.9952727320680124, 1.0013481442330652, 0.998053919299702, 0.9926765643114625, 0.9887792982440661, 0.9891839326410972, 0.9918550643628669, 0.9928511472164377, 0.9925949935866851, 0.9811142292509669, 0.9696896347042626, 0.9794379277092287, 0.9852245848263516, 0.9307429409544757, 0.9153128533533491, 0.803983656866624, 0.6770849236435652, 0.8103109364362103, 0.8962334559816942, 0.8691267776754426, 0.8923896495964692, 0.9252253661323026, 0.8221453238189291, 0.8744034237374989, 0.8615033555496692, 0.9803723483973427, 0.8839396235683638, 0.8776693023903317, 0.700138309046849, 0.9445903810477578, 0.9240462853680047, 0.9317677572098478, 0.929831880489473, 0.9160310917058143, 0.9846082668745564, 0.9502942424103913, 0.9262030815731066, 0.984909649297679, 0.9312667269388629, 0.9970479666915465, 0.9761410723331156, 0.9885085882833503, 0.9977055705894846, 0.9967010816623443, 0.9946271950046689, 0.9959573256117636, 0.997505697254573, 0.9985250841548035, 0.9997018358428638]
N_10x10 = [0.9967550862601042, 0.998193132170432, 0.9955225671855583, 0.9995159912826357, 0.9975577561602625, 1.0010921141704674, 0.9996811838400061, 1.000780786610947, 1.0009757616958792, 1.0007413014454838, 0.9979057264507223, 0.9995089854392735, 1.000027962988467, 0.9998466650612193, 1.0018983874959482, 0.9965214034332454, 0.9998696687039734, 0.9993912732761007, 0.999125977945774, 1.0008114153671166, 0.9981692823559659, 0.9945337370495073, 0.99527834912118, 0.975711409941924, 0.9786392427229232, 0.9982913986965182, 0.9881020127335866, 0.9868157261306518, 0.9970689715327936, 0.9912413500396843, 0.9855509170956932, 0.9882117164272278, 0.9832991224565247, 0.9775322788597779, 0.990374110127681, 0.9904392261966838, 0.9908799802543957, 0.9874731556890513, 0.9928474670380663, 0.9146614317581869, 0.8343696999638713, 0.5747404024092901, 0.7507128719528168, 0.773664342888019, 0.858130191939006, 0.8436647350729962, 0.898139196632921, 0.8136587455498449, 0.8370009752748057, 0.8696453141837666, 0.8545745881832539, 0.8489208062654684, 0.8130752130943659, 0.8288570536954423, 0.7815027021174666, 0.8981635012057367, 0.8189355468324073, 0.9417652260113262, 0.9472506337247908, 0.9754544163621827, 0.9238856060446803, 0.9686425247946452, 0.8978195224272445, 0.9757458976447861, 0.9530580922507738, 0.9366931904918602, 0.9897841123703374, 0.9728513119387817, 0.9890964349197976, 0.9881709558772703, 0.9945180088564083, 0.9990341073628707, 0.9954369320809768]
all_fidelities = [N_4x4, N_5x5, N_6x6, N_7x7, N_8x8, N_9x9, N_10x10]

colors = ["#000000", "#004488", "#994455", "#997700", "#6d0b52", "#ee99aa", "#eecc66",  "#6699cc"]
colors = colors[::-1]
indexes = np.array([6,5,4,3,2,1,0])

for i in indexes:
    Lx = system_sizes[i]
    n_rem = remaining_params[i, :pruning_iterations[i]]
    n_rem = n_rem / (Lx**2)
    fidelity = all_fidelities[i]
    ax2.plot(n_rem[1:], fidelity,
            '-',
            linewidth=2.5,
            label=r'$N = {} \times {}$'.format(Lx, Lx),
            color=colors[i])


ax2.grid(linestyle='-', alpha=0.5)
ax2.tick_params(axis='both', which='major', labelsize=20)
ax2.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
ax2.set_xscale('log')
ax2.set_ylabel(r'$F(\rho)$', size=30)


################ Mx in third plot
colors = ["#000000", "#004488", "#994455", "#997700", "#6d0b52", "#ee99aa", "#eecc66",  "#6699cc"]
colors = colors[::-1]
indexes = np.array([6,5,4,3,2,1,0])

for i in indexes:
    Lx = system_sizes[i]

    n_rem = remaining_params[i, :pruning_iterations[i]+1]
    n_rem = n_rem / (Lx**2)
    Mx_tot = Mx_totals[i, :pruning_iterations[i]+1]
    ax4.plot(n_rem, Mx_tot,
            '-',
            linewidth=3,
            label=r'$N = {} \times {}$'.format(Lx, Lx),
            color=colors[i])

ax4.grid(linestyle='-', alpha=0.5)
ax4.tick_params(axis='both', which='major', labelsize=20)
ax4.set_xscale('log')
ax4.set_xlabel(r'$\rho$', size=30)

ax4.set_ylabel(r'$M_x^{\mathrm{tot}}$', size=30)
plt.tight_layout()

# Mz in fourth plot
colors = ["#000000", "#004488", "#994455", "#997700", "#6d0b52", "#ee99aa", "#eecc66",  "#6699cc"]
colors = colors[::-1]
indexes = np.array([6,5,4,3,2,1,0])

for i in indexes:
    Lx = system_sizes[i]
    n_rem = remaining_params[i, :pruning_iterations[i]+1]
    n_rem = n_rem / (Lx**2)
    Mz_tot = Mz_totals[i, :pruning_iterations[i]+1]
    ax3.plot(n_rem, Mz_tot,
            '-',
            linewidth=3,
            label=r'$N = {} \times {}$'.format(Lx, Lx),
            color=colors[i])


ax3.grid(linestyle='-', alpha=0.5)
ax3.tick_params(axis='both', which='major', labelsize=20)
ax3.set_xscale('log')
ax3.set_xlabel(r'$\rho$', size=30)
ax3.set_ylabel(r'$M_z^{\mathrm{tot}}$', size=30)

ax1.text(-0.14, 0.9, r'a)', horizontalalignment='center', verticalalignment='bottom', transform=ax1.transAxes, fontsize=24)
ax1.text(1.06, 0.9, r'b)', horizontalalignment='center', verticalalignment='bottom', transform=ax1.transAxes, fontsize=24)
ax1.text(-0.14, -0.16, r'c)', horizontalalignment='center', verticalalignment='bottom', transform=ax1.transAxes, fontsize=24)
ax1.text(1.06, -0.16, r'd)', horizontalalignment='center', verticalalignment='bottom', transform=ax1.transAxes, fontsize=24)

fig.savefig('fig3_abcd.pdf', bbox_inches='tight', format='pdf')