In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cmasher as cmr
from scipy.stats import sem
from scipy.optimize import curve_fit
import os
import sys
import json
import equinox as eqx
import flax
from flax import nnx
sys.path.append('../')

from src.networks.prunable_ffnn import PrunableFFNN
from src.networks.prunable_cnn import PrunableCNN
from src.networks.prunable_res_cnn import PrunableResCNN

plt.rc('text', usetex=True)
plt.rc('font', family='serif')

In [None]:
def relative_error(x, y):
    return np.abs(x - y) / np.abs(y)

def abs_error(x, y):
    return np.abs(x - y)

def abs_difference_per_spin(x, y, N):
    return np.abs(x - y) / N

def difference(x, y):
    return x - y

In [None]:
def load_serialized_FFNN(load_data_path, pruning_iter):
    """ Load a network from a file
    """
    file_name = load_data_path + r'model_piter={}.eqx'.format(pruning_iter)

    with open(file_name, "rb") as f:
        hyperparams = json.loads(f.readline().decode())
        network = PrunableFFNN(**hyperparams)
        graph, old_state = nnx.split(network)
        new_state = eqx.tree_deserialise_leaves(f, old_state)
        new_model = nnx.merge(graph, new_state)

    return new_model

def load_serialized_CNN(load_data_path, pruning_iter):
    """ Load a network from a file
    """
    file_name = load_data_path + r'model_piter={}.eqx'.format(pruning_iter)

    with open(file_name, "rb") as f:
        hyperparams = json.loads(f.readline().decode())
        network = PrunableCNN(**hyperparams)
        graph, old_state = nnx.split(network)
        new_state = eqx.tree_deserialise_leaves(f, old_state)
        new_model = nnx.merge(graph, new_state)

    return new_model

def load_serialized_ResCNN(load_data_path, pruning_iter):
    """ Load a network from a file
    """
    file_name = load_data_path + r'model_piter={}.eqx'.format(pruning_iter)

    with open(file_name, "rb") as f:
        hyperparams = json.loads(f.readline().decode())
        network = PrunableResCNN(**hyperparams)
        graph, old_state = nnx.split(network)
        new_state = eqx.tree_deserialise_leaves(f, old_state)
        new_model = nnx.merge(graph, new_state)

    return new_model

def load_sampling_data(load_data_path, piter):
    """ Load a network from a file
    """
    file_path = load_data_path + f'sampling_log_piter={piter}.json'
    os.path.expanduser(file_path)

    if not os.path.exists(file_path):
        print(f"Error: File not found - {file_path}")
    else:
        with open(file_path, 'r') as f:
            data = json.load(f)

    return data

def load_training_data(load_data_path, piter):
    """ Load a network from a file
    """
    file_path = load_data_path + f'training_log_piter={piter}.json'
    os.path.expanduser(file_path)

    if not os.path.exists(file_path):
        print(f"Error: File not found - {file_path}")
    else:
        with open(file_path, 'r') as f:
            data = json.load(f)

    return data

In [None]:
# Ground state energy of the N=10x10 2D TFIM at k=3.04438
e_gs = -3.211395417101927*100 

# Shallow FFNN data

IMP-WR

In [None]:
# Number of pruning iterations
pruning_iterations = 65
# Tranverse fields
tfs = np.array([3.04438])

file_path = '../../data/tfim_2d/ffnn/IMP_WR/k=3.04438/alpha=5/'

relative_errors = np.empty(pruning_iterations+1)
remaining_params = np.empty(pruning_iterations+1)
variances = np.empty(pruning_iterations+1)
all_energies = np.empty(pruning_iterations+1)

N = 100

network = load_serialized_FFNN(file_path, 'dense')
n_dense = network.get_num_params()
print(f'Number of parameters in dense network: {n_dense}')

data = load_sampling_data(file_path, 'dense')
energies = data['Energy']['Mean']
energies = np.array(energies)
energy = np.mean(energies)
e_rel = relative_error(energy, e_gs)

variance = data['Energy']['Variance']
variance = np.array(variance)
variance = np.mean(variance)

remaining_params[0] = n_dense
relative_errors[0] = e_rel
variances[0] = variance/N
all_energies[0] = energy

for piter in range(pruning_iterations):
    
    # Load the network
    network = load_serialized_FFNN(file_path, piter)
    n_rem = network.get_num_params()

    data = load_sampling_data(file_path, piter)      

    energies = data['Energy']['Mean']
    energies = np.array(energies)
    energy = np.mean(energies)
    e_rel = relative_error(energy, e_gs)

    variance = data['Energy']['Variance']
    variance = np.array(variance)
    variance = np.mean(variance)

    remaining_params[piter+1] = n_rem
    relative_errors[piter+1] = e_rel
    variances[piter+1] = variance/N
    all_energies[piter+1] = energy


$T(\theta_{\mathrm{init}}, m_{\mathrm{imp}})$

In [None]:
# Number of pruning iterations
selected_piters = np.array([0,3,6,9,12,15,18,21,24,27,30,33,36,39,42,45,48,51,54,57,60])

file_path = '../../data/tfim_2d/ffnn/T_init_imp/'

t_init_imp_relative_errors = np.empty(len(selected_piters)+1)
t_init_imp_remaining_params = np.empty(len(selected_piters)+1)
t_init_imp_variances = np.empty(len(selected_piters)+1)

t_init_imp_remaining_params[0] = remaining_params[0]
t_init_imp_relative_errors[0] = relative_errors[0]
t_init_imp_variances[0] = variances[0]

for i, piter in enumerate(selected_piters):

    # Load the network
    network = load_serialized_FFNN(file_path, piter)
    n_rem = network.get_num_params()

    data = load_sampling_data(file_path, piter)      

    energies = data['Energy']['Mean']
    energies = np.array(energies)
    energy = np.mean(energies)
    e_rel = relative_error(energy, e_gs)

    variance = data['Energy']['Variance']
    variance = np.array(variance)
    variance = np.mean(variance)

    t_init_imp_remaining_params[i+1] = n_rem
    t_init_imp_relative_errors[i+1] = e_rel
    t_init_imp_variances[i+1] = variance/N


$T(\theta_{\mathrm{rand}}, m_{\mathrm{imp}})$

In [None]:
# Number of pruning iterations
selected_piters = np.array([0,3,6,9,12,15,18,21,24,27,30,33,36,39,42,45,48,51,54,57,60])

file_path = '../../data/tfim_2d/ffnn/T_rand_imp/'

t_rand_imp_relative_errors = np.empty(len(selected_piters))
t_rand_imp_remaining_params = np.empty(len(selected_piters))
t_rand_imp_variances = np.empty(len(selected_piters))

for i, piter in enumerate(selected_piters):

    # Load the network
    network = load_serialized_FFNN(file_path, piter)
    n_rem = network.get_num_params()

    data = load_sampling_data(file_path, piter)      

    energies = data['Energy']['Mean']
    energies = np.array(energies)
    energy = np.mean(energies)
    e_rel = relative_error(energy, e_gs)

    variance = data['Energy']['Variance']
    variance = np.array(variance)
    variance = np.mean(variance)

    t_rand_imp_remaining_params[i] = n_rem
    t_rand_imp_relative_errors[i] = e_rel
    t_rand_imp_variances[i] = variance/N


$T(\theta_{\mathrm{init}}, m_{\mathrm{rand}})$

In [None]:
# Selected densities corresponding to certain pruning iterations
selected_piters = np.array([8.8000e-01, 5.9970e-01, 4.0868e-01, 2.7850e-01, 1.8980e-01, 1.2934e-01,
  8.8140e-02, 6.0060e-02, 4.0940e-02, 2.7900e-02, 1.9020e-02, 1.2980e-02,
  8.8400e-03, 6.0200e-03, 4.1000e-03, 2.7800e-03, 1.8800e-03, 1.2800e-03,
  8.6000e-04, 5.8000e-04, 4.0000e-04])

file_path = '../../data/tfim_2d/ffnn/T_init_rand/'

t_init_rand_relative_errors = np.empty(len(selected_piters)+1)
t_init_rand_remaining_params = np.empty(len(selected_piters)+1)
t_init_rand_variances = np.empty(len(selected_piters)+1)

t_init_rand_remaining_params[0] = remaining_params[0]
t_init_rand_relative_errors[0] = relative_errors[0]
t_init_rand_variances[0] = variances[0]

for i, piter in enumerate(selected_piters):

    # Load the network
    network = load_serialized_FFNN(file_path, piter)
    n_rem = network.get_num_params()

    data = load_sampling_data(file_path, piter)      

    energies = data['Energy']['Mean']
    energies = np.array(energies)
    energy = np.mean(energies)
    e_rel = relative_error(energy, e_gs)

    variance = data['Energy']['Variance']
    variance = np.array(variance)
    variance = np.mean(variance)

    t_init_rand_remaining_params[i+1] = n_rem
    t_init_rand_relative_errors[i+1] = e_rel
    t_init_rand_variances[i+1] = variance/N
    

In [None]:
fig, ax = plt.subplots(figsize=(8,6))

colors = ["#000000", "#004488", "#994455", "#997700", "#6d0b52", "#ee99aa", "#eecc66",  "#6699cc"]
shapes = ['o-', 's-', 'D-', '<-', 'P-', 'v-']


ax.plot(remaining_params[:-4], relative_errors[:-4],
    '-', 
    label=r'IMP-WR',
    markersize=7,
    linewidth=3,
    markerfacecolor=colors[0],
    markeredgecolor='black',
    color=colors[0]
    )

ax.plot(t_init_imp_remaining_params, t_init_imp_relative_errors,
    '--', 
    label=r'$T(\theta_{\mathrm{init}}, m_{\mathrm{imp}})$',
    markersize=7,
    linewidth=4,
    markerfacecolor="orange",
    markeredgecolor='black',
    color="orange"
    )

ax.plot(t_rand_imp_remaining_params, t_rand_imp_relative_errors,
    ':', 
    label=r'$T(\theta_{\mathrm{rand}}, m_{\mathrm{imp}})$',
    markersize=7,
    linewidth=4,
    markerfacecolor=colors[2],
    markeredgecolor='black',
    color=colors[2]
    )

ax.plot(t_init_rand_remaining_params, t_init_rand_relative_errors,
    '-.', 
    label=r'$T(\theta_{\mathrm{init}}, m_{\mathrm{rand}})$',
    markersize=7,
    linewidth=4,
    markerfacecolor=colors[3],
    markeredgecolor='black',
    color=colors[3]
    )

ax.loglog()
ax.grid(linestyle='-', alpha=0.5)
ax.tick_params(axis='both', which='major', labelsize=26)
ax.set_xlabel(r'$n$', size=40)
ax.set_ylabel(r'$\epsilon_{\mathrm{rel}}(E)$', size=40)
ax.legend(loc='lower left', fontsize=20, reverse=True)
ax.text(0.81, 0.85, r'Shallow FFNN', horizontalalignment='center', verticalalignment='bottom', transform=ax.transAxes, fontsize=26)

plt.savefig('fig1a.pdf', bbox_inches='tight', format='pdf')

# Shallow CNN data

IMP-WR

In [None]:
n_filters = 4
# Number of pruning iterations
p_iters = 31

file_path = '../../data/tfim_2d/cnn/IMP_WR/'

relative_errors = np.empty(p_iters+1)
remaining_params = np.empty(p_iters+1)
variances = np.empty(p_iters+1)
all_energies = np.empty(p_iters+1)

network = load_serialized_CNN(file_path, 'dense')
n_dense = network.get_num_params()
print(f'Number of parameters in dense network: {n_dense}')

data = load_sampling_data(file_path, 'dense')
energies = data['Energy']['Mean']
energies = np.array(energies)
energy = np.mean(energies)
e_rel = relative_error(energy, e_gs)

var = data['Energy']['Variance']
var = np.array(var)
variance = np.mean(var)

variances[0] = variance
remaining_params[0] = n_dense
relative_errors[0] = e_rel
all_energies[0] = energy

for piter in range(p_iters):

    # Load the network
    network = load_serialized_CNN(file_path, piter)
    n_rem = network.get_num_params()

    data = load_sampling_data(file_path, piter)      

    energies = data['Energy']['Mean']
    energies = np.array(energies)
    energy = np.mean(energies)
    e_rel = relative_error(energy, e_gs)

    var = data['Energy']['Variance']
    var = np.array(var)
    variance = np.mean(var)

    variances[piter+1] = variance
    remaining_params[piter+1] = n_rem
    relative_errors[piter+1] = e_rel
    all_energies[piter+1] = energy
    

$T(\theta_{\mathrm{init}}, m_{\mathrm{imp}})$

In [None]:
# Selected of pruning iterations
selected_p_iters = np.array([0,3,6,9,12,15,18,21,24,27,28,29,30])

file_path = '../../data/tfim_2d/cnn/T_init_imp/'

t_init_imp_relative_errors = np.empty(len(selected_p_iters)+1)
t_init_imp_remaining_params = np.empty(len(selected_p_iters)+1)
t_init_imp_variances = np.empty(len(selected_p_iters)+1)

t_init_imp_variances[0] = variances[0]
t_init_imp_remaining_params[0] = remaining_params[0]
t_init_imp_relative_errors[0] = relative_errors[0]

for i, piter in enumerate(selected_p_iters):

    # Load the network
    network = load_serialized_CNN(file_path, piter)
    n_rem = network.get_num_params()

    data = load_sampling_data(file_path, piter)      

    energies = data['Energy']['Mean']
    energies = np.array(energies)
    energy = np.mean(energies)
    e_rel = relative_error(energy, e_gs)

    var = data['Energy']['Variance']
    var = np.array(var)
    variance = np.mean(var)
    
    t_init_imp_variances[i+1] = variance
    t_init_imp_remaining_params[i+1] = n_rem
    t_init_imp_relative_errors[i+1] = e_rel


$T(\theta_{\mathrm{rand}}, m_{\mathrm{imp}})$

In [None]:
# Selected of pruning iterations
selected_p_iters = np.array([0,3,6,9,12,15,18,21,24,27,28,29,30])

file_path = '../../data/tfim_2d/cnn/T_rand_imp/'

t_rand_imp_relative_errors = np.empty(len(selected_p_iters))
t_rand_imp_remaining_params = np.empty(len(selected_p_iters))
t_rand_imp_variances = np.empty(len(selected_p_iters))

for i, piter in enumerate(selected_p_iters):

    # Load the network
    network = load_serialized_CNN(file_path, piter)
    n_rem = network.get_num_params()

    data = load_sampling_data(file_path, piter)      

    energies = data['Energy']['Mean']
    energies = np.array(energies)
    energy = np.mean(energies)
    e_rel = relative_error(energy, e_gs)

    var = data['Energy']['Variance']
    var = np.array(var)
    variance = np.mean(var)
    
    t_rand_imp_variances[i] = variance
    t_rand_imp_remaining_params[i] = n_rem
    t_rand_imp_relative_errors[i] = e_rel


$T(\theta_{\mathrm{init}}, m_{\mathrm{rand}})$

In [None]:
# Selected densities corresponding to certain pruning iterations
selected_piters = np.array([0.94444444, 0.77777778, 0.69444444, 0.61111111, 0.52777778, 0.44444444, 
                             0.36111111, 0.27777778, 0.19444444, 0.11111111, 0.02777778])

file_path = '../../data/tfim_2d/cnn/T_init_rand/'

t_init_rand_relative_errors = np.empty(len(selected_piters)+1)
t_init_rand_remaining_params = np.empty(len(selected_piters)+1)
t_init_rand_variances = np.empty(len(selected_piters)+1)

t_init_rand_remaining_params[0] = remaining_params[0]
t_init_rand_relative_errors[0] = relative_errors[0]
t_init_rand_variances[0] = variances[0]

for i, piter in enumerate(selected_piters):

    # Load the network
    network = load_serialized_CNN(file_path, piter)
    n_rem = network.get_num_params()

    data = load_sampling_data(file_path, piter)      

    energies = data['Energy']['Mean']
    energies = np.array(energies)
    energy = np.mean(energies)
    e_rel = relative_error(energy, e_gs)

    variance = data['Energy']['Variance']
    variance = np.array(variance)
    variance = np.mean(variance)

    t_init_rand_remaining_params[i+1] = n_rem
    t_init_rand_relative_errors[i+1] = e_rel
    t_init_rand_variances[i+1] = variance/N


In [None]:
fig, ax = plt.subplots(figsize=(8,6))

colors = ["#000000", "#004488", "#994455", "#997700", "#6d0b52", "#ee99aa", "#eecc66",  "#6699cc"]
shapes = ['o-', 's-', 'D-', '<-', 'P-', 'v-']


ax.plot(remaining_params, relative_errors,
    '-', 
    label=r'IMP-WR',
    markersize=7,
    linewidth=3,
    markerfacecolor=colors[0],
    markeredgecolor='black',
    color=colors[0]
    )

ax.plot(t_init_imp_remaining_params, t_init_imp_relative_errors,
    '--', 
    label=r'$T(\theta_{\mathrm{init}}, m_{\mathrm{imp}})$',
    markersize=7,
    linewidth=4,
    markerfacecolor="orange",
    markeredgecolor='black',
    color="orange"
    )

ax.plot(t_rand_imp_remaining_params, t_rand_imp_relative_errors,
    ':', 
    label=r'$T(\theta_{\mathrm{rand}}, m_{\mathrm{imp}})$',
    markersize=7,
    linewidth=4,
    markerfacecolor=colors[2],
    markeredgecolor='black',
    color=colors[2]
    )

ax.plot(t_init_rand_remaining_params, t_init_rand_relative_errors,
    '-.', 
    label=r'$T(\theta_{\mathrm{init}}, m_{\mathrm{rand}})$',
    markersize=7,
    linewidth=4,
    markerfacecolor=colors[3],
    markeredgecolor='black',
    color=colors[3]
    )

ax.loglog()
ax.grid(linestyle='-', alpha=0.5)
ax.tick_params(axis='both', which='major', labelsize=26)
ax.set_xlabel(r'$n$', size=40)
ax.set_ylabel(r'$\epsilon_{\mathrm{rel}}(E)$', size=40)
ax.legend(loc='lower left', fontsize=20, reverse=True)
ax.text(0.81, 0.85, r'Shallow CNN', horizontalalignment='center', verticalalignment='bottom', transform=ax.transAxes, fontsize=26)

plt.savefig('fig1b.pdf', bbox_inches='tight', format='pdf')

# ResCNN data

IMP-WR

In [None]:
# Number of pruning iterations
pruning_iterations = 43

file_path = '../../data/tfim_2d/res_cnn/IMP_WR/'

relative_errors = np.empty((pruning_iterations+1))
remaining_params = np.empty((pruning_iterations+1))
variances = np.empty((pruning_iterations+1))
stdevs = np.empty((pruning_iterations+1))

network = load_serialized_ResCNN(file_path, 'dense')
n_dense = network.get_num_params()
n_dense = n_dense[0] + n_dense[1]
print(f'Number of parameters in dense network: {n_dense}')

data = load_sampling_data(file_path, 'dense')
energies = data['Energy']['Mean']
energies = np.array(energies)
energy = np.mean(energies)
e_rel = relative_error(energy, e_gs)

var = data['Energy']['Variance']
variance = np.array(var)
variance = np.mean(var)

variances[0] = variance
remaining_params[0] = n_dense
relative_errors[0] = e_rel

for piter in range(pruning_iterations):

    # Load the network
    network = load_serialized_ResCNN(file_path, piter)
    n_rem = network.get_num_params()
    n_rem = n_rem[0] + n_rem[1]
  
    data = load_sampling_data(file_path, piter)      

    energies = data['Energy']['Mean']
    energies = np.array(energies)
    energy = np.mean(energies)
    e_rel = relative_error(energy, e_gs)

    var = data['Energy']['Variance']
    variance = np.array(var)
    variance = np.mean(var)

    variances[piter+1] = variance
    remaining_params[piter+1] = n_rem
    relative_errors[piter+1] = e_rel


$T(\theta_{\mathrm{init}}, m_{\mathrm{imp}})$

In [None]:
selected_pruning_iterations = np.array([0,3,6,9,12,15,18,21,24,27,30,33,36,39,40,41,42])

file_path = '../../data/tfim_2d/res_cnn/T_init_imp/'

t_init_imp_relative_errors = np.empty((len(selected_pruning_iterations)+1))
t_init_imp_remaining_params = np.empty((len(selected_pruning_iterations)+1))
t_init_imp_variances = np.empty((len(selected_pruning_iterations)+1))

t_init_imp_variances[0] = variances[0]
t_init_imp_remaining_params[0] = remaining_params[0]
t_init_imp_relative_errors[0] = relative_errors[0]

for i, piter in enumerate(selected_pruning_iterations):
    # Load the network
    network = load_serialized_ResCNN(file_path, piter)
    n_rem = network.get_num_params()
    n_rem = n_rem[0] + n_rem[1]

    data = load_sampling_data(file_path, piter)      

    energies = data['Energy']['Mean']
    energies = np.array(energies)
    energy = np.mean(energies)
    e_rel = relative_error(energy, e_gs)

    var = data['Energy']['Variance']
    variance = np.array(var)
    variance = np.mean(var)

    t_init_imp_variances[i+1] = variance
    t_init_imp_remaining_params[i+1] = n_rem
    t_init_imp_relative_errors[i+1] = e_rel


$T(\theta_{\mathrm{rand}}, m_{\mathrm{imp}})$

In [None]:
selected_pruning_iterations = np.array([0,3,6,9,12,15,18,21,24,27,30,33,36,39,40,41,42])

file_path = '../../data/tfim_2d/res_cnn/T_rand_imp/'

t_rand_imp_relative_errors = np.empty((len(selected_pruning_iterations)))
t_rand_imp_remaining_params = np.empty((len(selected_pruning_iterations)))
t_rand_imp_variances = np.empty((len(selected_pruning_iterations)))

for i, piter in enumerate(selected_pruning_iterations):
    # Load the network
    network = load_serialized_ResCNN(file_path, piter)
    n_rem = network.get_num_params()
    n_rem = n_rem[0] + n_rem[1]

    data = load_sampling_data(file_path, piter)      

    energies = data['Energy']['Mean']
    energies = np.array(energies)
    energy = np.mean(energies)
    e_rel = relative_error(energy, e_gs)

    var = data['Energy']['Variance']
    variance = np.array(var)
    variance = np.mean(var)

    t_rand_imp_variances[i] = variance
    t_rand_imp_remaining_params[i] = n_rem
    t_rand_imp_relative_errors[i] = e_rel


$T(\theta_{\mathrm{init}}, m_{\mathrm{rand}})$

In [None]:
selected_pruning_iterations = np.array([0.88092162, 0.60282084, 0.41332903, 0.28413006, 0.19611326, 0.13614341,
 0.0952304,  0.06734496, 0.04834195, 0.03542205, 0.02659345, 0.02056417,
 0.01647287, 0.01372739, 0.01302756, 0.01238157, 0.01, 0.009, 0.008, 0.007, 0.006, 0.004, 0.002])

file_path = '../../data/tfim_2d/res_cnn/T_init_rand/'

t_init_rand_relative_errors = np.empty((len(selected_pruning_iterations)+1))
t_init_rand_remaining_params = np.empty((len(selected_pruning_iterations)+1))
t_init_rand_variances = np.empty((len(selected_pruning_iterations)+1))

t_init_rand_variances[0] = variances[0]
t_init_rand_remaining_params[0] = remaining_params[0]
t_init_rand_relative_errors[0] = relative_errors[0]

for i, piter in enumerate(selected_pruning_iterations):
    # Load the network
    network = load_serialized_ResCNN(file_path, piter)
    n_rem = network.get_num_params()
    n_rem = n_rem[0] + n_rem[1]

    data = load_sampling_data(file_path, piter)      

    energies = data['Energy']['Mean']
    energies = np.array(energies)
    energy = np.mean(energies)
    e_rel = relative_error(energy, e_gs)

    var = data['Energy']['Variance']
    variance = np.array(var)
    variance = np.mean(var)

    t_init_rand_variances[i+1] = variance
    t_init_rand_remaining_params[i+1] = n_rem
    t_init_rand_relative_errors[i+1] = e_rel


In [None]:
fig, ax = plt.subplots(figsize=(8,6))

colors = ["#000000", "#004488", "#994455", "#997700", "#6d0b52", "#ee99aa", "#eecc66",  "#6699cc"]
shapes = ['o-', 's-', 'D-', '<-', 'P-', 'v-']


ax.plot(remaining_params, relative_errors,
    '-', 
    label=r'IMP-WR',
    markersize=7,
    linewidth=3,
    markerfacecolor=colors[0],
    markeredgecolor='black',
    color=colors[0]
    )

ax.plot(t_init_imp_remaining_params, t_init_imp_relative_errors,
    '--', 
    label=r'$T(\theta_{\mathrm{init}}, m_{\mathrm{imp}})$',
    markersize=7,
    linewidth=4,
    markerfacecolor="orange",
    markeredgecolor='black',
    color="orange"
    )

ax.plot(t_rand_imp_remaining_params, t_rand_imp_relative_errors,
    ':', 
    label=r'$T(\theta_{\mathrm{rand}}, m_{\mathrm{imp}})$',
    markersize=7,
    linewidth=4,
    markerfacecolor=colors[2],
    markeredgecolor='black',
    color=colors[2]
    )

ax.plot(t_init_rand_remaining_params, t_init_rand_relative_errors,
    '-.', 
    label=r'$T(\theta_{\mathrm{init}}, m_{\mathrm{rand}})$',
    markersize=7,
    linewidth=4,
    markerfacecolor=colors[3],
    markeredgecolor='black',
    color=colors[3]
    )

ax.loglog()
ax.grid(linestyle='-', alpha=0.5)
ax.tick_params(axis='both', which='major', labelsize=26)
ax.set_xlabel(r'$n$', size=40)
ax.set_ylabel(r'$\epsilon_{\mathrm{rel}}(E)$', size=40)
ax.legend(loc='upper center', fontsize=20, reverse=True)
ax.text(0.84, 0.85, r'ResCNN', horizontalalignment='center', verticalalignment='bottom', transform=ax.transAxes, fontsize=26)

plt.savefig('fig1c.pdf', bbox_inches='tight', format='pdf')