In [1]:
import jax
jax.config.update('jax_platform_name', 'cpu')
import pickle as pkl
import matplotlib.pyplot as plt
import jax.numpy as jnp
import numpy as np 
import os
import numpy as np
import pickle as pkl
import jax.numpy as jnp
from tqdm import tqdm
from jax.tree_util import tree_map,tree_flatten,tree_flatten_with_path,keystr,tree_map_with_path
import matplotlib.pyplot as plt
from jax.numpy.linalg import matrix_norm,vector_norm


def compare_stats_settings(path):
    settings = {}
    if not os.path.isdir(path):
        return None,None
    
    for setting in os.listdir(path):

        argmax_test_accs = []
        maximum_test_accs = []
        train_accs = []
        for run in os.listdir(path + "/" + setting):
            if not os.path.isfile(path + "/" + setting + "/" + run + "/" + "stats.pkl"):
                continue
            with open(path + "/" + setting + "/" + run + "/" + "stats.pkl","rb") as f:
                stats = pkl.load(f)

            test_acc = tree_flatten_with_path(stats["test_acc"])[0]
            train_acc = tree_flatten_with_path(stats["train_acc"])[0]


            keys, test_acc = [e[0] for e in test_acc],jnp.stack([e[1] for e in test_acc])
            train_acc = jnp.stack([e[1] for e in train_acc])

            argmax_test_accs.append(np.asarray([int(keystr(keys[i])[1:-1]) for i in jnp.argmax(test_acc,axis=0)]))
            maximum_test_accs.append(np.asarray(test_acc[jnp.argmax(test_acc,axis=0),np.arange(test_acc.shape[-1])]))
            train_accs.append(np.asarray(train_acc[jnp.argmax(test_acc,axis=0),np.arange(test_acc.shape[-1])]))

        if len(argmax_test_accs) > 0:
            settings[setting] = (np.concatenate(argmax_test_accs),np.concatenate(maximum_test_accs),np.concatenate(train_accs))
    
    if settings == {}:
        return None,None
    
    highest_mean_acc_across_settings = max(settings,key=lambda x: np.mean(settings.get(x)[1]))
    return (np.mean(settings[highest_mean_acc_across_settings][1]),highest_mean_acc_across_settings, settings[highest_mean_acc_across_settings]),settings

def plot_stats(*paths):

    fig,axs = plt.subplots(nrows=1,ncols=len(paths))
    fig.set_size_inches(len(paths)*4,6)
    if len(paths) == 1:
        axs = [axs]
    for path,ax in zip(paths,axs):

        if not os.path.isfile(os.path.join(path,"stats.pkl")):
            continue

        with open(os.path.join(path,"stats.pkl"),"rb") as f:
            stats = pkl.load(f)

        train_acc = tree_flatten_with_path(stats["train_acc"])[0]
        test_acc = tree_flatten_with_path(stats["test_acc"])[0]

        train_x, train_y = np.asarray([keystr(e[0])[1:-1] for e in train_acc]),jnp.mean(jnp.stack([e[1] for e in train_acc]),axis=-1)
        test_x, test_y =np.asarray([keystr(e[0])[1:-1] for e in test_acc]),jnp.mean(jnp.stack([e[1] for e in test_acc]),axis=-1)

        ax.plot(train_x,train_y, label="train acc",c="blue")
        ax.plot(test_x,test_y, label="test acc",c="green")
        ax.plot(test_x,train_y-test_y, label="dif",c="red")
        ax.set_ylim(0.0,1.0)
        ax.legend()
        #ax.fill_between(x, y-error, y+error,alpha=0.3)
        #argmax = np.argmax(y)
        #max = np.max(y)
        #ax.plot([x[0],x[-1]],[max,max],c="red")
        #ax.plot([x[argmax],x[argmax]],[0,1],c="red")
        #ax.set_yticks((0,np.max(y),1))
        #ax.legend()
        

In [None]:
exps = ["./exps_adam/standard","./exps_adam/wd","b16_exps_adam/wd","./exps_adamw/wd","./exps_adam/norm","b16_exps_adam/norm", "./exps_adam/norm_layerwise_stepscale",
        "./exps_adam/center_norm","./exps_adam/center_norm_uncenter","./exps_adam/center_std_uncenter","./exps_adam/reverse_center_norm",]
for exp in exps:
    print("{0}: {1}".format(exp,compare_stats_settings(exp)[0]))

./exps_adam/standard: (0.79946667, '0.00025', (array([171102, 171102, 168900]), array([0.8002001 , 0.8022    , 0.79600006], dtype=float32), array([0.99977994, 0.9993001 , 0.9981202 ], dtype=float32)))
./exps_adam/wd: (0.80576664, '0.00015000000000000001', (array([139169, 168900, 130917]), array([0.8084    , 0.80250007, 0.8064    ], dtype=float32), array([0.99744004, 0.9958    , 0.99690014], dtype=float32)))
./exps_adamw/wd: None
./exps_adam/norm: (0.83150005, '0.8_100', (array([164628, 168900, 168536]), array([0.8339001, 0.8321001, 0.8285001], dtype=float32), array([0.99028   , 0.9908001 , 0.98946005], dtype=float32)))
./exps_adam/norm_layerwise_stepscale: None
./exps_adam/center_norm: None
./exps_adam/center_norm_uncenter: None
./exps_adam/center_std_uncenter: None
./exps_adam/reverse_center_norm: None
