In [None]:
import os
import json

import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
import numpy as np

In [None]:
#RESULTS_DIR = os.path.join("P:/early_stopping_double_descent", "...") # enter folder direction here (i.e., two_layer_results)
RESULTS_DIR = "five_layer_results"

SMALL_SIZE = 10
MEDIUM_SIZE = 12
BIGGER_SIZE = 14

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('axes', titlesize=BIGGER_SIZE)  # fontsize of the figure title


def get_all_files(lr):
    files = os.listdir(RESULTS_DIR)
    files = [file for file in files if file.startswith(f'lr={lr}') and file.endswith("log.json")]

    return files


def get_files():
    files = os.listdir(RESULTS_DIR)
    files = [file for file in files if file.endswith("log.json")]

    
def append_id(filename, id):
    return "{0}_{2}.{1}".format(*filename.rsplit('.', 1) + [id])


def get_filename_individual(lr1, lr2):
    result_path = os.path.split(RESULTS_DIR)[1]
    result_path_split = result_path.split('_')
    
    name = f'lr1={lr1}_lr2={lr2}.pdf'
    
    return name


def get_filename(lr, vmax):
    result_path = os.path.split(RESULTS_DIR)[1]
    result_path_split = result_path.split('_')

    name = f'{lr}_vmax={vmax}.pdf'
    
    return name


def get_filename_range(lr1, lr2_low, lr2_high):
    result_path = os.path.split(RESULTS_DIR)[1]
    result_path_split = result_path.split('_')
    
    name = f'lr1={lr1}_lr2_low={lr2_low}_lr2_high=_{lr2_high}.pdf'
        
    return name


In [None]:
def extract_data(file_path, running_avg):
    f = open(file_path)
    data = json.load(f)

    train_acc, test_acc = [], []

    for i, d in enumerate(data):
        train_acc.append(d["train"]["acc1"])
        test_acc.append(d["test"]["acc1"])

    train_acc, test_acc =  np.array(train_acc), np.array(test_acc)

    if running_avg:
        train_acc, test_acc = get_running_avg(train_acc), get_running_avg(test_acc)
        
    return 100-train_acc, 100-test_acc


def get_running_avg(x, step=3):
    cumsum = np.cumsum(x) 
    return (cumsum[step:] - cumsum[:-step]) / float(step)


In [None]:
def plot_individual_run(lr1, lr2, running_avg=False):
    file_path = os.path.join(RESULTS_DIR, f"lr={lr1}_{lr2}_log.json")
    
    data = extract_data(file_path, running_avg)
    
    cmap = matplotlib.colormaps['viridis']
    colorList = [cmap(50 / 1000), cmap(350 / 1000)]

    fig, ax = plt.subplots(len(data), 1, sharex=True)
    ylabels = ["Train error", "Test error"]
    for k in range(len(data)):
        ax[k].set_xscale('log')
        data_vec = data[k]
        ax[k].plot(np.arange(0, data_vec.shape[0]), data_vec, color=colorList[k], lw=4)
        ax[k].set_ylabel(ylabels[k])
        
    plt.suptitle(fr"$\eta_{{\mathbf{{W}}}} = {lr1}$, $\eta_{{\mathbf{{v}}}} = {lr2}$")
    
    save_dir = "plots"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)  
        
    plt.savefig(os.path.join(save_dir, get_filename_individual(lr1, lr2)))
    plt.show()
    

def plot_results(lr, vmax=100, keep_nan=True, plot_nans=False):
    
    files = get_all_files(lr)

    f_ext = "_log.json"
    lrs = [f.split('_')[-1][:-9] for f in files]
    
    print(fr'Loading {len(lrs)} files')
    
    #assert len(lrs) == 99, "Something wrong with the amount of files"
    lrs_float = [float(l) for l in lrs]
    lrs_float, lrs = zip(*sorted(zip(lrs_float, lrs)))
    
    num_its = 1000
    train_err, test_err = np.zeros((len(lrs), num_its)), np.zeros((len(lrs), num_its))
    for i, l in enumerate(lrs):
        data = extract_data(os.path.join(RESULTS_DIR, f"lr={lr}_{l}{f_ext}"), running_avg=False)
        train_err[i] = data[0]
        test_err[i] = data[1]
    if (~np.isfinite(risks)).any() and np.nonzero(~(~np.isfinite(risks)).any(axis=-1))[0].any():
        print(r"Lr2 with nan/inf, then some lower lr2")
        idx=np.nonzero(~(~np.isfinite(risks)).any(axis=-1))[0][-1] + 1
        for j in range(10):
            if plot_nans:
                plot_individual_run(lr, lrs[idx-j])
            else:
                print(lrs[idx-j])
    else:
        print('No nans/inf values')
    ratios = np.array(lrs_float)/lr
            
    fig, ax = plt.subplots(2, 1, sharex=True)
    
    # Plot
    im1 = ax[0].imshow(train_err.transpose()[::-1, :], interpolation='none', aspect='auto', vmax=vmax, vmin=0)
    im2 = ax[1].imshow(test_err.transpose()[::-1, :], interpolation='none', aspect='auto', vmax=vmax, vmin=0)
    fig.colorbar(im1, ax=ax[0])
    fig.colorbar(im2, ax=ax[1])
    
    # Title
    plt.suptitle(fr"$\eta_{{\mathbf{{W}}}}={lr}$")
    
    # Set correct x- and y-ticks
    # X axis
    ten_powers = 10.0 ** np.arange(-5, 6)
    indices = []
    for val in ten_powers:
        indices += [np.argmin(np.abs(ratios-val))]
    
    for k in [1]:
        ax[k].set_xticks(indices)
        ax[k].set_xticklabels([f"$10^{{{i}}}$" for i in np.arange(-5, 6)])
        ax[k].set_xlabel(r"$\eta_{\mathbf{v}} / \eta_{\mathbf{W}}$")
    
    # Y axis
    ten_powers = 10.0 ** np.arange(4, 0, -1)
    indices = []
    for val in ten_powers:
        indices += [np.argmin(np.abs(geo_samples-val))]
    for k in range(2):
        ax[k].set_yticks(indices)
        ax[k].set_yticklabels([f"$10^{{{i}}}$" for i in np.arange(1, 5)])
        ax[k].set_ylabel(r"Iteration $t$")
    
    # Save file
    save_dir = "plots"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)  
        
    figname = get_filename(lr, vmax)
    plt.savefig(os.path.join("plots", figname), bbox_inches='tight')
    
    plt.show()
    return
    

In [None]:
# Function for plotting a selection of curves

def plot_individual_runs_range(lr1, lr_ratio_low, lr_ratio_high, running_avg=False):
    """Plot individual runs within range lr_ratio_low to lr_ratio_high"""
    
    files = get_all_files(lr1, batch_norm, uniform_noise)
        
    f_ext = "_log.json"
    lrs = [f.split('_')[-1][:-9] for f in files]

    # Keep runs only within range
    lr2_low, lr2_high = lr_ratio_low * lr1, lr_ratio_high * lr1
    # files = [f for f, lr in zip(files, lrs) if float(lr) >= lr2_low and float(lr) <= lr2_high]
    lrs = [float(lr) for lr in lrs if float(lr) >= lr2_low and float(lr) <= lr2_high]
    lrs.sort()

    fig, ax = plt.subplots(2, 1, sharex=True)
    ylabels = ["Train error", "Test error"]
    epochs = []
    for i, lr in enumerate(lrs):
        data = extract_data(os.path.join(RESULTS_DIR, f"lr={lr1}_{lr}{f_ext}"), running_avg)       

        for k in range(2):
            if i == 0:
                ax[k].set_xscale('log')
                ax[k].set_ylim([0, 100])
                ax[k].set_ylabel(ylabels[k])
            
            data_vec = data[k] 
            ax[k].plot(np.arange(0, data_vec.shape[0]), data_vec, lw=4)
            
            
    ax[0].legend([fr"$\eta_{{\mathbf{{v}}}} = {lr2}$" for lr2 in lrs])
    fig.suptitle(fr"$\eta_{{\mathbf{{w}}}} = {lr1}$")

    save_dir = "plots"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)  

    plt.savefig(os.path.join(save_dir, get_filename_range(lr1, lr2_low, lr2_high, batch_norm, uniform_noise)))
    plt.show()
    


In [None]:
plot_individual_run(0.1, 0.1)

In [None]:
plot_results(1e-4, 20, batch_norm=False, uniform_noise=False, plot_nans=False)

In [None]:
plot_individual_runs_range(1e-4, 5e-1, 2, uniform_noise=False)