In [None]:
import pandas as pd
from matplotlib import rcParams
import matplotlib.pyplot as plt
import seaborn as sns
from naslib.trial import get_dataset_api

In [None]:
sns.set_style("white")
rcParams['axes.titlepad'] = 15 
rcParams['font.size'] = 9

In [None]:
def histplot_model_performance(data: pd.Series, xlabel="Validation Accuracy"):
    fig, ax = plt.subplots(1,1)
    ax = sns.histplot(data=data, bins=50, ax=ax, shrink=1.0)
    
    plt.xlabel(xlabel, labelpad=1.2)
    plt.ylabel("Frequency")
    plt.tight_layout()
    return fig

In [None]:
search_space = "nasbench201"

Cifar10

In [None]:
threhold = 0
dataset_api = get_dataset_api(search_space_type=search_space, dataset="cifar10")

val_acc_cifar10 = [] 
api = dataset_api['nb201_data']
for model, info in api.items():
    val_acc_cifar10.append(api[model]["cifar10-valid"]["eval_acc1es"][-1])

val_acc_cifar10 = pd.Series(val_acc_cifar10)
val_acc_high_cifar10 = val_acc_cifar10[val_acc_cifar10 > threhold]

In [None]:
fig_cifar10 = histplot_model_performance(data=val_acc_high_cifar10, xlabel="CIFAR10")

cifar100

In [None]:
threhold = 0
dataset_api = get_dataset_api(search_space_type=search_space, dataset="cifar100")

val_acc_cifar100 = [] 
api = dataset_api['nb201_data']
for model, info in api.items():
    val_acc_cifar100.append(api[model]["cifar100"]["eval_acc1es"][-1])

val_acc_cifar100 = pd.Series(val_acc_cifar100)
val_acc_high_cifar100 = val_acc_cifar100[val_acc_cifar100 > threhold]

In [None]:
fig_cifar100 = histplot_model_performance(data=val_acc_high_cifar100, xlabel="CIFAR100")

ImageNet16-120

In [None]:
threhold = 0
dataset_api = get_dataset_api(search_space_type=search_space, dataset="ImageNet16-120")

val_acc_imgnet = [] 
api = dataset_api['nb201_data']
for model, info in api.items():
    val_acc_imgnet.append(api[model]["ImageNet16-120"]["eval_acc1es"][-1])

val_acc_imgnet = pd.Series(val_acc_imgnet)
val_acc_high_imgnet = val_acc_imgnet[val_acc_imgnet >= threhold]

In [None]:
fig_imgnet = histplot_model_performance(data=val_acc_high_imgnet)

In [None]:
import matplotlib.pyplot as plt

# Create a new figure with 1 row, 3 columns
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# cifar 10
sns.histplot(data=val_acc_high_cifar10, bins=50, ax=axes[0], shrink=1.0)
axes[0].set_xlabel("CIFAR10", labelpad=10)
axes[0].set_ylabel("Frequency")

# cifar 100
sns.histplot(data=val_acc_high_cifar100, bins=50, ax=axes[1], shrink=1.0)
axes[1].set_xlabel("CIFAR100", labelpad=10)
axes[1].set_ylabel("Frequency")

# cifar 100
sns.histplot(data=val_acc_high_imgnet, bins=50, ax=axes[2], shrink=1.0)
axes[2].set_xlabel("ImageNet16-120", labelpad=10)
axes[2].set_ylabel("Frequency")

# Adjust layout
plt.tight_layout()
plt.show()

In [None]:
output_dir = "/Users/chengchen/GitHub/master_thesis/report/thesis/figs"
fig.savefig(output_dir + f"/nas_bench_201_val_acc.pdf")