In [1]:
import pandas as pd
import numpy as np
import seaborn as sns
import numpy.typing as npt
import matplotlib.pyplot as plt

In [2]:
def plot_empirical_error_convergence(loss: npt.NDArray[np.float_], nn_config: dict, annotate: bool = False) -> None:
    sns.set(style="darkgrid", color_codes=True, rc={"figure.figsize": (8, 5)})
    ax = sns.lineplot(data=loss, markers=True, marker="o", label=f'{nn_config["learning_rate"]}')
    plt.xticks(np.arange(len(loss)))
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Convergence of empirical error")
    plt.legend()
    if annotate:
        for i, j in enumerate(loss):
            ax.annotate(str(np.round(j, 3)), xy=(i, j), xytext=(i, j+0.02))

In [6]:
results: pd.DataFrame = pd.read_json("data/results.json")
results.nlargest(n=6, columns=["accuracy_score"])

Unnamed: 0,accuracy_score,precision_score,recall_score,f1_score,confusion_matrix,random_state_seed,test_size,hidden_layer_size,batch_size,learning_rate,history
28,0.929333,"[0.9111111111111111, 0.9470588235294111, 0.907...","[0.984, 0.9877300613496931, 0.9202898550724631...","[0.9461538461538461, 0.966966966966967, 0.9136...","[[123, 0, 0, 0, 0, 0, 0, 0, 2, 0], [0, 161, 0,...",42,0.3,100,20,1.0,"{'loss': [0.82942134141922, 0.3543525338172910..."
16,0.928,"[0.9389312977099231, 0.9573170731707311, 0.968...","[0.984, 0.9631901840490791, 0.891304347826086,...","[0.9609375000000001, 0.9602446483180421, 0.928...","[[123, 0, 0, 0, 0, 0, 0, 0, 2, 0], [0, 157, 0,...",42,0.3,50,20,1.0,"{'loss': [0.8357468843460081, 0.36331853270530..."
15,0.919333,"[0.9236641221374041, 0.9693251533742331, 0.898...","[0.968, 0.9693251533742331, 0.898550724637681,...","[0.9453124999999991, 0.9693251533742331, 0.898...","[[121, 0, 1, 0, 0, 0, 1, 0, 2, 0], [0, 158, 0,...",42,0.3,50,20,0.5,"{'loss': [0.9272853136062621, 0.42916768789291..."
19,0.914,"[0.891304347826086, 0.9575757575757571, 0.9270...","[0.984, 0.9693251533742331, 0.9202898550724631...","[0.935361216730037, 0.9634146341463411, 0.9236...","[[123, 0, 0, 0, 0, 1, 0, 0, 1, 0], [0, 158, 0,...",42,0.3,50,50,1.0,"{'loss': [1.069212079048156, 0.472264409065246..."
31,0.913333,"[0.8978102189781021, 0.9461077844311371, 0.946...","[0.984, 0.9693251533742331, 0.898550724637681,...","[0.9389312977099231, 0.9575757575757571, 0.921...","[[123, 0, 0, 0, 0, 1, 0, 0, 1, 0], [0, 158, 0,...",42,0.3,100,50,1.0,"{'loss': [1.152988076210022, 0.477073132991790..."
4,0.911333,"[0.9242424242424241, 0.962962962962962, 0.9136...","[0.976, 0.9570552147239261, 0.9202898550724631...","[0.9494163424124511, 0.9599999999999991, 0.916...","[[122, 0, 0, 0, 0, 2, 0, 0, 1, 0], [0, 156, 0,...",42,0.3,25,20,1.0,"{'loss': [0.8276067972183221, 0.37989017367362..."
