In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
from pathlib import Path
import matplotlib.pyplot as plt
sns.set_theme(color_codes="darkgrid")

In [None]:
# Load the datasets (prototypes, fmnist, cifar10tl, imagenettetl, cifar100tl)
dataset = "fmnist"

In [None]:
our_all = pd.read_csv(f'out/{dataset}/our.csv')
sgd_bin_m1_all = pd.read_csv(f'out/{dataset}/sgd_bin.csv')
if dataset != "prototypes":
    sgd_small_all = pd.read_csv(f'out/{dataset}/sgd.csv')
baldassi_all = pd.read_csv(f'out/{dataset}/baldassi.csv')

sgd_bin_m1_all['mean_val_acc'] = sgd_bin_m1_all['mean_val_acc']*100
sgd_bin_m1_all['std_val_acc'] = sgd_bin_m1_all['std_val_acc']*100
if dataset != "prototypes":
    sgd_small_all['mean_val_acc'] = sgd_small_all['mean_val_acc']*100
    sgd_small_all['std_val_acc'] = sgd_small_all['std_val_acc']*100

In [None]:
bs=100
output_dim = 10

if dataset == "cifar10tl":
    input_dim = 9216
    dataset_size = 40000
elif dataset == "cifar100tl":
    input_dim = 9216
    dataset_size = 40000
    output_dim = 100
elif dataset == "imagenettetl":
    input_dim = 2304
    dataset_size = 10000
elif dataset == "prototypes":
    input_dim = 1000
    dataset_size = 10000
else:
    input_dim = 784
    dataset_size = 50000
    
# Compute memory requirements
def compute_memory(layers, baldassi=False, adam=False):
    mem_w, mem_w_ll, mem_w_gl, mem_h, mem_d = {}, {}, {}, {}, {}
    mem_w[layers], mem_w_ll[layers], mem_w_gl[layers], mem_h[layers], mem_d[layers] = 0, 0, 0, 0, 0
    layer_sizes = [int(x) for x in layers.split('_')]
    prev_size = input_dim
    mem_w_ll[layers], mem_d[layers] = 0, 0
    for i, size in enumerate(layer_sizes):
        if not baldassi or i%2 == 0:
            # Weight dim
            mem_w[layers] += prev_size * size
        elif baldassi and i%2 == 1:
            # Grouping layer weight dim
            mem_w_gl[layers] += prev_size * size
        
        # Pre-activation dim
        mem_h[layers] += size * bs
        
        # Locall loss weights dim
        mem_w_ll[layers] = max(mem_w_ll[layers], size * output_dim)
        
        # Gradient dim
        mem_d[layers] = max(mem_d[layers], prev_size * size)
        
        # Update previous size
        prev_size = size
    
    # Output layer dim
    mem_w[layers] += layer_sizes[-1] * output_dim
    
    # Output pre-activation dim
    mem_h[layers] += output_dim * bs
    
    if adam:
        mem_w[layers] *= 3
        
    return mem_w, mem_w_ll, mem_w_gl, mem_h, mem_d

# Compute computational requirements
def compute_operations(layers, binary=False, localloss=False, optimizer='sgd'):
    layers_sizes = [int(x) for x in layers.split('_')]
    layers_sizes = [input_dim] + layers_sizes + [output_dim]  # Add input and output layers
    
    if binary:
        total_xnor, total_popcount, total_increment = 0, 0, 0
    else:
        total_mult, total_add = 0, 0
    
    for i in range(len(layers_sizes) - 1):
        input_size, output_size = layers_sizes[i], layers_sizes[i + 1]
        elems = input_size * output_size
        
        if binary:
            forward_xnor = bs * elems
            forward_popcount = bs * elems
            local_classifier_xnor = bs * output_size * output_dim if localloss else 0
            local_classifier_popcount = bs * output_dim * output_size if localloss else 0
            
            backward_xnor = 9 * bs * output_size if i != len(layers_sizes) - 2 else 0
            weight_update_xnor = 2 * elems if i != len(layers_sizes) - 2 else 0
            weight_update_popcount = elems if i != len(layers_sizes) - 2 else 0
            weight_update_increment = 4 * elems if i != len(layers_sizes) - 2 else 0
            
            total_xnor += (forward_xnor + backward_xnor + weight_update_xnor + local_classifier_xnor)
            total_popcount += (forward_popcount + weight_update_popcount + local_classifier_popcount)
            total_increment += weight_update_increment
        else:
            forward_mult = bs * elems
            forward_add = bs * (input_size - 1) * output_size
            
            backward_mult_weight = bs * elems
            backward_add_weight = bs * (input_size - 1) * output_size
            
            backward_mult_input = bs * elems if i > 0 else 0
            backward_add_input = bs * (output_size - 1) * input_size if i > 0 else 0            
            
            # --- weight update ---
            if optimizer.lower() == 'sgd':
                # classic SGD update: 1 mult + 1 add per weight
                weight_update_mult = elems
                weight_update_add  = elems
                
            elif optimizer.lower() == 'adam':
                # 1) m_t = β1·m + (1-β1)·g : 2 mults, 1 add
                # 2) v_t = β2·v + (1-β2)·g²: 3 mults (g²+two scales), 1 add
                weight_update_mult = elems * (2 + 3)
                weight_update_add  = elems * (1 + 1)
                
                # 3) bias‐correction & normalization:
                #    m̂ = m_t/(1−β1^t), v̂ = v_t/(1−β2^t) → 2 divisions
                #    sqrt(v̂) → 1 sqrt
                #    sqrt(v̂)+ε → 1 add
                #    m̂/(sqrt(v̂)+ε) → 1 division
                weight_update_mult += elems * 4
                weight_update_add  += elems  # sqrt+ε
                
                # 4) lr * … and subtract from w
                weight_update_mult += elems     # lr * normalized m
                weight_update_add  += elems     # w ← w − Δ
            
            total_mult += (forward_mult + backward_mult_weight + backward_mult_input + weight_update_mult)
            total_add += (forward_add + backward_add_weight + backward_add_input + weight_update_add)
            
    if binary:
        return total_xnor, total_popcount, total_increment
    else:
        return total_mult, total_add

In [None]:
mem_sgd_small, mem_sgd_bin_m1, mem_ll, mem_baldassi, mem_bin_adam, mem_small_adam = {}, {}, {}, {}, {}, {}
for i, row in our_all.iterrows():
    mem_w, mem_w_ll, mem_w_gl, mem_h, mem_d = compute_memory(row['layers'])
    mem_ll[row["layers"]] = input_dim*dataset_size*1 + mem_w[row["layers"]]*8 + mem_w_ll[row["layers"]]*1 + mem_h[row["layers"]]*8
if dataset != "prototypes":
    for i, row in sgd_small_all.iterrows():
        mem_w, mem_w_ll, mem_w_gl, mem_h, mem_d = compute_memory(row['layers'])
        mem_sgd_small[row["layers"]] = input_dim*dataset_size//32*32 + mem_w[row["layers"]]*32 + mem_h[row["layers"]]*32 + mem_d[row["layers"]]*32
for i, row in sgd_bin_m1_all.iterrows():
    mem_w, mem_w_ll, mem_w_gl, mem_h, mem_d = compute_memory(row['layers'])
    mem_sgd_bin_m1[row["layers"]] = input_dim*dataset_size*1 + mem_w[row["layers"]]*32 + mem_h[row["layers"]]*32 + mem_d[row["layers"]]*32
for i, row in baldassi_all.iterrows():
    mem_w, mem_w_ll, mem_w_gl, mem_h, mem_d = compute_memory(row['layers'], baldassi=True)
    mem_baldassi[row["layers"]] = input_dim*dataset_size*1 + mem_w[row["layers"]]*8 + mem_w_gl[row["layers"]]*1 + mem_h[row["layers"]]*8

   
our_all['memory'] = our_all['layers'].apply(lambda x: mem_ll[x])
sgd_bin_m1_all['memory'] = sgd_bin_m1_all['layers'].apply(lambda x: mem_sgd_bin_m1[x])
if dataset != "prototypes":
    sgd_small_all['memory'] = sgd_small_all['layers'].apply(lambda x: mem_sgd_small[x])
baldassi_all['memory'] = baldassi_all['layers'].apply(lambda x: mem_baldassi[x])

## EXPERIMENT 1

In [None]:
if dataset == "imagenettetl":
    our_df = our_all[our_all["layers"].isin(["525_525", "255_255", "135_135", "75_75", "35_35"])]
    sgd_bin_m1_df = sgd_bin_m1_all[sgd_bin_m1_all["layers"].isin(["81_81", "36_36", "19_19", "10_10", "5_5"])]
    sgd_small_df = sgd_small_all[sgd_small_all["layers"].isin(["81_81", "36_36", "19_19", "10_10", "5_5"])]
    baldassi_df = baldassi_all[baldassi_all["layers"].isin(["660_60", "286_26", "154_14", "77_7", "44_4"])]
elif dataset == "prototypes":
    our_df = our_all[our_all["layers"].isin(["525_525", "255_255", "135_135", "75_75", "35_35"])]
    sgd_bin_m1_df = sgd_bin_m1_all
    baldassi_df = baldassi_all
elif dataset == "cifar10tl":
    our_df = our_all[our_all["layers"].isin(["525_525", "255_255", "135_135", "75_75", "35_35"])]
    sgd_bin_m1_df = sgd_bin_m1_all[sgd_bin_m1_all["layers"].isin(["70_70", "33_33", "18_18", "10_10", "5_5"])]
    sgd_small_df = sgd_small_all[sgd_small_all["layers"].isin(["70_70", "33_33", "18_18", "10_10", "5_5"])]
    baldassi_df = baldassi_all[baldassi_all["layers"].isin(["561_51", "264_24", "143_13", "77_7", "44_4"])]
elif dataset == "cifar100tl":
    our_df = our_all[our_all["layers"].isin(["4095_4095", "2025_2025", "1035_1035", "525_525", "255_255"])]
    sgd_bin_m1_df = sgd_bin_m1_all[sgd_bin_m1_all["layers"].isin(["715_715", "312_312", "145_145", "70_70", "33_33"])]
    sgd_small_df = sgd_small_all[sgd_small_all["layers"].isin(["715_715", "312_312", "145_145", "70_70", "33_33"])]
    baldassi_df = baldassi_all[baldassi_all["layers"].isin(["5951_541", "2541_231", "1177_107", "561_51", "264_24"])]
elif dataset == "fmnist":
    our_df = our_all[our_all["layers"].isin(["525_525", "255_255", "135_135", "75_75", "35_35"])]
    sgd_bin_m1_df = sgd_bin_m1_all[sgd_bin_m1_all["layers"].isin(["105_105", "44_44", "21_21", "11_11", "5_5"])]
    sgd_small_df = sgd_small_all[sgd_small_all["layers"].isin(["105_105", "44_44", "21_21", "11_11", "5_5"])]
    baldassi_df = baldassi_all[baldassi_all["layers"].isin(["880_80", "363_33", "165_15", "88_8", "44_4"])]
 
our_df = our_df.reset_index(drop=True)
sgd_bin_m1_df = sgd_bin_m1_df.reset_index(drop=True)
if dataset != "prototypes":
    sgd_small_df = sgd_small_df.reset_index(drop=True)
baldassi_df = baldassi_df.reset_index(drop=True)

In [None]:
# Plot the validation accuracy vs memory requirements
fig = plt.figure()

our_df = our_df.sort_values(by='memory')
sgd_bin_m1_df = sgd_bin_m1_df.sort_values(by='memory')
if dataset != "prototypes":
    sgd_small_df = sgd_small_df.sort_values(by='memory')
baldassi_df = baldassi_df.sort_values(by='memory')

def plot(df, label, color, single_layer = False, sgd = False):
    if sgd:
        zorder = 1
        linestyle = '-'
        fmt = 'v'
    else:
        zorder = 2
        linestyle = 'dashed' if single_layer else '-'
        fmt = 'o'

    plt.errorbar(df['memory']/8, df['mean_val_acc'], yerr=df['std_val_acc'], fmt=fmt, color=color, linestyle=linestyle, label=f"{label}", capsize=3, zorder=zorder)

plot(our_df, 'Ours', "black")
plot(baldassi_df, 'Baldassi et al. [10]', "darkred", single_layer=True)
if dataset != "prototypes":
    plot(sgd_small_df, 'FP input full-precision SGD', "darkblue", sgd=True)
plot(sgd_bin_m1_df, 'Binary input full-precision SGD', "darkgreen", sgd=True)
 
# Enhancing the plot
if dataset == "fmnist":
    plt.yticks(np.arange(39, 90, 5))
    plt.ylim(38, 90)
elif dataset == "imagenettetl":
    # Print yticks with decimal values   
    plt.yticks(np.arange(36, 93, 8))
    plt.ylim(35, 93)
elif dataset == "prototypes":
    plt.yticks(np.arange(35, 99, 7))
    plt.ylim(34, 99)
elif dataset == "cifar100tl":
    plt.yticks(np.arange(0, 60.5, 5))
    plt.ylim(-1, 61)
elif dataset == "cifar10tl":
    plt.yticks(np.arange(32, 88, 5))
    plt.ylim(31, 88)

from matplotlib.ticker import FormatStrFormatter
plt.gca().yaxis.set_major_formatter(FormatStrFormatter('%.1f'))

plt.grid(True, which="both")
plt.xlabel('Memory requirements (bytes)')
plt.ylabel('Test accuracy')
plt.legend(loc='lower right')
plt.show()

In [None]:
# fn = Path(f'fig/exp1a_{dataset}.pdf')
# fig.savefig(fn, bbox_inches='tight')

In [None]:
# SGD
mul_cost = 10000
add_cost = 10000

if dataset != "prototypes":
    sgd_small_df['operations'] = sgd_small_df['layers'].apply(lambda x: sum(a*b for a, b in zip(compute_operations(x), [mul_cost, add_cost])))
sgd_bin_m1_df['operations'] = sgd_bin_m1_df['layers'].apply(lambda x: sum(a*b for a, b in zip(compute_operations(x), [mul_cost, add_cost])))

# Localloss
xnor_cost = 1
popcount_cost = 10
increment_cost = 10

baldassi_df['operations'] = baldassi_df['layers'].apply(lambda x: sum(a*b for a, b in zip(compute_operations(x, binary=True, localloss=False), [xnor_cost, popcount_cost, increment_cost])))
our_df['operations'] = our_df['layers'].apply(lambda x: sum(a*b for a, b in zip(compute_operations(x, binary=True, localloss=True), [xnor_cost, popcount_cost, increment_cost])))

In [None]:
# Plot the validation accuracy vs memory requirements
fig = plt.figure()

our_df = our_df.sort_values(by='memory')
sgd_bin_m1_df = sgd_bin_m1_df.sort_values(by='memory')
if dataset != "prototypes":
    sgd_small_df = sgd_small_df.sort_values(by='memory')
baldassi_df = baldassi_df.sort_values(by='memory')

def plot(df, label, color, single_layer = False, sgd = False):
    if sgd:
        zorder = 1
        linestyle = '-'
        fmt = 'v'
    else:
        zorder = 2
        linestyle = 'dashed' if single_layer else '-'
        fmt = 'o'

    plt.errorbar(df['operations'], df['mean_val_acc'], yerr=df['std_val_acc'], fmt=fmt, color=color, linestyle=linestyle, label=f"{label}", capsize=3, zorder=zorder)

plot(our_df, 'Ours', "black")
plot(baldassi_df, 'Baldassi et al. [10]', "darkred", single_layer=True)
if dataset != "prototypes":
    plot(sgd_small_df, 'FP input full-precision SGD', "darkblue", sgd=True)
plot(sgd_bin_m1_df, 'Binary input full-precision SGD', "darkgreen", sgd=True)

# Enhancing the plot
if dataset == "fmnist":
    plt.yticks(np.arange(39, 90, 5))
    plt.ylim(38, 90)
    legend_loc = "lower left"
elif dataset == "imagenettetl":
    # Print yticks with decimal values   
    plt.yticks(np.arange(36, 93, 8))
    plt.ylim(35, 93)
    legend_loc = "lower left"
elif dataset == "prototypes":
    plt.yticks(np.arange(35, 99, 7))
    plt.ylim(34, 99)
    legend_loc = "lower right"
elif dataset == "cifar100tl":
    plt.yticks(np.arange(0, 60.5, 5))
    plt.ylim(-1, 61)
    legend_loc = "lower right"
elif dataset == "cifar10tl":
    plt.yticks(np.arange(32, 88, 5))
    plt.ylim(31, 88)
    legend_loc = "lower right"

from matplotlib.ticker import FormatStrFormatter
plt.gca().yaxis.set_major_formatter(FormatStrFormatter('%.1f'))

plt.xscale('log')
plt.grid(True, which="both")
plt.xlabel('Number of Boolean gates')
plt.ylabel('Test accuracy')
plt.legend(loc=legend_loc)
plt.show()

In [None]:
# fn = Path(f'fig/exp1b_{dataset}.pdf')
# fig.savefig(fn, bbox_inches='tight')

## EXPERIMENT 2

In [None]:
group_size_df = pd.read_csv(f'out/{dataset}/grouping_factor.csv')
group_size_df['layers'] = group_size_df['layers'].astype(str)

In [None]:
# Prepare figure and axis again for the updated x-axis and lines between values
fig = plt.figure()

# Ensure sorted order by number of parameters for consistent plotting
group_size_df = group_size_df.sort_values(by='grouping_factor')

# Get layers 75, 525, 525, and 2025 and plot them with different colors
for i, layer in enumerate([75, 135, 525, 2025]):
    subset = group_size_df[group_size_df['layers'] == str(layer)]
    if i == 0:
        color = "darkgreen"
    elif i == 1:
        color = "darkblue"
    elif i == 2:
        color = "grey"
    else:
        color = "darkred"
    
    plt.errorbar(subset['grouping_factor'], subset['mean_val_acc'], yerr=subset['std_val_acc'], color=color, capsize=5, fmt="-o", label=f"Hidden layer size: {layer}")

# Enhancing the plot
if dataset == "fmnist":
    plt.yticks(np.arange(67, 89, 2))
    plt.ylim(66, 88)
elif dataset == "imagenettetl":
    # Print yticks with decimal values   
    plt.yticks(np.arange(83, 92, 1))
    plt.ylim(82.5, 91.5)
elif dataset == "cifar10tl":
    plt.yticks(np.arange(55, 86, 2.5))
    plt.ylim(54, 86)

from matplotlib.ticker import FormatStrFormatter
plt.gca().yaxis.set_major_formatter(FormatStrFormatter('%.1f'))

plt.xlabel('Group size γ')
plt.ylabel('Test accuracy')
plt.xscale('log')
plt.legend(loc='lower right')
plt.grid(True)
plt.show()

In [None]:
# fn = Path(f'fig/exp2_{dataset}.pdf')
# fig.savefig(fn, bbox_inches='tight')

## EXPERIMENT 3

In [None]:
if dataset == "imagenettetl":
    our_df = our_all[(our_all["layers"] == "525_525") | (our_all["layers"] == "255_255") | (our_all["layers"] == "135_135") | (our_all["layers"] == "75_75") | (our_all["layers"] == "35_35")]
    baldassi_df_opt = baldassi_all[(baldassi_all["layers"] == "660_220") | (baldassi_all["layers"] == "286_286") | (baldassi_all["layers"] == "154_154") | (baldassi_all["layers"] == "77_77") | (baldassi_all["layers"] == "44_44")]
    baldassi_df_2l = baldassi_all[(baldassi_all["layers"] == "660_60_660_60") | (baldassi_all["layers"] == "286_26_286_26") | (baldassi_all["layers"] == "154_14_154_14") | (baldassi_all["layers"] == "77_7_77_7") | (baldassi_all["layers"] == "44_4_44_4")]
    baldassi_df_2l_opt = baldassi_all[(baldassi_all["layers"] == "660_220_660_220") | (baldassi_all["layers"] == "286_286_286_286") | (baldassi_all["layers"] == "154_154_154_154") | (baldassi_all["layers"] == "77_77_77_77") | (baldassi_all["layers"] == "44_44_44_44")]
elif dataset == "cifar10tl":
    our_df = our_all[(our_all["layers"] == "525_525") | (our_all["layers"] == "255_255") | (our_all["layers"] == "135_135") | (our_all["layers"] == "75_75") | (our_all['layers'] == '35_35')]
    baldassi_df_opt = baldassi_all[(baldassi_all["layers"] == "561_51") | (baldassi_all["layers"] == "264_88") | (baldassi_all["layers"] == "143_143") | (baldassi_all["layers"] == "77_77") | (baldassi_all["layers"] == "44_44")]
    baldassi_df_2l = baldassi_all[(baldassi_all["layers"] == "561_51_561_51") | (baldassi_all["layers"] == "264_24_264_24") | (baldassi_all["layers"] == "143_13_143_13") | (baldassi_all["layers"] == "77_7_77_7") | (baldassi_all["layers"] == "44_4_44_4")]
    baldassi_df_2l_opt = baldassi_all[(baldassi_all["layers"] == "561_51_561_51") | (baldassi_all["layers"] == "264_88_264_88") | (baldassi_all["layers"] == "143_143_143_143") | (baldassi_all["layers"] == "77_77_77_77") | (baldassi_all["layers"] == "44_44_44_44")]
elif dataset == "fmnist":
    our_df = our_all[(our_all["layers"] == "525_525") | (our_all["layers"] == "255_255") | (our_all["layers"] == "135_135") | (our_all["layers"] == "75_75") | (our_all["layers"] == "35_35")]
    baldassi_df_opt = baldassi_all[(baldassi_all["layers"] == "880_80") | (baldassi_all["layers"] == "363_33") | (baldassi_all["layers"] == "165_33") | (baldassi_all["layers"] == "88_8") | (baldassi_all["layers"] == "44_44")]
    baldassi_df_2l = baldassi_all[(baldassi_all["layers"] == "880_80_880_80") | (baldassi_all["layers"] == "363_33_363_33") | (baldassi_all["layers"] == "165_15_165_15") | (baldassi_all["layers"] == "88_8_88_8") | (baldassi_all["layers"] == "44_4_44_4")]
    baldassi_df_2l_opt = baldassi_all[(baldassi_all["layers"] == "880_80_880_80") | (baldassi_all["layers"] == "363_33_363_33") | (baldassi_all["layers"] == "165_33_165_33") | (baldassi_all["layers"] == "88_8_88_8") | (baldassi_all["layers"] == "44_44_44_44")]

In [None]:
# Plot the validation accuracy vs memory requirements
fig = plt.figure()

def plot(df, label, color, single_layer = False, sgd = False):
    df.sort_values(by='memory', inplace=True)
    if sgd:
        zorder = 1
        linestyle = '-'
        fmt = 'v'
    else:
        zorder = 2
        linestyle = 'dashed' if single_layer else '-'
        fmt = 'o'

    plt.errorbar(df['memory']/8, df['mean_val_acc'], yerr=df['std_val_acc'], fmt=fmt, color=color, linestyle=linestyle, label=f"{label}", capsize=3, zorder=zorder)

plot(our_df, 'Ours - 2 hidden layers', "black")
plot(baldassi_df_2l, 'Baldassi et al. [10] - 2 hidden layers', "darkgreen", sgd=True)
plot(baldassi_df_opt, 'Baldassi et al. [10] enhanced - 1 hidden layer', "darkred", single_layer=True)
plot(baldassi_df_2l_opt, 'Baldassi et al. [10] enhanced - 2 hidden layers', "darkblue", sgd=True)

# Enhancing the plot
if dataset == "fmnist":
    plt.yticks(np.arange(39, 89, 3.5))
    plt.ylim(38, 89)
elif dataset == "imagenettetl":
    # Print yticks with decimal values   
    plt.yticks(np.arange(27, 93, 8))
    plt.ylim(25, 93)
elif dataset == "prototypes":
    plt.yticks(np.arange(35, 99, 7))
    plt.ylim(34, 99)
elif dataset == "cifar10tl":
    plt.yticks(np.arange(16.5, 87.5, 5))
    plt.ylim(15.5, 87.5)

from matplotlib.ticker import FormatStrFormatter
plt.gca().yaxis.set_major_formatter(FormatStrFormatter('%.1f'))

plt.grid(True, which="both")
plt.xlabel('Memory requirements (bytes)')
plt.ylabel('Test accuracy')
plt.legend(loc='lower right')
plt.show()

In [None]:
# fn = Path(f'fig/exp3_{dataset}.pdf')
# fig.savefig(fn, bbox_inches='tight')

## EXPERIMENT 4

In [None]:
adam_small_all = pd.read_csv(f'out/{dataset}/adam.csv')
adam_small_all['mean_val_acc'] = adam_small_all['mean_val_acc']*100
adam_small_all['std_val_acc'] = adam_small_all['std_val_acc']*100

In [None]:
our_df = our_all[our_all['layers'].apply(lambda x: '1035' in x)].copy()
sgd_df = sgd_small_all[sgd_small_all['layers'].apply(lambda x: '1035' in x)].copy()
adam_df = adam_small_all[adam_small_all['layers'].apply(lambda x: '1035' in x)].copy()

# Compute the number of layers for each architecture
our_df.loc[:, '#layers'] = our_df['layers'].apply(lambda x: len(x.split('_')))
sgd_df.loc[:, '#layers'] = sgd_df['layers'].apply(lambda x: len(x.split('_')))
adam_df.loc[:, '#layers'] = adam_df['layers'].apply(lambda x: len(x.split('_')))

# Prepare figure and axis again for the updated x-axis and lines between values
fig = plt.figure()

sgd_color = "darkgreen"
adam_color = "darkblue"
sgd_alpha = 0.4
adam_alpha = 0.4

def plot(df, label, color, optimizer=None, alpha=1.0):
    df.sort_values(by='#layers', inplace=True)
    if optimizer == 'sgd':
        zorder = 1
        linestyle = 'dashed'
        fmt = 'v'
    elif optimizer == 'adam':
        zorder = 2
        linestyle = 'dotted'
        fmt = 's'
    else:
        zorder = 3
        linestyle = '-'
        fmt = 'o'

    plt.errorbar(df['#layers'], df['mean_val_acc'], yerr=df['std_val_acc'],
                 fmt=fmt, color=color, linestyle=linestyle, label=f"{label}",
                 capsize=3, zorder=zorder, alpha=alpha)

plot(our_df, 'Ours', color="black")
plot(sgd_df, 'Full-precision SGD', color=sgd_color, optimizer='sgd', alpha=sgd_alpha)
plot(adam_df, 'Full-precision Adam', color=adam_color, optimizer='adam', alpha=adam_alpha)
# # Adjusted function to apply shades based on the first layer within each learning algorithm
# def plot(df, label, color, optimizer=None):
#     df.sort_values(by='#layers', inplace=True)
#     if optimizer == 'sgd':
#         zorder = 1
#         linestyle = 'dashed'
#         fmt = 'v'
#     elif optimizer == 'adam':
#         zorder = 2
#         linestyle = 'dotted'
#         fmt = 's'
#     else:
#         zorder = 3
#         linestyle = '-'
#         fmt = 'o'

#     plt.errorbar(df['#layers'], df['mean_val_acc'], yerr=df['std_val_acc'], fmt=fmt, color=color, linestyle=linestyle, label=f"{label}", capsize=3, zorder=zorder)    
    
# plot(our_df, 'Ours', color="black")
# plot(sgd_df, 'Full-precision SGD', color="silver", optimizer='sgd')
# plot(adam_df, 'Full-precision Adam', color="dimgray", optimizer='adam')

# Enhancing the plot
if dataset == "fmnist":
    plt.yticks(np.arange(81, 91, 1.5))
    plt.ylim(80, 91)
elif dataset == "imagenettetl":
    plt.yticks(np.arange(0, 101, 10))
    plt.ylim(-1, 101)
elif dataset == "cifar10tl":
    plt.yticks(np.arange(5, 96, 10))
    plt.ylim(4, 96)
    
from matplotlib.ticker import FormatStrFormatter
plt.gca().yaxis.set_major_formatter(FormatStrFormatter('%.1f'))

plt.xticks(np.arange(1, 11, 1))
plt.xlabel('Number of layers')
plt.ylabel('Test accuracy')
plt.legend(loc='lower left')
plt.grid(True)
plt.show()

In [None]:
# fn = Path(f'fig/exp4_{dataset}.pdf')
# fig.savefig(fn, bbox_inches='tight')

## APPENDIX

### Adam

In [None]:
adam_bin_all = pd.read_csv(f'out/{dataset}/adam_bin.csv')
if dataset != "prototypes":
    adam_small_all = pd.read_csv(f'out/{dataset}/adam.csv')

adam_bin_all['mean_val_acc'] = adam_bin_all['mean_val_acc']*100
adam_bin_all['std_val_acc'] = adam_bin_all['std_val_acc']*100
if dataset != "prototypes":
    adam_small_all['mean_val_acc'] = adam_small_all['mean_val_acc']*100
    adam_small_all['std_val_acc'] = adam_small_all['std_val_acc']*100

In [None]:
# Adam memory
for i, row in adam_bin_all.iterrows():
    mem_w, mem_w_ll, mem_w_gl, mem_h, mem_d = compute_memory(row['layers'], adam=True)
    mem_bin_adam[row["layers"]] = input_dim*dataset_size*1 + mem_w[row["layers"]]*32 + mem_h[row["layers"]]*32 + mem_d[row["layers"]]*32
if dataset != "prototypes":
    for i, row in adam_small_all.iterrows():
        mem_w, mem_w_ll, mem_w_gl, mem_h, mem_d = compute_memory(row['layers'], adam=True)
        mem_small_adam[row["layers"]] = input_dim*dataset_size//32*32 + mem_w[row["layers"]]*32 + mem_h[row["layers"]]*32 + mem_d[row["layers"]]*32

adam_bin_all['memory'] = adam_bin_all['layers'].apply(lambda x: mem_bin_adam[x])
if dataset != "prototypes":
    adam_small_all['memory'] = adam_small_all['layers'].apply(lambda x: mem_small_adam[x])

In [None]:
adam_bin_df = adam_bin_all
adam_bin_df = adam_bin_df.reset_index(drop=True)

if dataset != "prototypes":
    adam_small_df = adam_small_all[~adam_small_all['layers'].str.contains('1035')]
    adam_small_df = adam_small_df.reset_index(drop=True)

In [None]:
# Plot the validation accuracy vs memory requirements
fig = plt.figure()

our_df = our_df.sort_values(by='memory')
adam_bin_df = adam_bin_df.sort_values(by='memory')
if dataset != "prototypes":
    adam_small_df = adam_small_df.sort_values(by='memory')

def plot(df, label, color, single_layer = False, sgd = False):
    if sgd:
        zorder = 1
        linestyle = '-'
        fmt = 'v'
    else:
        zorder = 2
        linestyle = 'dashed' if single_layer else '-'
        fmt = 'o'

    plt.errorbar(df['memory']/8, df['mean_val_acc'], yerr=df['std_val_acc'], fmt=fmt, color=color, linestyle=linestyle, label=f"{label}", capsize=3, zorder=zorder)

plot(our_df, 'Ours', "black")
if dataset != "prototypes":
    plot(adam_small_df, 'FP input full-precision Adam', "darkblue", sgd=True)
plot(adam_bin_df, 'Binary input full-precision Adam', "darkgreen", sgd=True)

# Enhancing the plot
if dataset == "fmnist":
    plt.yticks(np.arange(30, 91, 6))
    plt.ylim(29, 91)
elif dataset == "imagenettetl":
    # Print yticks with decimal values   
    plt.yticks(np.arange(5, 94, 8))
    plt.ylim(4, 94)
elif dataset == "prototypes":
    plt.yticks(np.arange(28, 99, 7))
    plt.ylim(27, 99)
elif dataset == "cifar100tl":
    plt.yticks(np.arange(0, 60.5, 5))
    plt.ylim(-1, 61)
elif dataset == "cifar10tl":
    plt.yticks(np.arange(5, 86, 8))
    plt.ylim(4, 86)

from matplotlib.ticker import FormatStrFormatter
plt.gca().yaxis.set_major_formatter(FormatStrFormatter('%.1f'))

plt.grid(True, which="both")
plt.xlabel('Memory requirements (bytes)')
plt.ylabel('Test accuracy')
plt.legend(loc='lower right')
plt.show()

In [None]:
# fn = Path(f'fig/exp_appendix3a_{dataset}.pdf')
# fig.savefig(fn, bbox_inches='tight')

In [None]:
# Adam
mul_cost = 10000
add_cost = 10000

if dataset != "prototypes":
    adam_small_df['operations'] = adam_small_df['layers'].apply(lambda x: sum(a*b for a, b in zip(compute_operations(x, optimizer="adam"), [mul_cost, add_cost])))
adam_bin_df['operations'] = adam_bin_df['layers'].apply(lambda x: sum(a*b for a, b in zip(compute_operations(x, optimizer="adam"), [mul_cost, add_cost])))

In [None]:
# Plot the validation accuracy vs memory requirements
fig = plt.figure()

adam_bin_df = adam_bin_df.sort_values(by='memory')
if dataset != "prototypes":
    adam_small_df = adam_small_df.sort_values(by='memory')

def plot(df, label, color, single_layer = False, sgd = False):
    if sgd:
        zorder = 1
        linestyle = '-'
        fmt = 'v'
    else:
        zorder = 2
        linestyle = 'dashed' if single_layer else '-'
        fmt = 'o'

    plt.errorbar(df['operations'], df['mean_val_acc'], yerr=df['std_val_acc'], fmt=fmt, color=color, linestyle=linestyle, label=f"{label}", capsize=3, zorder=zorder)

plot(our_df, 'Ours', "black")
if dataset != "prototypes":
    plot(adam_small_df, 'FP input full-precision Adam', "darkblue", sgd=True)
plot(adam_bin_df, 'Binary input full-precision Adam', "darkgreen", sgd=True)

# Enhancing the plot
if dataset == "fmnist":
    plt.yticks(np.arange(30, 91, 6))
    plt.ylim(29, 91)
    legend_loc = "lower left"
elif dataset == "imagenettetl":
    # Print yticks with decimal values   
    plt.yticks(np.arange(5, 94, 8))
    plt.ylim(4, 94)
    legend_loc = "lower left"
elif dataset == "prototypes":
    plt.yticks(np.arange(28, 99, 7))
    plt.ylim(27, 99)
elif dataset == "cifar100tl":
    plt.yticks(np.arange(0, 60.5, 5))
    plt.ylim(-1, 61)
elif dataset == "cifar10tl":
    plt.yticks(np.arange(5, 86, 8))
    plt.ylim(4, 86)
    legend_loc = "lower left"

from matplotlib.ticker import FormatStrFormatter
plt.gca().yaxis.set_major_formatter(FormatStrFormatter('%.1f'))

plt.xscale('log')
plt.grid(True, which="both")
plt.xlabel('Number of Boolean gates')
plt.ylabel('Test accuracy')
plt.legend(loc=legend_loc)
plt.show()

In [None]:
# fn = Path(f'fig/exp_appendix3b_{dataset}.pdf')
# fig.savefig(fn, bbox_inches='tight')

### Rob

In [None]:
rob_df = pd.read_csv(f'out/{dataset}/rob.csv')
rob_df['layers'] = rob_df['layers'].astype(str)

In [None]:
# Prepare figure and axis again for the updated x-axis and lines between values
fig = plt.figure()

# Ensure sorted order by number of parameters for consistent plotting
rob_df = rob_df.sort_values(by='rob')

# Get layers 75, 525, 525, and 2025 and plot them with different colors
for i, layer in enumerate([75, 135, 525, 2025]):
    subset = rob_df[rob_df['layers'] == str(layer)]
    if i == 0:
        color = "darkgreen"
    elif i == 1:
        color = "darkblue"
    elif i == 2:
        color = "grey"
    else:
        color = "darkred"
    
    plt.errorbar(subset['rob'], subset['mean_val_acc'], yerr=subset['std_val_acc'], color=color, capsize=5, fmt="-o", label=f"Hidden layer size: {layer}")

# Enhancing the plot
if dataset == "fmnist":
    plt.yticks(np.arange(77, 88, 1))
    plt.ylim(76, 88)
elif dataset == "imagenettetl":
    # Print yticks with decimal values   
    plt.yticks(np.arange(83.5, 92, 1))
    plt.ylim(82.5, 92.5)
elif dataset == "cifar10tl":
    plt.yticks(np.arange(75, 85, 1))
    plt.ylim(74, 85)

from matplotlib.ticker import FormatStrFormatter
plt.gca().yaxis.set_major_formatter(FormatStrFormatter('%.1f'))

plt.xticks([0.0, 0.25, 0.5, 0.75, 1.0])
plt.xlabel(r"Robustness $r$")
plt.ylabel('Test accuracy')
# plt.xscale('log')
plt.legend(loc='lower right')
plt.grid(True)
plt.show()

In [None]:
# fn = Path(f'fig/exp_appendix1_{dataset}.pdf')
# fig.savefig(fn, bbox_inches='tight')

### Grouping layer size optimization

In [None]:
group_size_df = pd.read_csv(f'out/{dataset}/baldassi.csv')
# Filter rows that have max 2 layers
group_size_df = group_size_df[group_size_df['layers'].apply(lambda x: len(x.split('_')) == 2)]
group_size_df['grouping_factor'] = group_size_df['layers'].apply(lambda x: int(x.split('_')[1]))
group_size_df['layers'] = group_size_df['layers'].apply(lambda x: str(x.split('_')[0]))

In [None]:
# Prepare figure and axis again for the updated x-axis and lines between values
fig = plt.figure()

# Ensure sorted order by number of parameters for consistent plotting
group_size_df = group_size_df.sort_values(by='grouping_factor')

# Plot each layer with different colors
for i, layer in enumerate(group_size_df['layers'].unique()):
    subset = group_size_df[group_size_df['layers'] == layer]
    
    if i == 0:
        color = "darkgreen"
    elif i == 1:
        color = "darkblue"
    elif i == 2:
        color = "grey"
    elif i == 3:
        color = "darkred"
    else:
        color = "black"
        
    # print(layer, subset.loc[subset['mean_val_acc'].idxmax()]['grouping_factor'])
    plt.errorbar(subset['grouping_factor'], subset['mean_val_acc'], yerr=subset['std_val_acc'], color=color, capsize=5, fmt="-o", label=f"Hidden layer size: {layer}")

# Enhancing the plot
if dataset == "fmnist":
    plt.yticks(np.arange(10, 91, 10))
    plt.ylim(8, 92)
elif dataset == "imagenettetl":
    plt.yticks(np.arange(45, 96, 5))
    plt.ylim(44, 96)
elif dataset == "cifar10tl":
    plt.yticks(np.arange(5, 86, 10))
    plt.ylim(3, 87)

from matplotlib.ticker import FormatStrFormatter
plt.gca().yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
    
plt.xscale('log')
plt.xlabel('Baldassi et al. [10] classification layer size ')
plt.ylabel('Test accuracy')
plt.legend(loc='lower right')
plt.grid(True)
plt.show()

In [None]:
# fn = Path(f'fig/exp_appendix2_{dataset}.pdf')
# fig.savefig(fn, bbox_inches='tight')

### UCI

In [None]:
localloss_all = pd.read_csv(f'out/uci/our.csv')
baldassi_all = pd.read_csv(f'out/uci/baldassi.csv')

localloss_all['layers'] = localloss_all['layers'].astype(str)
baldassi_all['layers'] = baldassi_all['layers'].astype(str)
# sgd_all['mean_val_acc'] = sgd_all['mean_val_acc']*100
# sgd_all['std_val_acc'] = sgd_all['std_val_acc']*100

# Take only rows with layers = 135_135
# localloss_all = localloss_all[(localloss_all['layers'] == "135_135") | (localloss_all['layers'] == "55_55")]

# Extracting relevant columns based on actual names
df_table_ll = localloss_all[[
    "dataset", "layers", "mean_val_acc", "std_val_acc"
]].copy()

df_table_baldassi = baldassi_all[[
    "dataset", "layers", "mean_val_acc", "std_val_acc"
]].copy()

# Formatting the accuracy values as "acc ± std" with 2 decimal places
df_table_ll["mean_val_acc"] = df_table_ll.apply(
    lambda row: f"{row['mean_val_acc']:.2f} ± {row['std_val_acc']:.2f}", axis=1
)
df_table_baldassi["mean_val_acc"] = df_table_baldassi.apply(
    lambda row: f"{row['mean_val_acc']:.2f} ± {row['std_val_acc']:.2f}", axis=1
)

# Dropping the standard deviation column as it's now merged with accuracy
df_table_ll = df_table_ll.drop(columns=["std_val_acc"]).sort_values(["dataset","layers"])
df_table_baldassi = df_table_baldassi.drop(columns=["std_val_acc"]).sort_values(["dataset","layers"])

# Renaming columns for better visualization
df_table_ll.columns = ["Dataset", "Layers", "OUR Validation Accuracy"]
df_table_baldassi.columns = ["Dataset", "Layers", "baldassi Validation Accuracy"]

# Split the layers column into two integers and multiply them to get the dimension
df_table_ll['Layer_Dim'] = df_table_ll['Layers'].apply(lambda x: np.prod([int(i) for i in x.split('_')]))
df_table_baldassi['Layer_Dim'] = df_table_baldassi['Layers'].apply(lambda x: np.prod([int(i) for i in x.split('_')]))

# Sort by Dataset and Layer_Dim
df_table_ll = df_table_ll.sort_values(by=['Dataset', 'Layer_Dim'], ascending=[True, False])
df_table_baldassi = df_table_baldassi.sort_values(by=['Dataset', 'Layer_Dim'], ascending=[True, False])

# Create a rank per dataset based on the Layer_Dim for OUR and baldassi tables
df_table_ll['Rank'] = df_table_ll.groupby('Dataset')['Layer_Dim'].rank(ascending=False, method='first').astype(int)
df_table_baldassi['Rank'] = df_table_baldassi.groupby('Dataset')['Layer_Dim'].rank(ascending=False, method='first').astype(int)

# Drop Layer_Dim column
df_table_ll = df_table_ll.drop(columns=['Layer_Dim'])
df_table_baldassi = df_table_baldassi.drop(columns=['Layer_Dim'])

In [None]:
# Create a scatter plot to compare localloss vs sgd accuracies
fig = plt.figure()

# Extract accuracies from the table
localloss_acc = df_table_ll['OUR Validation Accuracy'].str.split(' ± ').str[0].astype(float)
baldassi_acc = df_table_baldassi['baldassi Validation Accuracy'].str.split(' ± ').str[0].astype(float)

localloss_acc.reset_index(drop=True, inplace=True)
baldassi_acc.reset_index(drop=True, inplace=True)

# Plot the reference line x=y
plt.plot([0, 100], [0, 100], color='black', linestyle='dashed')

colors = ['blue', 'green', 'red']
labels = ['Large architectures', "Medium architectures", "Small architectures"]

# Plot the scatter plot
for i in range(len(localloss_acc)):
    plt.scatter(localloss_acc[i], baldassi_acc[i], color=colors[i%3])

# Create a legend
for i in range(len(labels)):
    plt.scatter([], [], color=colors[i], label=labels[i])

# Add labels and title
plt.xlim(48, 102)
plt.ylim(48, 102)

from matplotlib.ticker import FormatStrFormatter
plt.gca().yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
plt.gca().xaxis.set_major_formatter(FormatStrFormatter('%.1f'))

plt.xlabel('Our proposed solution test accuracy')
plt.ylabel('Baldassi et al. [10] test accuracy')
plt.legend()
plt.grid(True)

# Show the plot
plt.show()

In [None]:
# fn = Path(f'fig/exp_appendix3_uci.pdf')
# fig.savefig(fn, bbox_inches='tight')

## Additional Experiments

### Learning Dynamics

In [None]:
our_all_learning = pd.read_csv(f'out/{dataset}/our_per_step.csv')
sgd_bin_all_learning = pd.read_csv(f'out/{dataset}/sgd_bin_per_step.csv')
if dataset != "prototypes":
    sgd_small_all_learning = pd.read_csv(f'out/{dataset}/sgd_per_step.csv')

In [None]:
our_df = our_all_learning[(our_all_learning["layers"] == "525_525")]

if dataset == "imagenettetl":
    sgd_bin_m1_df = sgd_bin_all_learning[(sgd_bin_all_learning["layers"] == "81_81")]
    sgd_small_df = sgd_small_all_learning[(sgd_small_all_learning["layers"] == "81_81")]
elif dataset == "cifar10tl":
    sgd_bin_m1_df = sgd_bin_all_learning[ (sgd_bin_all_learning["layers"] == "70_70")]
    sgd_small_df = sgd_small_all_learning[(sgd_small_all_learning["layers"] == "70_70")]
elif dataset == "fmnist":
    sgd_bin_m1_df = sgd_bin_all_learning[(sgd_bin_all_learning["layers"] == "105_105")]
    sgd_small_df = sgd_small_all_learning[(sgd_small_all_learning["layers"] == "105_105")]

our_df = our_df.reset_index(drop=True)
sgd_bin_m1_df = sgd_bin_m1_df.reset_index(drop=True)
if dataset != "prototypes":
    sgd_small_df = sgd_small_df.reset_index(drop=True)

In [None]:
mem_sgd_small, mem_sgd_bin_m1, mem_ll, mem_baldassi = {}, {}, {}, {}
for i, row in our_df.iterrows():
    mem_w, mem_w_ll, mem_w_gl, mem_h, mem_d = compute_memory(row['layers'])
    mem_ll[row["layers"]] = input_dim*dataset_size*1 + mem_w[row["layers"]]*8 + mem_w_ll[row["layers"]]*1 + mem_h[row["layers"]]*8
if dataset != "prototypes":
    for i, row in sgd_small_df.iterrows():
        mem_w, mem_w_ll, mem_w_gl, mem_h, mem_d = compute_memory(row['layers'])
        mem_sgd_small[row["layers"]] = input_dim*dataset_size//32*32 + mem_w[row["layers"]]*32 + mem_h[row["layers"]]*32 + mem_d[row["layers"]]*32
for i, row in sgd_bin_m1_df.iterrows():
    mem_w, mem_w_ll, mem_w_gl, mem_h, mem_d = compute_memory(row['layers'])
    mem_sgd_bin_m1[row["layers"]] = input_dim*dataset_size*1 + mem_w[row["layers"]]*32 + mem_h[row["layers"]]*32 + mem_d[row["layers"]]*32
    
our_df['memory'] = our_df['layers'].apply(lambda x: mem_ll[x])
sgd_bin_m1_df['memory'] = sgd_bin_m1_df['layers'].apply(lambda x: mem_sgd_bin_m1[x])
if dataset != "prototypes":
    sgd_small_df['memory'] = sgd_small_df['layers'].apply(lambda x: mem_sgd_small[x])

In [None]:
sgd_bin_m1_df['mean_val_acc'] = sgd_bin_m1_df['mean_val_acc']*100
sgd_bin_m1_df['std_val_acc'] = sgd_bin_m1_df['std_val_acc']*100
sgd_bin_m1_df['mean_train_acc'] = sgd_bin_m1_df['mean_train_acc']*100
sgd_bin_m1_df['std_train_acc'] = sgd_bin_m1_df['std_train_acc']*100
if dataset != "prototypes":
    sgd_small_df['mean_val_acc'] = sgd_small_df['mean_val_acc']*100
    sgd_small_df['std_val_acc'] = sgd_small_df['std_val_acc']*100
    sgd_small_df['mean_train_acc'] = sgd_small_df['mean_train_acc']*100
    sgd_small_df['std_train_acc'] = sgd_small_df['std_train_acc']*100

In [None]:
# Helper to get unique layers sorted by memory
def get_sorted_layers(df):
    layers_mem = df.groupby('layers').first().reset_index()[['layers', 'memory']]
    return layers_mem.sort_values('memory')['layers'].tolist()

fig, ax = plt.subplots(figsize=(8, 5))

# Plot with std shadow for Ours (skip step 0)
layers_ours = get_sorted_layers(our_df)
for idx, layer in enumerate(layers_ours):
    subset = our_df[(our_df['layers'] == layer) & (our_df['step'] != 0)]
    ax.plot(subset['step'], subset['mean_val_acc'], label=f'Ours', color="black")
    ax.fill_between(
        subset['step'],
        subset['mean_val_acc'] - subset['std_val_acc'],
        subset['mean_val_acc'] + subset['std_val_acc'],
        color="black", alpha=0.15
    )

# Plot with std shadow for SGD Binary
layers_sgd_bin = get_sorted_layers(sgd_bin_m1_df)
for idx, layer in enumerate(layers_sgd_bin):
    subset = sgd_bin_m1_df[sgd_bin_m1_df['layers'] == layer].copy()
    subset['step'] = subset['step'] + 1
    ax.plot(subset['step'], subset['mean_val_acc'], label=f"Binary input full-precision SGD", color="darkgreen")
    ax.fill_between(
        subset['step'],
        subset['mean_val_acc'] - subset['std_val_acc'],
        subset['mean_val_acc'] + subset['std_val_acc'],
        color="darkgreen", alpha=0.15
    )

# Plot with std shadow for SGD Small
if dataset != "prototypes":
    layers_sgd_small = get_sorted_layers(sgd_small_df)
    for idx, layer in enumerate(layers_sgd_small):
        subset = sgd_small_df[sgd_small_df['layers'] == layer].copy()
        subset['step'] = subset['step'] + 1
        ax.plot(subset['step'], subset['mean_val_acc'], label=f'FP input full-precision SGD', color="darkblue")
        ax.fill_between(
            subset['step'],
            subset['mean_val_acc'] - subset['std_val_acc'],
            subset['mean_val_acc'] + subset['std_val_acc'],
            color="darkblue", alpha=0.15
        )

# Enhancing the plot
if dataset == "fmnist":
    plt.yticks(np.arange(26, 90, 7))
    plt.ylim(25, 91)
elif dataset == "imagenettetl":
    # Print yticks with decimal values   
    plt.yticks(np.arange(17, 96, 7))
    plt.ylim(16, 95)
elif dataset == "prototypes":
    plt.yticks(np.arange(32, 99, 8))
    plt.ylim(30, 98)
elif dataset == "cifar10tl":
    plt.yticks(np.arange(19, 86, 6))
    plt.ylim(18, 86)

from matplotlib.ticker import FormatStrFormatter
plt.gca().yaxis.set_major_formatter(FormatStrFormatter('%.1f'))

plt.grid(True, which="both")
plt.xlabel('Epoch')
plt.ylabel('Test accuracy')
plt.legend(loc='lower right')
plt.show()

In [None]:
# fn = Path(f'fig/exp_add1_{dataset}.pdf')
# fig.savefig(fn, bbox_inches='tight')

### Catastrophic Forgetting

In [None]:
our_all_forget = pd.read_csv(f'out/{dataset}/our_forget.csv')

In [None]:
import numpy as np
import matplotlib.pyplot as plt
# Plot test accuracy vs weight_clip for our_all_forget

# Compute log2(weight_clip) + 1 for x-axis
df = our_all_forget.sort_values('weight_clip').copy()
df['log2_weight_clip_plus1'] = np.log2(df['weight_clip']) + 1

fig, ax = plt.subplots(figsize=(7, 4))

ax.errorbar(
    df['log2_weight_clip_plus1'],
    df['mean_val_acc'],
    yerr=df['std_val_acc'],
    fmt='-o',
    color='black',
    capsize=4,
    label='Test accuracy'
)

# Enhancing the plot
plt.yticks(np.arange(10, 91, 10))
plt.ylim(5, 95)

from matplotlib.ticker import FormatStrFormatter
plt.gca().yaxis.set_major_formatter(FormatStrFormatter('%.1f'))

plt.grid(True, which="both")
plt.xlabel('Hidden weights bits')
plt.ylabel('Test accuracy')
plt.show()

In [None]:
# fn = Path(f'fig/exp_add2_{dataset}.pdf')
# fig.savefig(fn, bbox_inches='tight')

### Train Out Classifiers

In [None]:
our_all = pd.read_csv(f'out/{dataset}/our.csv')
our_train_out_all = pd.read_csv(f'out/{dataset}/our_train_out.csv')

layers = our_train_out_all['layers'].tolist()
# Sort layers by number of parameters (product of dims)
layers = sorted(layers, key=lambda x: np.prod([int(i) for i in x.split('_')]))

# Collect validation accuracies for each layer config
train_out_acc = [our_train_out_all[our_train_out_all['layers'] == l]['mean_val_acc'].values for l in layers]
all_acc = [our_all[our_all['layers'] == l]['mean_val_acc'].values for l in layers]

results = []
for idx, layer in enumerate(layers):
    mean_train = our_train_out_all.loc[our_train_out_all['layers'] == layer, 'mean_val_acc'].values[0]
    std_train = our_train_out_all.loc[our_train_out_all['layers'] == layer, 'std_val_acc'].values[0]
    mean_all = our_all.loc[our_all['layers'] == layer, 'mean_val_acc'].values[0]
    std_all = our_all.loc[our_all['layers'] == layer, 'std_val_acc'].values[0]
    results.append({
        "Layers": layer,
        "Train-out": f"{mean_train:.2f} ± {std_train:.2f}",
        "Fixed-out": f"{mean_all:.2f} ± {std_all:.2f}"
    })

results_df = pd.DataFrame(results)
display(results_df)