# Single pass study

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import logging

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rc

from analysis import load_configs, load_experimental_results, get_stats
from factorization.config import IMAGE_DIR, USETEX
from factorization.models.mlp import Model

logging.basicConfig(level=logging.INFO)


rc("font", family="serif", size=8)
rc("text", usetex=USETEX)
if USETEX:
    rc("text.latex", preamble=r"\usepackage{times}")


def get_names(name):
    match name:
        case "iid":
            return "iid", ["statistical_complexity"]

In [None]:
save_fig = False

In [None]:
name = "iid"
xaxis = "epoch"

kwargs = {
    # "nb_parents": 3,
    "alphas": 1e-1,
}

file_path, study_factors = get_names(name)

In [None]:
all_configs = load_configs(file_path)
res = load_experimental_results(all_configs, **kwargs)

In [None]:
all_mean, all_std, keys = get_stats(res, study_factors, xaxis=xaxis, **kwargs)

In [None]:
if save_fig:
    fig, ax = plt.subplots(1, 1, figsize=(1.75, 1.5))
else:
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    ax = axes[0]

y_name = "test_loss"
legend = []
all_plots = []
ind = 0
for val, val_std in zip(all_mean, all_std):
    ind += 1
    # if ind % 2 == 0:
        # continue
    a, = ax.plot(val[y_name], linewidth=1.5)
    all_plots.append(a)
    # leg = ' '.join([rf'{key}={val[key].values[0]}' for key in keys])
    leg = r'$\chi=$' + f'{val[keys[0]].values[0]}'
    legend.append(leg)
ax.set_yscale('log')
ax.set_xscale('log')
ax.set_xlabel("Epoch", fontsize=8)
ax.set_ylabel(r"Loss ${\cal L}$", fontsize=8)
ax.tick_params(axis='both', labelsize=6)
ax.grid()
if save_fig:
    fig.savefig(IMAGE_DIR / f'{name}.pdf', bbox_inches='tight')

if save_fig:
    fig, ax = plt.subplots(1, 1, figsize=(.25, 1.5))
    leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=6)
    ax.axis('off')
    fig.savefig(IMAGE_DIR / f'{name}_leg.pdf', bbox_inches='tight')
else:
    ax = axes[1]
    leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=12)
    ax.axis('off')