In [None]:
from naslib.utils.io import read_json
from pathlib import Path 
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import rcParams
from matplotlib.axes import Axes

In [None]:
sns.set_style("white")
rcParams['axes.titlepad'] = 15
rcParams['font.size'] = 12

In [None]:
QUERY_VAL_ACC = "valid_acc"
QUERY_CAL_ERR = "calibration_score" 
LOG_FILENAME = "errors.json"

def collect_info_all_seeds(folder: Path, filename: str = LOG_FILENAME, query_key: str = QUERY_KEY):
    matches = list(folder.rglob(pattern=f"./seed=*"))

    scores={}
    for p in matches:
        scores[p.name] = read_json(p / filename)[query_key]
    df = pd.DataFrame(scores)
    df.index += 1   # python indexing starts from 0
    df.index.name = "epochs"
    return df.sort_index(axis=1).reset_index()

In [None]:
def plot_single_experiment(path: Path, ax: Axes, label: str | None, query_key: str, ylabel: str):
    df_wide = collect_info_all_seeds(folder=path, query_key=query_key)
    df_long = pd.melt(df_wide, id_vars='epochs')
    
    legend = None if label is None else "auto"
    sns.lineplot(x="epochs", data=df_long, y="value", err_style="band", errorbar="sd", label=label, ax=ax, legend=legend)
    ax.set_xlabel("epochs")
    ax.set_ylabel(ylabel=ylabel)
    return ax

In [None]:
datasets = ["cifar10" , "cifar100", "ImageNet16-120"]
home = Path.home() / "Desktop/Experiments"

def get_label_and_path(dataset): 
    gaussian = home/"acq_search=mutation/nasbench201/" / dataset / "/acq=its/num_to_mutate=2/num_init=10/bananas__ensemble_mlp__gaussian__num_quantiles=10"
    scp_10 = home/"acq_search=mutation/nasbench201/" / dataset  / "/acq=its/num_to_mutate=2/num_init=10/bananas__ensemble_mlp__CP_split__train_cal_split=03__num_quantiles=10"
    scp_30 = home/"acq_search=mutation/nasbench201/" / dataset /  "/acq=its/num_to_mutate=2/num_init=30/bananas__ensemble_mlp__CP_split__train_cal_split=03__num_quantiles=10"
    label_and_path = {
        "uncalibrated": gaussian,
        "scp (init_size=10)": scp_10,
        "scp (init_size=30)": scp_30
    }
    return label_and_path

### Visualise validation accuracy and RMSCE per epoch

In [None]:
fig, axes = plt.subplots(ncols=3, nrows=2, figsize=(15, 7), gridspec_kw={"height_ratios" : [1.5, 2.5]})

for i, dataset in enumerate(datasets):
    label_and_path = get_label_and_path(dataset=dataset)
    for label, path in label_and_path.items():
        plot_single_experiment(path=path, ax=axes[0, i], label=None, query_key=QUERY_CAL_ERR, ylabel="rmsce")
        plot_single_experiment(path=path, ax=axes[1, i], label=label, query_key=QUERY_VAL_ACC, ylabel="validation accuracy")
        axes[1, i].legend(loc="lower right")
plt.tight_layout()

In [None]:
plt.show()

In [None]:
path = Path("/Users/chengchen/Desktop/Experiments/acq_search=mutation/nasbench201/cifar10/acq=its/num_to_mutate=2/num_init=10/bananas__ensemble_mlp__CP_bootstrap__num_ensemble=5__num_quantiles=10_absresidual")

scores_1 = collect_info_all_seeds(path).iloc[-1, 1:]
print(scores_1.mean())
print(scores_1.std())

In [None]:
scores_1.sort_values()