In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib2tikz

In [None]:
sns.set_style("whitegrid")
sns.set_context("paper")
plt.rcParams.update({'axes.labelsize': '22',
                     'xtick.labelsize':'18',
                     'ytick.labelsize': '18',
                     'legend.fontsize': '18',
                     'figure.figsize': (8, 8)})

def increase_linewidth(ax):
    lines = ax.get_lines()
    for line in lines:
        line.set_linewidth(3)
    leg = ax.legend()
    leg_lines = leg.get_lines()
    plt.setp(leg_lines, linewidth=5)

In [None]:
question_type = "mass"
directory = "features_plots/"
stats = pd.read_hdf(directory+question_type+"_stats.h5")

In [None]:
ax = sns.lineplot(x="Epoch", y="Loss", hue="features",  markers=True, data=stats)
increase_linewidth(ax)
plt.savefig(directory+question_type+"_losses.pdf")

stats.groupby("features").max()[["Train Accuracy", "Val Accuracy"]].plot.bar()
plt.grid()
plt.ylabel("Accuracy")
plt.xlabel("")
plt.savefig(directory+question_type+"_acc_hist.pdf")

for label in stats.features.unique():
    sns.lineplot(x="Epoch", y="Val Accuracy", data=stats[stats.features == label])
    sns.lineplot(x="Epoch", y="Train Accuracy", data=stats[stats.features == label])
    
    plt.title(label)
    plt.legend(labels=["Validation", "Training"], loc=2)
    plt.ylabel("Accuracy")
    plt.grid()
    matplotlib2tikz.save(directory+question_type+"_"+label+"_acc_plot.tikz")
    plt.show()

In [None]:
ax = sns.lineplot(x="Epoch", y="Val Accuracy", hue="features", data=stats)
increase_linewidth(ax)
plt.savefig(directory+question_type+"_all_val_acc_plot.pdf")
plt.show()
ax = sns.lineplot(x="Epoch", y="Train Accuracy", hue="features", data=stats)
increase_linewidth(ax)
plt.savefig(directory+question_type+"_all_train_acc_plot.pdf")