In [None]:
import sys
import os

import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
import numpy as np

sys.path.append('code/')
from linear_utils import is_float

In [None]:
# RESULTS_DIR = os.path.join("P:/early_stopping_double_descent", "...") # enter folder direction here (i.e., two_layer_results)
RESULTS_DIR = "results/two_layer_results_l2/transform_data/theoretical" #five_layer_regression_results"

SMALL_SIZE = 10
MEDIUM_SIZE = 12
BIGGER_SIZE = 14

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('axes', titlesize=BIGGER_SIZE)  # fontsize of the figure title


def get_all_files(sweep=None, key_word=""):
    
    if sweep is None:
        os.listdir(RESULT_PATH)
    else:
        files = os.listdir(os.path.join(RESULT_PATH, sweep))
        
    files = [file for file in files if ".txt" in file]
    files = [file for file in files if key_word in file]

    return files


def get_run_name(key_word="", val=None):
                    
    if val is not None:
        run_name = key_word + f"_{val}"
    else:
        run_name = key_word

        assert key_word != "", "Empty run name"
        
    return run_name


def get_file(sweep=None, key_word="", val=None):

    files = get_all_files(sweep, key_word)
    
    if val is None:
        file = files[0]
    else:
        file = [file for file in files if file == get_run_name(key_word, val)][0]
    
    return file


def append_id(filename, id):
    return "{0}_{2}.{1}".format(*filename.rsplit('.', 1) + [id])


def get_filename(metric, sweep=None, key_word="", val=None):
    
    name = metric + key_word
    
    if val is not None:
        name += f"_{val}"
        
    if sweep is not None:
        name = sweep + "_" + name
    
    return name


def get_filename_range(metric, min_val, max_val, sweep=None, key_word=""):
    
    name = metric + key_word
    
    name += f"_min_{min_val}_max_{max_val}"
    
    if sweep is not None:
        name = sweep + "_" + name
    
    return name


In [None]:
def plot_individual_run(metrics, labels, sweep=None, key_word="", val=None):
    
    assert len(metrics) == len(labels), "Must provide a label for each metric"
    
    file = get_file(sweep, key_word, val)
    
    with open(file, "r") as f:
        # Load the dictionary from the file
        data_dict = json.load(f)
    
        
    geo_samples = [int(i) for i in np.geomspace(1, len(data) - 1, num=700)]
    cmap = matplotlib.colormaps['viridis']
    #colors = [cmap((50 + 300 * i) / 1000) for i in range(len(metrics))]
    
    fig, ax = plt.subplots(len(metrics), 1, figsize=(5, 6), sharex=True)
    for k, metric in enumerate(metrics):
        ax[k].set_xscale('log') # ??
        ax[k].set_ylabel(labels[k])
   
        assert metric in data_dict, "Unknown metric"
        
        data = data_dict[metric]
        
        for data_vec in data:
            ax[k].plot(geo_samples, data_vec[geo_samples],
                #color=colors[k],
                lw=4)
            
            
    ax[-1].set_xlabel("Iterations")
          
    save_dir = "plots"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)  
        
        
    plt.savefig(os.path.join(save_dir, get_filename('_'.join(metrics), sweep, key_word, val)))
    plt.show()
    



In [None]:
# Function for plotting a selection of curves

def plot_individual_runs_range(metrics, labels, min_val, max_val, sweep=None, key_word=""):
    
    assert len(metrics) == len(labels), "Must provide a label for each metric"
    
    files = get_all_files(sweep, key_word, val)
    
    vals = [file.split(key_word + "_")[-1] for file in files]
    vals = [float(val) for val in vals if float(val) >= min_val and float(val) <= max_val]
        
    geo_samples = [int(i) for i in np.geomspace(1, len(data) - 1, num=700)]
    cmap = matplotlib.colormaps['viridis']
    #colors = [cmap((50 + 300 * i) / 1000) for i in range(len(metrics))]
    
    fig, ax = plt.subplots(len(metrics), 1, figsize=(5, 6), sharex=True)
    
    if sweep is None:
        base_dir = RESULT_PATH
    else:
        base_dir = os.path.join(RESULT_PATH, sweep)
        
    
    for val in vals:
        
        with open(os.path.join(base_dir, get_run_name(key_word, val) + ".txt"), "r") as f:
            # Load the dictionary from the file
            data_dict = json.load(f)

        for k, metric in enumerate(metrics):
                
            assert metric in data_dict, "Unknown metric"

            data_vec = data_dict[metric]

            ax[k].plot(geo_samples, data_vec[geo_samples],
                #color=colors,
                label=key_word + f"_{val}",
                lw=4)                
    
    for axis in ax:
        axis.set_xscale('log') # ??
        axis.set_ylabel(labels[k])
        axis.legend()

            
    ax[-1].set_xlabel("Iterations")
          
    save_dir = "plots"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)  
        
        
    plt.savefig(os.path.join(save_dir, get_filename_range('_'.join(metrics), min_val, max_val, sweep, key_word)))
    plt.show()
    



In [None]:
def plot_results(metrics, labels, min_val, max_val, sweep=None, key_word="", vmax=20, vmin=0, keep_nan=True, plot_nans=False, 
                 num_its=50000, ratio=1):

    assert len(metrics) == len(labels), "Must provide a label for each metric"

    if not isinstance(vmax, list):
        vmax = [vmax] * len(metrics)
    if not isinstance(vmin, list):
        vmin = [vmin] * len(metrics)
        
    
    files = get_all_files(sweep, key_word, val)
 
    vals = [file.split(key_word + "_")[-1] for file in files]
    vals = list(set(vals))
    vals_float = [float(v) for v in vals]
    vals_float, vals = zip(*sorted(zip(vals_float, vals)))
    
    metrics_data = np.zeros((len(metrics), len(vals), num_its))
    for i, v in enumerate(vals):

        with open(os.path.join(base_dir, get_run_name(key_word, val) + ".txt"), "r") as f:
            # Load the dictionary from the file
            data_dict = json.load(f)
        
        for k, metric in enumerate(metrics):  
            assert metric in data_dict, "Unknown metric"
            
            metrics_data[k, i, :] = data_dict[metric]
      
    for k, metric in enumerate(metrics):
        if (~np.isfinite(metrics_data[k, :, :])).any():
            print(r"Value with nan/inf, then some lower val")
            idx=np.nonzero(~(~np.isfinite(metrics_data[k, :, :])).any(axis=-1))[0][-1] + 1
            for j in range(10):
                if plot_nans:
                    plot_individual_run([metric], [labels[k]], sweep, key_word, vals[idx-j])
                else:
                    print(vals[idx-j])
        else:
            print('No nans/inf values')
        
        
    ratios = np.array(vals_float) / ratio
            
    # Subsample epochs
    geo_samples = [int(i) for i in np.geomspace(1, num_its - 1, num=700)]
    fig, ax = plt.subplots(len(metirces), 1, sharex=True)
    

    # For setting correct x- and y-ticks
    # X axis
    min_pow_x = -5
    max_pow_x = 6
    ten_powers_x = 10.0 ** np.arange(min_pow_x, max_pow_x)
    
    x_indices = []
    for val in ten_powers_x:
        x_indices += [np.argmin(np.abs(ratios-val))]
    
    # Y axis
    min_pow_y = math.floor(math.log(number, 10))
    max_pow_y = 0
    ten_powers_y = 10.0 ** np.arange(4, 0, -1)
    y_indices = []
    for val in ten_powers_y:
        y_indices += [np.argmin(np.abs(geo_samples-val))]
        
    
    for k in range(len(metrices)):
        # Plot
        im = ax[0].imshow(metrics_data[k, :, geo_samples].transpose()[::-1, :], interpolation='none', aspect='auto', vmax=vmax[k], vmin=vmin[k])
        fig.colorbar(im, ax=ax[k])
    
        ax[k].set_title(labels[k])
        
        # x-axis
        ax[k].set_xticks(indices)
        ax[k].set_xticklabels([f"$10^{{{i}}}$" for i in np.arange(min_pow_x, max_pow_x)])
        ax[k].set_xlabel(key_word)

        # y-axis
        ax[k].set_yticks(y_indices)
        ax[k].set_yticklabels([f"$10^{{{i}}}$" for i in np.arange(min_pow_y + 1, max_pow_y + 1)]) # +1 ??
        ax[k].set_ylabel(r"Iteration $t$")
                
    
    # Save file
    save_dir = "plots"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)  
        
    plt.savefig(os.path.join("plots", get_filename('_'.join(metrics), sweep, key_word, val)), bbox_inches='tight')
    
    plt.show()
    return

In [None]:
plot_individual_run()

In [None]:
plot_individual_run_range()

In [None]:
plot_results()