In [None]:
import sys
import os
import json

import math
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
import numpy as np

sys.path.append('code/')
from linear_utils import is_float

In [None]:

SMALL_SIZE = 10
MEDIUM_SIZE = 12
BIGGER_SIZE = 14

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('axes', titlesize=BIGGER_SIZE)  # fontsize of the figure title


def get_all_files(sweep=None, key_word=""):
    
    if sweep is None:
        os.listdir(RESULTS_DIR)
    else:
        files = os.listdir(os.path.join(RESULTS_DIR, sweep))
        
    files = [file for file in files if ".txt" in file]
    files = [file for file in files if key_word in file]

    return files


def get_run_name(key_word="", val=None):
                    
    if val is not None:
        run_name = key_word + f"_{val}"
    else:
        run_name = key_word

        assert key_word != "", "Empty run name"
        
    return run_name


def get_file(sweep=None, key_word="", val=None):

    files = get_all_files(sweep, key_word)
    
    if val is None:
        file = files[0]
    else:
        file = [file for file in files if file == get_run_name(key_word, val) + ".txt"][0]
    
    return file


def append_id(filename, id):
    return "{0}_{2}.{1}".format(*filename.rsplit('.', 1) + [id])


def get_filename(metric, sweep=None, key_word="", val=None):
    
    name = metric + "_" + key_word
    
    if val is not None:
        name += f"_{val}"
        
    if sweep is not None:
        name = sweep + "_" + name
    
    return name


def get_filename_range(metric, min_val, max_val, sweep=None, key_word=""):
    
    name = metric + key_word
    
    name += f"_min_{min_val}_max_{max_val}"
    
    if sweep is not None:
        name = sweep + "_" + name
    
    return name


In [None]:
def plot_individual_run(metrics, labels, sweep=None, key_word="", val=None, num_its=100001):
    
    assert len(metrics) == len(labels), "Must provide a label for each metric"
    
    file = get_file(sweep, key_word, val)
    
    if sweep is None:
        base_dir = RESULTS_DIR
    else:
        base_dir = os.path.join(RESULTS_DIR, sweep)
    
    with open(os.path.join(base_dir, file), "r") as f:
        # Load the dictionary from the file
        data_dict = json.load(f)
    
        
    geo_samples = [int(i) for i in np.geomspace(1, num_its - 1, num=700)]
    cmap = matplotlib.colormaps['viridis']
    #colors = [cmap((50 + 300 * i) / 1000) for i in range(len(metrics))]
    
    fig, ax = plt.subplots(len(metrics), 1, figsize=(5, len(metrics) * 3), sharex=True)
    for k, metric in enumerate(metrics):
        ax[k].set_xscale('log') # ??
        ax[k].set_ylabel(labels[k])
   
        assert metric in data_dict, "Unknown metric"
        
        data_vec = np.array(data_dict[metric])
        
        ax[k].plot(geo_samples, data_vec[geo_samples],
            #color=colors[k],
            lw=4)
            
            
    ax[-1].set_xlabel("Iterations")
          
    save_dir = "plots"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)  
        
        
    plt.savefig(os.path.join(save_dir, get_filename('_'.join(metrics), sweep, key_word, val) + ".pdf"))
    plt.show()
    



In [None]:
# Function for plotting a selection of curves

def plot_individual_runs_range(metrics, labels, min_val, max_val, ymin=None, ymax=None, sweep=None, key_word="", num_its=100001):
    
    assert len(metrics) == len(labels), "Must provide a label for each metric"
    
    if ymin is not None and not isinstance(ymin, list):
        ymin = [ymin] * len(metrics)
    if ymax is not None and not isinstance(ymax, list):
        ymax = [ymax] * len(metrics)
    
    files = get_all_files(sweep, key_word)
    
    vals = [(file.split(key_word + "_")[-1]).split(".txt")[0] for file in files]
    vals = [float(val) for val in vals if float(val) >= min_val and float(val) <= max_val]
        
    geo_samples = [int(i) for i in np.geomspace(1, num_its - 1, num=700)]
    cmap = matplotlib.colormaps['viridis']
    #colors = [cmap((50 + 300 * i) / 1000) for i in range(len(metrics))]
    
    fig, ax = plt.subplots(len(metrics), 1, figsize=(5, len(metrics) * 3), sharex=True)
    
    if sweep is None:
        base_dir = RESULTS_DIR
    else:
        base_dir = os.path.join(RESULTS_DIR, sweep)
        
    
    for val in vals:
        
        with open(os.path.join(base_dir, get_run_name(key_word, val) + ".txt"), "r") as f:
            # Load the dictionary from the file
            data_dict = json.load(f)

        for k, metric in enumerate(metrics):
                
            assert metric in data_dict, "Unknown metric"

            data_vec = np.array(data_dict[metric])

            ax[k].plot(geo_samples, data_vec[geo_samples],
                #color=colors,
                #label=key_word + f"_{val}",
                lw=4)                
    
    for k, axis in enumerate(ax):
        axis.set_xscale('log') # ??
        axis.set_ylabel(labels[k])
        axis.legend([val for val in vals], loc=1)
        
        if ymin is not None and ymax is not None:
            axis.set_ylim([ymin[k], ymax[k]])

            
    ax[-1].set_xlabel("Iterations")
          
    save_dir = "plots"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)  
        
        
    plt.savefig(os.path.join(save_dir, get_filename_range('_'.join(metrics), min_val, max_val, sweep, key_word) + ".pdf"))
    plt.show()
    



In [None]:
def plot_results(metrics, labels, vmin=0, vmax=50, sweep=None, key_word="", keep_nan=True, plot_nans=False, 
                 num_its=100001, ratio=1, xlabel=None, title=None):

    assert len(metrics) == len(labels), "Must provide a label for each metric"

    if not isinstance(vmin, list):
        vmin = [vmin] * len(metrics)
    if not isinstance(vmax, list):
        vmax = [vmax] * len(metrics)   
    
    files = get_all_files(sweep, key_word)
 
    vals = [(file.split(key_word + "_")[-1]).split(".txt")[0] for file in files]
    vals = list(set(vals))
    vals_float = [float(v) for v in vals]
    vals_float, vals = zip(*sorted(zip(vals_float, vals)))
    
    if sweep is None:
        base_dir = RESULTS_DIR
    else:
        base_dir = os.path.join(RESULTS_DIR, sweep)
    
    metrics_data = np.zeros((len(metrics), len(vals), num_its))
    for i, v in enumerate(vals):

        with open(os.path.join(base_dir, get_run_name(key_word, v) + ".txt"), "r") as f:
            # Load the dictionary from the file
            data_dict = json.load(f)
        
        for k, metric in enumerate(metrics):  
            assert metric in data_dict, "Unknown metric"
            
            metrics_data[k, i, :] = np.array(data_dict[metric])
      
    for k, metric in enumerate(metrics):
        if (~np.isfinite(metrics_data[k, :, :])).any():
            print(r"Value with nan/inf, then some lower val")
            idx=np.nonzero(~(~np.isfinite(metrics_data[k, :, :])).any(axis=-1))[0][-1] + 1
            for j in range(10):
                if plot_nans:
                    plot_individual_run([metric], [labels[k]], sweep, key_word, vals[idx-j])
                else:
                    print(vals[idx-j])
        else:
            print('No nans/inf values')
        
        
    ratios = np.array(vals_float) / ratio
            
    # Subsample epochs
    geo_samples = [int(i) for i in np.geomspace(1, num_its - 1, num=700)]
    fig, ax = plt.subplots(len(metrics), 1, figsize=(5, len(metrics) * 3), sharex=False)
    

    # For setting correct x- and y-ticks
    # X axis
    min_pow_x = math.floor(math.log(min(ratios), 10))
    max_pow_x = math.floor(math.log(max(ratios), 10))
    ten_powers_x = 10.0 ** np.arange(min_pow_x, max_pow_x)
    
    x_indices = []
    for val in ten_powers_x:
        x_indices += [np.argmin(np.abs(ratios-val))]
    
    # Y axis
    max_pow_y = math.floor(math.log(num_its, 10))
    min_pow_y = 0
    print(np.arange(max_pow_y, min_pow_y, -1))
    ten_powers_y = 10.0 ** np.arange(max_pow_y, min_pow_y, -1)
    y_indices = []
    for val in ten_powers_y:
        y_indices += [np.argmin(np.abs(geo_samples[::-1]-val))]
        print(val)
        print(y_indices)
        
    
    metrics_data = np.transpose(metrics_data, axes=(0, 2, 1))
    for k in range(len(metrics)):
        # Plot
        im = ax[k].imshow(metrics_data[k, geo_samples, :][::-1, :], interpolation='none', aspect='auto', vmax=vmax[k], vmin=vmin[k])
        fig.colorbar(im, ax=ax[k])
    
        ax[k].set_title(labels[k])
        
        # x-axis
        ax[k].set_xticks(x_indices)
        ax[k].set_xticklabels([f"$10^{{{i}}}$" for i in np.arange(min_pow_x, max_pow_x)])
        
        if xlabel is None:
            ax[k].set_xlabel(key_word)
        else:
            ax[k].set_xlabel(xlabel)
            
        # y-axis
        ax[k].set_yticks(y_indices)
        ax[k].set_yticklabels([f"$10^{{{i}}}$" for i in np.arange(max_pow_y, min_pow_y, -1)]) 
        ax[k].set_ylabel(r"Iteration $t$")
                
            
    if title is not None:
        plt.suptitle(title)
        
    fig.tight_layout(pad=1.0)
    
    # Save file
    save_dir = "plots"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)  
        
    plt.savefig(os.path.join("plots", get_filename('_'.join(metrics), sweep, key_word)) + ".pdf", bbox_inches='tight')
    
    plt.show()
    return

In [None]:
RESULTS_DIR = "results/two_layer_results_l2/transform_data"#/theoretical" #five_layer_regression_results"

In [None]:
plot_individual_run(["loss", "risk"], ["train MSE", "test MSE"], sweep="lr2lay1rank", key_word="lr1", val=0.00202358965)

In [None]:
plot_individual_runs_range(["loss", "risk"], ["train MSE", "test MSE"], 0.01, 0.1, ymin=[85, 0], ymax=[110, 40], sweep="lr2lay1rank", key_word="lr1")

In [None]:
plot_results(["loss", "risk"], ["train MSE", "test MSE"], vmin=[85, 5], vmax=[110, 20], sweep="lrfactor2lay1rank", key_word="lr2", ratio=0.001, xlabel="lr2 / lr1", title="lr1=0.001")

In [None]:
plot_results(["loss", "risk"], ["train MSE", "test MSE"], vmin=[85, 5], vmax=[105, 15], sweep="teokappadifflr", key_word="kappa", ratio=1, title="lr1=1e-6, lr2=1e-3")