In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cmasher as cmr
from scipy.stats import sem
from scipy.optimize import curve_fit
import os
import sys
import json
import equinox as eqx
import flax
from flax import nnx

sys.path.append('../')
from src.networks.prunable_ffnn import PrunableFFNN

plt.rc('text', usetex=True)
plt.rc('font', family='serif')

In [None]:
def relative_error(x, y):
    return np.abs(x - y) / np.abs(y)

def abs_error(x, y):
    return np.abs(x - y)

def abs_difference_per_spin(x, y, N):
    return np.abs(x - y) / N

def difference(x, y):
    return x - y

In [None]:
def load_serialized_network(load_data_path, pruning_iter):
    """ Load a network from a file
    """
    file_name = load_data_path + r'model_piter={}.eqx'.format(pruning_iter)

    with open(file_name, "rb") as f:
        hyperparams = json.loads(f.readline().decode())
        network = PrunableFFNN(**hyperparams)
        graph, old_state = nnx.split(network)
        new_state = eqx.tree_deserialise_leaves(f, old_state)
        new_model = nnx.merge(graph, new_state)

    return new_model

In [None]:
def load_sampling_data(load_data_path, piter):
    """ Load a network from a file
    """
    file_path = load_data_path + f'sampling_log_piter={piter}.json'
    os.path.expanduser(file_path)

    if not os.path.exists(file_path):
        print(f"Error: File not found - {file_path}")
    else:
        with open(file_path, 'r') as f:
            data = json.load(f)

    return data

In [None]:
def load_training_data(load_data_path, piter):
    """ Load a network from a file
    """
    file_path = load_data_path + f'training_log_piter={piter}.json'
    os.path.expanduser(file_path)

    if not os.path.exists(file_path):
        print(f"Error: File not found - {file_path}")
    else:
        with open(file_path, 'r') as f:
            data = json.load(f)

    return data

IMP-WR

In [None]:
alphas = np.array([8], dtype=int)
p_iters = np.array([53])

load_data_path = '../../data/toric_code/ffnn/IMP_WR/'

relative_errors = np.empty((len(alphas), np.max(p_iters)+1))
remaining_params = np.empty((len(alphas), np.max(p_iters)+1))
variances = np.empty((len(alphas), np.max(p_iters)+1))

for i, alpha in enumerate(alphas):
    N = 18
    e_gs = -18

    pruning_iterations = p_iters[i]

    file_path = load_data_path + 'alpha={}/'.format(alpha)
    network = load_serialized_network(file_path, 'dense')
    n_dense = network.get_num_params()
    print(f'Number of parameters in dense network: {n_dense}')

    data = load_sampling_data(file_path, 'dense')
    energies = data['Energy']['Mean']
    energies = np.array(energies)
    energy = np.mean(energies)
    e_rel = relative_error(energy, e_gs)

    variance = data['Energy']['Variance']
    variance = np.array(variance)
    variance = np.mean(variance)

    remaining_params[i,0] = n_dense
    relative_errors[i,0] = e_rel
    variances[i,0] = variance/N

    for piter in range(pruning_iterations):

        # Load the network
        network = load_serialized_network(file_path, piter)
        n_rem = network.get_num_params()

        data = load_sampling_data(file_path, piter)      

        energies = data['Energy']['Mean']
        energies = np.array(energies)
        energy = np.mean(energies)
        e_rel = relative_error(energy, e_gs)

        variance = data['Energy']['Variance']
        variance = np.array(variance)
        variance = np.mean(variance)

        variances[i,piter+1] = variance/N
        remaining_params[i,piter+1] = n_rem
        relative_errors[i,piter+1] = e_rel


$T(\theta_{\mathrm{init}}, m_{\mathrm{imp}})$

In [None]:

alphas = np.array([8], dtype=int)
selected_p_iters = np.array([0,3,6,9,12,15,18,21,24,27,30,33,36,39,42,45,48,51])

file_path = '../../data/toric_code/ffnn/T_init_imp/'

t_init_imp_relative_errors = np.empty(len(selected_p_iters)+1)
t_init_imp_remaining_params = np.empty(len(selected_p_iters)+1)
t_init_imp_variances = np.empty(len(selected_p_iters)+1)

t_init_imp_remaining_params[0] = remaining_params[0,0]
t_init_imp_relative_errors[0] = relative_errors[0,0]
t_init_imp_variances[0] = variances[0,0]

for i, piter in enumerate(selected_p_iters):

    # Load the network
    network = load_serialized_network(file_path, piter)
    n_rem = network.get_num_params()

    data = load_sampling_data(file_path, piter)    

    energies = data['Energy']['Mean']
    energies = np.array(energies)
    energy = np.mean(energies)
    e_rel = relative_error(energy, e_gs)

    variance = data['Energy']['Variance']
    variance = np.array(variance)
    variance = np.mean(variance)

    t_init_imp_variances[i+1] = variance/N
    t_init_imp_remaining_params[i+1] = n_rem
    t_init_imp_relative_errors[i+1] = e_rel


$T(\theta_{\mathrm{rand}}, m_{\mathrm{imp}})$

In [None]:

alphas = np.array([8], dtype=int)
selected_p_iters = np.array([0,3,6,9,12,15,18,21,24,27,30,33,36,39,42,45,48,51])

file_path = '../../data/toric_code/ffnn/T_rand_imp/'

t_rand_imp_relative_errors = np.empty(len(selected_p_iters))
t_rand_imp_remaining_params = np.empty(len(selected_p_iters))
t_rand_imp_variances = np.empty(len(selected_p_iters))

for i, piter in enumerate(selected_p_iters):

    # Load the network
    network = load_serialized_network(file_path, piter)
    n_rem = network.get_num_params()

    data = load_sampling_data(file_path, piter)      

    energies = data['Energy']['Mean']
    energies = np.array(energies)
    energy = np.mean(energies)
    e_rel = relative_error(energy, e_gs)

    variance = data['Energy']['Variance']
    variance = np.array(variance)
    variance = np.mean(variance)

    t_rand_imp_variances[i] = variance/N
    t_rand_imp_remaining_params[i] = n_rem
    t_rand_imp_relative_errors[i] = e_rel


$T(\theta_{\mathrm{init}}, m_{\mathrm{rand}})$

In [None]:
alphas = np.array([8], dtype=int)
selected_p_iters = np.array([8.80015432e-01, 5.99537037e-01, 4.08950617e-01, 2.78549383e-01,
  1.89814815e-01, 1.29243827e-01, 8.83487654e-02, 6.05709877e-02,
  4.08950617e-02, 2.77777778e-02, 1.85185185e-02, 1.27314815e-02,
  8.87345679e-03, 6.17283951e-03, 4.24382716e-03, 3.08641975e-03,
  1.92901235e-03, 7.71604938e-04])

file_path = '../../data/toric_code/ffnn/T_init_rand/'

t_init_rand_relative_errors = np.empty(len(selected_p_iters)+1)
t_init_rand_remaining_params = np.empty(len(selected_p_iters)+1)
t_init_rand_variances = np.empty(len(selected_p_iters)+1)

t_init_rand_remaining_params[0] = remaining_params[0,0]
t_init_rand_relative_errors[0] = relative_errors[0,0]
t_init_rand_variances[0] = variances[0,0]

for i, piter in enumerate(selected_p_iters):

    # Load the network
    network = load_serialized_network(file_path, piter)
    n_rem = network.get_num_params()

    data = load_sampling_data(file_path, piter)      

    energies = data['Energy']['Mean']
    energies = np.array(energies)
    energy = np.mean(energies)
    e_rel = relative_error(energy, e_gs)

    variance = data['Energy']['Variance']
    variance = np.array(variance)
    variance = np.mean(variance)

    t_init_rand_variances[i+1] = variance/N
    t_init_rand_remaining_params[i+1] = n_rem
    t_init_rand_relative_errors[i+1] = e_rel


In [None]:
fig, ax = plt.subplots(figsize=(8,6), layout='constrained')

# colors = ["#000000", "#eecc66", "#004488", "#ee99aa", "#6699cc"]
colors = ["#000000", "#004488", "#994455", "#997700", "#6d0b52", "#ee99aa", "#eecc66",  "#6699cc"]
shapes = ['o-', 's-', 'D-', '<-', 'P-', 'v-']

ax.axvline(x=16, color='black', linestyle='--', linewidth=3, label=r'$n/N=16$', alpha=0.5)

ax.plot(remaining_params[0,:41]/N, relative_errors[0,:41],
        '-', 
        label=r'IMP-WR',
        markersize=7,
        linewidth=4,
        markerfacecolor=colors[0],
        markeredgecolor='black',
        color=colors[0]
        )

ax.plot(t_init_imp_remaining_params[:-4]/N, t_init_imp_relative_errors[:-4],
        '--', 
        label=r'$T(\theta_{\mathrm{init}}, m_{\mathrm{imp}})$',
        markersize=7,
        linewidth=4,
        markerfacecolor="orange",
        markeredgecolor='black',
        color="orange"
        )

ax.plot(t_rand_imp_remaining_params[:-4]/N, t_rand_imp_relative_errors[:-4],
        ':', 
        label=r'$T(\theta_{\mathrm{rand}}, m_{\mathrm{imp}})$',
        markersize=7,
        linewidth=4,
        markerfacecolor=colors[2],
        markeredgecolor='black',
        color=colors[2]
        )

ax.plot(t_init_rand_remaining_params[:-4]/N, t_init_rand_relative_errors[:-4],
        '-.', 
        label=r'$T(\theta_{\mathrm{init}}, m_{\mathrm{rand}})$',
        markersize=7,
        linewidth=4,
        markerfacecolor=colors[3],
        markeredgecolor='black',
        color=colors[3]
        )

ax.grid(linestyle='-', alpha=0.5)
ax.tick_params(axis='both', which='major', labelsize=20)
ax.loglog()
ax.set_xlabel(r'$\rho$', size=30)
ax.legend(fontsize=20, reverse=True, loc='lower left')

ax.set_ylabel(r'$\epsilon_{\mathrm{rel}}(E)$', size=30)

plt.savefig('fig4a.pdf', bbox_inches='tight', format='pdf')

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cmasher as cmr
from scipy.stats import sem
from scipy.optimize import curve_fit
import os
import sys
import json
import equinox as eqx
import flax
import jax
import jax.numpy as jnp
from flax import nnx

sys.path.append('../../../')
from src import *
from src.hamiltonian_ops.toric_code import *
from src.vmc.tc_vmc import *
from src.networks.prunable_ffnn import *
from src.pruning_algorithm.pruner import *
from src.networks.prunable_ffnn import PrunableFFNN

plt.rc('text', usetex=True)
plt.rc('font', family='serif')

In [None]:
def get_psi(network_file_path, pruning_iter):
    """ Load a network from a file
    """

    # System parameters
    Lx,Ly = 3,3
    periodic_bc = True
    # Sampling parameters
    num_samples = 1024
    chunk_size = 1024
    num_discard = 10
    num_chains = 16
    use_tc_sampler = True
    # Optimization parameters
    learning_rate = 0.008
    diag_shift = 0.001
    # Network parameters
    input_dim = 2*Lx*Ly
    hidden_dim = 8*input_dim
    rngs_key = 2
    # Pruning parameters
    dense_epochs = 2000
    iterative_epochs = 200
    pruning_protocol = 'layerwise'
    pruning_ratio = 0.12
    weight_rewinding = True
    pruning_its = 20
    sampling_its = 10
    network_type = 'PrunableFFNN'


    ham_params = {
        'Lx': Lx,
        'Ly': Ly,
        'periodic_bc': periodic_bc
        }

    network_params = {
        'input_dim': input_dim,
        'hidden_dim': hidden_dim,
        'rngs_key': rngs_key
        }

    vmc_params = {
        'num_samples': num_samples,
        'chunk_size': chunk_size,
        'num_discard': num_discard,
        'num_chains': num_chains,
        'learning_rate': learning_rate,
        'diag_shift': diag_shift,
        'use_tc_sampler': use_tc_sampler,
        }

    pruner_params = {
        'dense_epochs': dense_epochs,
        'iterative_epochs': iterative_epochs,
        'pruning_protocol': pruning_protocol,
        'pruning_ratio': pruning_ratio,
        'weight_rewinding': weight_rewinding,
        'pruning_its': pruning_its,
        'sampling_its': sampling_its,
        'network_type': network_type,
        }

    params = {
        'ham_params': ham_params,
        'network_params': network_params,
        'vmc_params': vmc_params,
        'pruner_params': pruner_params
        }


    # Unpack hyperparameters
    ham_params = params['ham_params']
    network_params = params['network_params']
    vmc_params = params['vmc_params']
    pruner_params = params['pruner_params']


    # Construct the Hamiltonian
    operator = ToricCode(**ham_params)

    # Construct the neural network

    network = load_serialized_network(network_file_path, pruning_iter)

    # Construct the VMC wavefunction
    vmc = VMC(operator, 
            network,
            **vmc_params)

    # Initialize the pruning algorithm
    pruner = Pruner(vmc, **pruner_params)
    state = pruner.vmc_driver.state
    
    return state


In [None]:
def local_fidelity(psi, phi, psi_samples, phi_samples):

    logpsi_at = psi.log_value(phi_samples)
    logpsi_aa = psi.log_value(psi_samples)
    logpsi_tt = phi.log_value(phi_samples)
    logpsi_ta = phi.log_value(psi_samples)

    r1 = jnp.exp(logpsi_at - logpsi_tt)
    r2 = jnp.exp(logpsi_ta - logpsi_aa)

    return jnp.mean(r1) * jnp.mean(r2)

The cell below calculates the fidelity between neighboring NQS along the pruning trajectory.
This will take some time to run, so we have reported the data further below in the network.
To run the cell below, you can select a network width from the set of four widths: $w=\alpha N$, where the options are $\alpha=4,8,16,32.$ To select a particular network to calculate fidelity, choose index 'selected_network_index'.

In [None]:
pruning_iterations = np.array([47,53,58,64], dtype=int)
alphas = np.array([4,8,16,32], dtype=int)

selected_network_index = 0

network_file_path = '../../data/toric_code/ffnn/IMP_WR/alpha={}/'.format(alphas[selected_network_index])

remaining_params = []
fidelities = []

# Get the dense state
dense_state = get_psi(network_file_path, 'dense')
phi = get_psi(network_file_path,0)

dense_samples = dense_state.samples #(n_samples=5000)  # Shape: (n_samples, L)
phi_samples = phi.samples

# Compute the local fidelity
F = local_fidelity(dense_state, phi, dense_samples, phi_samples)
fidelities.append(F.tolist())

network = load_serialized_network(network_file_path, 'dense')
n_rem = network.get_num_params()
remaining_params.append(n_rem)


for piter in range(pruning_iterations-1):

    psi = get_psi(network_file_path,piter)
    phi = get_psi(network_file_path,piter+1)
    
    psi_samples = psi.samples #(n_samples=5000)  # Shape: (n_samples, L)
    phi_samples = phi.samples

    F = local_fidelity(psi, phi, psi_samples, phi_samples)
    # print('Fidelity: ', F)
    fidelities.append(F.tolist())

    network = load_serialized_network(network_file_path, piter)
    n_rem = network.get_num_params()
    remaining_params.append(n_rem)


In [None]:
fidelities_alpha4 = [0.9996702760752937, 0.9997944153804623, 0.9998602252634268, 0.999734199409201, 0.9998393350722603, 0.9999452898051967, 1.000139954851309, 1.0018503308715403, 0.9990737439769088, 1.0002758413752644, 1.0004544830614366, 0.9982817421289151, 0.9992479357804066, 1.001633841508743, 0.9965428984133855, 0.994530853359085, 0.9837225065675967, 0.6580421762349105, 0.5215711111231468, 0.8159755250923004, 0.6121137297484521, 0.5118255299101642, 0.7873946344638836, 0.9103090915074407, 0.42241257975625024, 0.6369669598717014, 0.9319366090281122, 0.9992694742104333, 0.8346253943492538, 0.8359493131794143, 0.9882343806641858, 0.9986546578970215, 0.9997027780427084, 1.0002177253501325, 0.6348796471155733, 0.9286056964387958, 0.9995981829796013, 0.9999972047996486, 0.9993392772440369, 0.9995531228092318, 1.0004479554646222, 0.9999132176697587, 0.9998347147071931, 1.0001733805032855, 0.9998100238812003, 1.0000098420525698, 1.0000059985542469]
remaining_params_alpha4 = [1296, 1140, 1003,  883,  777,  684,  602,  530,  466,  410,  361,  318,  280,  246,
  216,  190,  167,  147,  129,  114,  100,   88,   77,   68,   60,   53,   47,   41,
   36,   32,   28,   25,   22,   19,   17,   15,   13,   11,   10,    9,    8,    7,
    6,    5,    4,    3,    2]


fidelities_alpha8 = [0.999862532491192, 1.000055737974618, 1.0002182905835821, 1.0002598565288956, 0.9998411504151883, 1.0001785991953596, 0.9994705643371744, 1.000131636355786, 1.0001342703099374, 0.9991960431430322, 1.0006818138068185, 1.000609375417767, 0.9968723766804263, 1.0004472722964406, 1.0020075654148568, 0.9986258202329994, 1.0015795834473762, 0.9991778296278087, 0.9986290075794304, 1.0003145176800285, 0.9988224963953412, 0.6395058417634442, 0.8466665981010807, 0.5959961517265476, 0.7428196029771675, 0.8576182679696921, 0.45784015081244406, 0.9339728474305595, 0.9882918237428594, 0.9167450658328926, 0.7148799009226959, 0.9885224270400976, 0.9406557930101405, 0.7477522304119106, 0.9949422578286056, 0.9642928544092977, 0.9997273729123073, 1.0007097945829218, 0.999683702279856, 1.0011253396900648, 0.6913650313009206, 0.9907633981739651, 0.9820859671731542, 0.9854156677863622, 0.9990201780183159, 0.9998343196339554, 0.999695376710423, 1.000451705783833, 0.9993026142107837, 1.0004038560899553, 0.9998871923276809, 0.9982610954338756, 1.000210065105975]
remaining_params_alpha8 = [2592, 2281, 2007, 1766, 1554, 1368, 1204, 1060,  933,  821,  722,  635,  559,  492,
  433,  381,  335,  295,  260,  229,  202,  178,  157,  138,  121,  106,   93,   82,
   72,   63,   55,   48,   42,   37,   33,   29,   26,   23,   20,   18,   16,   14,
   12,   11,   10,    9,    8,    7,    6,    5,    4,    3,    2]


fidelities_alpha16 = [1.0002080738440289, 0.9998504031501517, 1.0000193456927926, 0.9998718127087551, 0.9999006526494094, 1.000095594822997, 1.000009340644973, 1.000076512739329, 0.9997937308663332, 0.999310041547534, 1.0003222714967628, 1.001489980294017, 0.9994034604731302, 0.9992861142444088, 0.9975754657146486, 1.0023537967511544, 0.9987015283339331, 1.0017961946154026, 0.9980726604014325, 0.9873787908941343, 1.0023800091067727, 1.0007453365722434, 0.964989141082345, 0.9304446701542897, 0.579933544513775, 0.9452532520596008, 0.8851610934355021, 0.6540764501758688, 0.7214311152785188, 0.522064957227501, 0.9615468424899457, 0.9989444194917758, 0.7876785786149474, 0.6647795183451782, 0.6930722219163574, 0.9966884740150438, 0.9990078735548595, 0.9994532696352673, 0.897685942938071, 0.7025891763896476, 0.9950524398916926, 1.0009566347277832, 1.0001688212686606, 0.6118713211670141, 0.9999585212388901, 1.0002722792320615, 0.9783886180776854, 0.9974778016730979, 0.994144737611786, 0.9989923588497355, 0.9987317888257059, 0.9992836063781074, 0.9998455187383875, 0.9995102177189755, 0.9999895631160712, 0.9999402672731095, 1.000047542257251, 1.0002704168382486]
remaining_params_alpha16 = [5184, 4562, 4015, 3533, 3109, 2736, 2408, 2119, 1865, 1641, 1444, 1271, 1118,  984,
  866,  762,  671,  590,  519,  457,  402,  354,  312,  275,  242,  213,  187,  165,
  145,  128,  113,   99,   87,   77,   68,   60,   53,   47,   41,   36,   32,   28,
   25,   22,   19,   17,   15,   13,   11,   10,    9,    8,    7,    6,    5,    4,
    3,    2]

fidelities_alpha32 = [0.9998676915927804, 1.0003808192762345, 0.9995915363574529, 1.0001914399236351, 0.9998599255202338, 1.0000811925020905, 1.0002334364581498, 1.0000126830483098, 0.9998693065584444, 0.9999468825065013, 0.9997463839489202, 1.0000029307796092, 0.9985048930158288, 0.9994411429234544, 0.9999698244555631, 1.0005576593305763, 0.9998215382898317, 0.9993106027886267, 0.9967817356397575, 0.9999348661935737, 0.9966365681279113, 1.000001707129311, 0.9987639437852366, 0.9959872627174629, 0.9982505039084688, 0.9868849026303526, 0.9981469431741037, 0.9993019361416047, 0.7742488934830779, 0.5893580920089525, 0.6488835341545299, 0.5621090775863535, 0.9750961684306216, 0.888794034707362, 0.7444514539527612, 0.9932679639305108, 0.6098187360708222, 0.9706232013454639, 1.0003054679466696, 1.0005195844632864, 0.9995551119250486, 0.5819843910410092, 0.8685696517690032, 0.9981673955604005, 0.7626097849457186, 0.9982205609129516, 0.9827474858181453, 0.9950097367961024, 0.9979853961249212, 0.9985663517293124, 0.8513471257534079, 1.0000498771928747, 0.8952081168172055, 0.9999397095931822, 1.0000048985845034, 0.9980187566797184, 0.9995382589033935, 0.9998993246232678, 0.9984809888009983, 1.0001487042947312, 1.000071851922645, 0.9999772295890784, 1.000319581923362, 0.9995959180233106]
remaining_params_alpha32 = [10368,  9124,  8029,  7066,  6218,  5472,  4815,  4237,  3729,  3282,  2888,  2541,
  2236,  1968,  1732,  1524,  1341,  1180,  1038,   913,   803,   707,   622,   547,
   481,   423,   372,   327,   288,   253,   223,   196,   172,   151,   133,   117,
   103,    91,    80,    70,    62,    55,    48,    42,    37,    33,    29,    26,
    23,    20,    18,    16,    14,    12,    11,    10,     9,     8,     7,     6,
     5,     4,     3,     2]

fidelities_alpha4 = np.array(fidelities_alpha4)
remaining_params_alpha4 = np.array(remaining_params_alpha4)
fidelities_alpha8 = np.array(fidelities_alpha8)
remaining_params_alpha8 = np.array(remaining_params_alpha8)
fidelities_alpha16 = np.array(fidelities_alpha16)
remaining_params_alpha16 = np.array(remaining_params_alpha16)
fidelities_alpha32 = np.array(fidelities_alpha32)
remaining_params_alpha32 = np.array(remaining_params_alpha32)

In [None]:
colors = ["#000000", "#eecc66", "#004488", "#ee99aa", "#6699cc"]
colors = colors[::-1]

fig, ax = plt.subplots(figsize=(8,6), layout='constrained')

ax.axvline(x=16, color='black', linestyle='--', linewidth=3, label=r'$n/N=16$', alpha=0.5)

ax.plot(remaining_params_alpha32/N, fidelities_alpha32,
        '-', 
        label=r'$\alpha = 32$',
        markersize=7,
        linewidth=3,
        markeredgecolor='black',
        color=colors[3]
        )

ax.plot(remaining_params_alpha16/N, fidelities_alpha16,
        '-', 
        label=r'$\alpha = 16$',
        markersize=7,
        linewidth=3,
        markeredgecolor='black',
        color=colors[2]
        )

ax.plot(remaining_params_alpha8/N, fidelities_alpha8,
        '-', 
        label=r'$\alpha = 8$',
        markersize=7,
        linewidth=3,
        markeredgecolor='black',
        color=colors[1]
        )

ax.plot(remaining_params_alpha4/N, fidelities_alpha4,
        '-', 
        label=r'$\alpha = 4$',
        markersize=7,
        linewidth=3,
        markeredgecolor='black',
        color=colors[0]
        )


ax.legend(fontsize=20, reverse=True, loc='lower right')
ax.grid(linestyle='-', alpha=0.5)
ax.tick_params(axis='both', which='major', labelsize=20)
ax.set_xscale('log')
ax.set_xlabel(r'$\rho$', size=30)
ax.set_ylabel(r'$F(\rho)$', size=30)

fig.savefig('fig4b.pdf', bbox_inches='tight', format='pdf')

In [None]:
import matplotlib.gridspec as gridspec

fig = plt.figure(figsize=(24, 6), layout='constrained')
gs = gridspec.GridSpec(1, 3, width_ratios=[1, 1,1])  # left column wider

# Right column (top and bottom)
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1])
ax3 = fig.add_subplot(gs[0, 2])

# colors = ["#000000", "#eecc66", "#004488", "#ee99aa", "#6699cc"]
colors = ["#000000", "#004488", "#994455", "#997700", "#6d0b52", "#ee99aa", "#eecc66",  "#6699cc"]
shapes = ['o-', 's-', 'D-', '<-', 'P-', 'v-']

ax1.axvline(x=16, color='black', linestyle='--', linewidth=3, label=r'$\rho=16$', alpha=0.5)

ax1.plot(remaining_params[0,:41]/N, relative_errors[0,:41],
        '-', 
        label=r'IMP-WR',
        markersize=7,
        linewidth=4,
        markerfacecolor=colors[0],
        markeredgecolor='black',
        color=colors[0]
        )

ax1.plot(t_init_imp_remaining_params[:-4]/N, t_init_imp_relative_errors[:-4],
        '--', 
        label=r'$T(\theta_{\mathrm{init}}, m_{\mathrm{imp}})$',
        markersize=7,
        linewidth=4,
        markerfacecolor="orange",
        markeredgecolor='black',
        color="orange"
        )

ax1.plot(t_rand_imp_remaining_params[:-4]/N, t_rand_imp_relative_errors[:-4],
        ':', 
        label=r'$T(\theta_{\mathrm{rand}}, m_{\mathrm{imp}})$',
        markersize=7,
        linewidth=4,
        markerfacecolor=colors[2],
        markeredgecolor='black',
        color=colors[2]
        )

ax1.plot(t_init_rand_remaining_params[:-4]/N, t_init_rand_relative_errors[:-4],
        '-.', 
        label=r'$T(\theta_{\mathrm{init}}, m_{\mathrm{rand}})$',
        markersize=7,
        linewidth=4,
        markerfacecolor=colors[3],
        markeredgecolor='black',
        color=colors[3]
        )

colors = ["#eecc66", "#004488", "#ee99aa", "#000000", "#6699cc"]
    
ax2.axvline(x=16, color='black', linestyle='--', linewidth=3, label=r'$\rho=16$', alpha=0.5)

ax2.plot(remaining_params_alpha32/N, fidelities_alpha32,
        '-', 
        label=r'$\alpha = 32$',
        markersize=7,
        linewidth=3,
        markeredgecolor='black',
        color=colors[3]
        )

ax2.plot(remaining_params_alpha16/N, fidelities_alpha16,
        '-', 
        label=r'$\alpha = 16$',
        markersize=7,
        linewidth=3,
        markeredgecolor='black',
        color=colors[2]
        )

ax2.plot(remaining_params_alpha8/N, fidelities_alpha8,
        '-', 
        label=r'$\alpha = 8$',
        markersize=7,
        linewidth=3,
        markeredgecolor='black',
        color=colors[1]
        )

ax2.plot(remaining_params_alpha4/N, fidelities_alpha4,
        '-', 
        label=r'$\alpha = 4$',
        markersize=7,
        linewidth=3,
        markeredgecolor='black',
        color=colors[0]
        )


ax2.legend(fontsize=18, reverse=True, loc='lower right')
ax2.grid(linestyle='-', alpha=0.5)
ax2.tick_params(axis='both', which='major', labelsize=24)
ax2.set_xscale('log')
ax2.set_xlabel(r'$\rho$', size=35)
ax2.set_ylabel(r'$F(\rho)$', size=35)

ax1.grid(linestyle='-', alpha=0.5)
ax1.tick_params(axis='both', which='major', labelsize=24)
ax1.loglog()
ax1.set_xlabel(r'$\rho$', size=35)
ax1.legend(fontsize=18, reverse=True, loc='lower left')
ax1.set_ylabel(r'$\epsilon_{\mathrm{rel}}(E)$', size=35)

ax1.text(-0.14, 0.93, r'a)', horizontalalignment='center', verticalalignment='bottom', transform=ax1.transAxes, fontsize=24)
ax2.text(-0.14, 0.93, r'b)', horizontalalignment='center', verticalalignment='bottom', transform=ax2.transAxes, fontsize=24)
ax3.text(-0.14, 0.93, r'c)', horizontalalignment='center', verticalalignment='bottom', transform=ax3.transAxes, fontsize=24)
ax3.text(-0.14, 0.34, r'd)', horizontalalignment='center', verticalalignment='bottom', transform=ax3.transAxes, fontsize=24)