In [None]:
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
import os
import numpy as np

In [None]:
# RESULTS_DIR = os.path.join("P:/early_stopping_double_descent", "...") # enter folder direction here (i.e., two_layer_results)

SMALL_SIZE = 10
MEDIUM_SIZE = 12
BIGGER_SIZE = 14

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('axes', titlesize=BIGGER_SIZE)  # fontsize of the figure title


def get_all_files(lr, batch_norm):
    files = os.listdir(RESULTS_DIR)
    files = [file for file in files if file.startswith(f'lr={lr}')]
    if batch_norm:
        files = [file for file in files if "batch_norm" in file]
    else:
        files = [file for file in files if "batch_norm" not in file]
    return files

def get_files(batch_norm):
    files = os.listdir(RESULTS_DIR)
    files = [file for file in files if file.endswith(".csv")]
    if batch_norm:
        files = [file for file in files if "batch_norm" in file]
    else:
        files = [file for file in files if "batch_norm" not in file]
    return files

def append_id(filename, id):
    return "{0}_{2}.{1}".format(*filename.rsplit('.', 1) + [id])

def get_filename_individual(lr1, lr2, batch_norm):
    result_path = os.path.split(RESULTS_DIR)[1]
    result_path_split = result_path.split('_')
    linear = result_path_split[-1] == "linear"
    if linear:
        w = '_'.join(result_path_split[-3:-1])
    else:
        w = '_'.join(result_path_split[-2:])
    
    name = f'lr1={lr1}_lr2={lr2}_w={w}.pdf'
    
    if batch_norm:
        name = append_id(name, "batch_norm")
    if linear:
        name = append_id(name, "linear")
    return name

def plot_individual_run(lr1, lr2, batch_norm):
    file_path = os.path.join(RESULTS_DIR, f"lr={lr1}_{lr2}.csv")
    if batch_norm:
        file_path = append_id(file_path, "batch_norm") 
    data = pd.read_csv(file_path, header=None)
    geo_samples = [int(i) for i in np.geomspace(1, len(risks) - 1, num=700)]
    cmap = matplotlib.colormaps['viridis']
    colorList = [cmap(50 / 1000), cmap(350 / 1000)]
    labelList = ['same stepsize', 'scaled stepsize']

    #fig = plt.figure()
    fig, ax = plt.subplots(2,1, sharex=True)
    for k in range(2):
        ax[k].set_xscale('log')
        data_vec = data[k]
        ax[k].plot(geo_samples, data_vec[geo_samples],
            color=colorList[k],
            label=labelList[k],
            lw=4)

    plt.suptitle(fr"$\eta_{{\mathbf{{W}}}} = {lr1}$, $\eta_{{\mathbf{{v}}}} = {lr2}$")
    plt.savefig(os.path.join('plots', get_filename_individual(lr1, lr2, batch_norm)))
    plt.show()
    


def get_filename(lr, vmax, batch_norm):
    result_path = os.path.split(RESULTS_DIR)[1]
    result_path_split = result_path.split('_')
    linear = result_path_split[-1] == "linear"
    if linear:
        w = '_'.join(result_path_split[-3:-1])
    else:
        w = '_'.join(result_path_split[-2:])
    
    name = f'{lr}_vmax={vmax}_w={w}.pdf'
    if batch_norm:
        name = append_id(name, "batch_norm")
    if linear:
        name = append_id(name, "linear")
    return name

def plot_results(lr, vmax=20, keep_nan=True, batch_norm=False, plot_nans=False):
    files = get_all_files(lr, batch_norm)
    if batch_norm:
        f_ext = "_batch_norm.csv"
        lrs = [f.split('_')[1] for f in files]
    else:
        f_ext = ".csv"
        lrs = [f.split('_')[-1][:-4] for f in files]
    assert len(lrs) == 99, "Something wrong with the amount of files"
    lrs_float = [float(l) for l in lrs]
    lrs_float, lrs = zip(*sorted(zip(lrs_float, lrs)))
    risks = np.zeros((len(lrs), 50000))
    losses = np.zeros((len(lrs), 50000))
    for i, l in enumerate(lrs):
        data = pd.read_csv(os.path.join(RESULTS_DIR, f"lr={lr}_{l}{f_ext}"), header=None)
        risks[i] = data[0]
        losses[i] = data[1]
    if (~np.isfinite(risks)).any() and np.nonzero(~(~np.isfinite(risks)).any(axis=-1))[0].any():
        print(r"Lr2 with nan/inf, then some lower lr2")
        idx=np.nonzero(~(~np.isfinite(risks)).any(axis=-1))[0][-1] + 1
        for j in range(10):
            if plot_nans:
                plot_individual_run(lr, lrs[idx-j], batch_norm)
            else:
                print(lrs[idx-j])
    else:
        print('No nans/inf values')
    ratios = np.array(lrs_float)/lr
            
    # Subsample epochs
    geo_samples = [int(i) for i in np.geomspace(1, 50000 - 1, num=700)]
    fig, ax = plt.subplots(2,1, sharex=True)
    
    # Plot
    im1 = ax[0].imshow(risks[:, geo_samples].transpose()[::-1, :], interpolation='none', aspect='auto', vmax=vmax, vmin=0)
    im2 = ax[1].imshow(losses[:, geo_samples].transpose()[::-1, :], interpolation='none', aspect='auto', vmax=vmax, vmin=0)
    fig.colorbar(im1, ax=ax[0])
    fig.colorbar(im2, ax=ax[1])
    
    # Title
    title = fr"$\eta_{{\mathbf{{W}}}}={lr}$"
    if batch_norm:
        title += ", with batch normalization"
    plt.suptitle(title)
    
    # Set correct x- and y-ticks
    # X axis
    ten_powers = 10.0 ** np.arange(-5, 6)
    indices = []
    for val in ten_powers:
        indices += [np.argmin(np.abs(ratios-val))]
    
    for k in [1]:
        ax[k].set_xticks(indices)
        ax[k].set_xticklabels([f"$10^{{{i}}}$" for i in np.arange(-5, 6)])
        ax[k].set_xlabel(r"$\eta_{\mathbf{v}} / \eta_{\mathbf{W}}$")
    
    # Y axis
    ten_powers = 10.0 ** np.arange(4, 0, -1)
    indices = []
    for val in ten_powers:
        indices += [np.argmin(np.abs(geo_samples-val))]
    for k in range(2):
        ax[k].set_yticks(indices)
        ax[k].set_yticklabels([f"$10^{{{i}}}$" for i in np.arange(1, 5)])
        ax[k].set_ylabel(r"Iteration $t$")
    
    # Save file
    figname = get_filename(lr, vmax, batch_norm)
    plt.savefig(os.path.join("plots", figname), bbox_inches='tight')
    
    plt.show()
    return

def plot_same_lr(vmax=20, batch_norm=False):
    files = get_files(batch_norm)
    if batch_norm:
        f_ext = "_batch_norm.csv"
        lrs = [f.split('_')[1] for f in files]
    else:
        f_ext = ".csv"
        lrs = [f.split('_')[-1][:-4] for f in files]
    assert len(lrs) == 91, "Not right amount of files"
    lrs_float = [float(l) for l in lrs]
    lrs_float, lrs = zip(*sorted(zip(lrs_float, lrs)))
    risks = np.zeros((len(lrs), 50000))
    losses = np.zeros((len(lrs), 50000))
    for i, lr in enumerate(lrs):
        data = pd.read_csv(os.path.join(RESULTS_DIR, f"lr={lr}_{lr}{f_ext}"), header=None)
        risks[i] = data[0]
        losses[i] = data[1]
    # Subsample epochs
    geo_samples = [int(i) for i in np.geomspace(1, 50000 - 1, num=700)]
    fig, ax = plt.subplots(2,1, sharex=True)
    # Plot
    im1 = ax[0].imshow(risks[:, geo_samples].transpose()[::-1, :], interpolation='none', aspect='auto', vmax=vmax)
    fig.colorbar(im1, ax=ax[0])
    im2 = ax[1].imshow(losses[:, geo_samples].transpose()[::-1, :], interpolation='none', aspect='auto', vmax=vmax)
    fig.colorbar(im2, ax=ax[1])
    
    # Title
    title = fr"Same lr for both layers"
    if batch_norm:
        title += ", with batch normalization"
    plt.suptitle(title)
    
    # Set correct x- and y-ticks
    # X axis
    ten_powers = 10.0 ** np.arange(-7, -1)
    indices = []
    for val in ten_powers:
        indices += [np.argmin(np.abs(lrs_float-val))]
    
    ax[1].set_xticks(indices)
    ax[1].set_xticklabels([f"$10^{{{i}}}$" for i in np.arange(-7, -1)])
    ax[1].set_xlabel("lr")
    
    # Y axis
    ten_powers = 10.0 ** np.arange(4, 0, -1)
    indices = []
    for val in ten_powers:
        indices += [np.argmin(np.abs(geo_samples-val))]
    for k in range(2):
        ax[k].set_yticks(indices)
        ax[k].set_yticklabels([f"$10^{{{i}}}$" for i in np.arange(1, 5)])
        ax[k].set_ylabel(r"Iteration $t$")
    
    figname = get_filename(lr, vmax, batch_norm)
    plt.savefig(os.path.join("plots", figname), bbox_inches='tight')
    
    plt.show()
    
    

In [None]:
plot_results(1e-2, 20, batch_norm=False, plot_nans=False)

In [None]:
plot_individual_run(1e-7, 0.00023299518099999998, batch_norm=False)

In [None]:
for eta_w in [1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2]:
    plot_results(eta_w, 20, plot_nans=False, batch_norm=False)