# Generalization study

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import logging

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rc

from analysis import load_configs, load_experimental_results, get_stats
from factorization.config import IMAGE_DIR, USETEX
from factorization.models.mlp import Model

logging.basicConfig(level=logging.INFO)


rc("font", family="serif", size=8)
rc("text", usetex=USETEX)
if USETEX:
    rc("text.latex", preamble=r"\usepackage{times}")


def get_names(name):
    match name:
        case "filtration":
            return "filtration", ["beta"]
        case "generalization":
            return "generalization", ["statistical_complexity"]
        case "emb_dim":
            return "convergence_gen", ["emb_dim", "nb_parents"]
        case "split":
            return "split", ["data_split", "nb_parents"]
        case "isoflop":
            return "isoflop", ["data_split", "nb_parents"]
        case "isoflop_long":
            return "isoflop_long", ["data_split", "nb_parents"]

In [None]:
save_fig = False

## Filtration

In [None]:
name = "filtration"
xaxis = "epoch"

kwargs = {
    "input_factors": [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
    # "input_factors": [8, 8, 8, 8],
    "output_factors": [8, 8, 8, 8],
    # "output_factors": [4096],
    "alphas": 1e-1,
}

file_path, study_factors = get_names(name)

In [None]:
all_configs = load_configs(file_path)
res = load_experimental_results(all_configs, **kwargs)

In [None]:
all_mean, all_std, keys = get_stats(res, study_factors, xaxis=xaxis, **kwargs)

In [None]:
if save_fig:
    fig, ax = plt.subplots(1, 1, figsize=(1.75, 1.5))
else:
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    ax = axes[0]

y_name = "test_loss"
legend = []
all_plots = []
for val, val_std in zip(all_mean, all_std):
    a, = ax.plot(val[y_name], linewidth=1.5)
    all_plots.append(a)
    # leg = ' '.join([rf'{key}={val[key].values[0]}' for key in keys])
    leg = rf'$\beta$={val[keys[0]].values[0]:.2f}'
    legend.append(leg)
ax.set_yscale('log')
ax.set_xscale('log')
ax.set_xlabel("Epoch", fontsize=8)
ax.set_ylabel(r"Test loss ${\cal L}_{\text{unobs.}}$", fontsize=8)
ax.tick_params(axis='both', labelsize=6)
ax.grid()
if save_fig:
    fig.savefig(IMAGE_DIR / f'{name}.pdf', bbox_inches='tight')

if save_fig:
    fig, ax = plt.subplots(1, 1, figsize=(.25, 1.5))
    leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=6)
    ax.axis('off')
    fig.savefig(IMAGE_DIR / f'{name}_leg.pdf', bbox_inches='tight')
else:
    ax = axes[1]
    leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=12)
    ax.axis('off')

## Statistical Complexity

In [None]:
name = "generalization"
xaxis = "epoch"

kwargs = {
    # "input_factors": [8, 8, 8, 8],
    "output_factors": [8, 8, 8, 8],
    # "output_factors": [4096],
    "alphas": 1e-1,
}

file_path, study_factors = get_names(name)

In [None]:
all_configs = load_configs(file_path)
res = load_experimental_results(all_configs, **kwargs)

In [None]:
all_mean, all_std, keys = get_stats(res, study_factors, xaxis=xaxis, **kwargs)

In [None]:
if save_fig:
    fig, ax = plt.subplots(1, 1, figsize=(1.75, 1.5))
else:
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    ax = axes[0]

y_name = "test_loss"
legend = []
all_plots = []
for val, val_std in zip(all_mean, all_std):
    a, = ax.plot(val[y_name], linewidth=1.5)
    all_plots.append(a)
    # leg = ' '.join([rf'{key}={val[key].values[0]}' for key in keys])
    leg = rf'$\chi$={val[keys[0]].values[0]}'
    legend.append(leg)
ax.set_yscale('log')
ax.set_xscale('log')
ax.set_xlabel("Epoch", fontsize=8)
ax.set_ylabel(r"Test loss ${\cal L}_{\text{unobs.}}$", fontsize=8)
ax.tick_params(axis='both', labelsize=6)
ax.grid()
if save_fig:
    fig.savefig(IMAGE_DIR / f'{name}.pdf', bbox_inches='tight')

if save_fig:
    fig, ax = plt.subplots(1, 1, figsize=(.25, 1.5))
    leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=6)
    ax.axis('off')
    fig.savefig(IMAGE_DIR / f'{name}_leg.pdf', bbox_inches='tight')
else:
    ax = axes[1]
    leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=12)
    ax.axis('off')

## Effect of embedding dimension

In [None]:
name = "emb_dim"
xaxis = "flops"

kwargs = {
    "scheduler": "custom",
    "nb_parents": 3,
}

file_path, study_factors = get_names(name)

In [None]:
all_configs = load_configs(file_path)
res = load_experimental_results(all_configs, **kwargs)

In [None]:
res['flops'] = Model.get_flops(res['emb_dim'], res['ffn_dim'], res['nb_layers'], res['output_size'])
res['flops'] *= res['epoch'] * res['input_size'] * res['data_split']

In [None]:
all_mean, all_std, keys = get_stats(res, study_factors, xaxis=xaxis, **kwargs)

In [None]:
if save_fig:
    fig, ax = plt.subplots(1, 1, figsize=(1.75, 1.5))
else:
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    ax = axes[0]

y_name = "test_loss"
legend = []
all_plots = []
ind = 0
for val, val_std in zip(all_mean, all_std):
    ind += 1
    if ind > 7 and ind < 11:
        continue
    a, = ax.plot(val[y_name], linewidth=1.5)
    all_plots.append(a)
    # leg = ' '.join([rf'{key}={val[key].values[0]}' for key in keys])
    leg = rf'$d={val[keys[1]].values[0]}$'
    legend.append(leg)
ax.set_yscale('log')
ax.set_xscale('log')
ax.set_xlabel("Flop", fontsize=8)
ax.set_ylabel(r"Test loss ${\cal L}_{\text{unobs.}}$", fontsize=8)
ax.tick_params(axis='both', labelsize=6)
ax.grid()
if save_fig:
    fig.savefig(IMAGE_DIR / f'{name}.pdf', bbox_inches='tight')

if save_fig:
    fig, ax = plt.subplots(1, 1, figsize=(.25, 1.5))
    leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=6)
    ax.axis('off')
    fig.savefig(IMAGE_DIR / f'{name}_leg.pdf', bbox_inches='tight')
else:
    ax = axes[1]
    leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=12)
    ax.axis('off')

## Data split

In [None]:
name = "split"
xaxis = "data_split"

kwargs = {
    "scheduler": "custom",
    "final": True,
}

file_path, study_factors = get_names(name)

In [None]:
all_configs = load_configs(file_path)
res = load_experimental_results(all_configs, **kwargs)

In [None]:
res['flops'] = Model.get_flops(res['emb_dim'], res['ffn_dim'], res['nb_layers'], res['output_size'])
res['flops'] *= res['epoch'] * res['input_size'] * res['data_split']

In [None]:
all_mean, all_std, keys = get_stats(res, study_factors, xaxis=xaxis, **kwargs)

In [None]:
if save_fig:
    fig, ax = plt.subplots(1, 1, figsize=(1.75, 1.5))
else:
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    ax = axes[0]

y_name = "test_loss"
legend = []
all_plots = []
for val, val_std in zip(all_mean, all_std):
    a, = ax.plot(val[y_name], linewidth=1.5)
    all_plots.append(a)
    # leg = ' '.join([rf'{key}={val[key].values[0]}' for key in keys])
    leg = r'$|\operatorname{pa}_j|=' + f'{val[keys[0]].values[0]}$'
    legend.append(leg)
ax.set_yscale('log')
ax.set_xscale('log')
ax.set_xlabel(r"Data split $\gamma$", fontsize=8)
ax.set_ylabel(r"Test loss ${\cal L}_{\text{unobs.}}$", fontsize=8)
ax.tick_params(axis='both', labelsize=6)
ax.grid()
if save_fig:
    fig.savefig(IMAGE_DIR / f'{name}_scaling.pdf', bbox_inches='tight')

if save_fig:
    fig, ax = plt.subplots(1, 1, figsize=(.25, 1.5))
    leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=6)
    ax.axis('off')
    fig.savefig(IMAGE_DIR / f'{name}_scaling_leg.pdf', bbox_inches='tight')
else:
    ax = axes[1]
    leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=12)
    ax.axis('off')

## Data split and flops

In [None]:
name = "split"
xaxis = "flops"

kwargs = {
    "nb_parents": 1,
}

file_path, study_factors = get_names(name)

In [None]:
all_configs = load_configs(file_path)
res = load_experimental_results(all_configs, **kwargs)

In [None]:
res['flops'] = Model.get_flops(res['emb_dim'], res['ffn_dim'], res['nb_layers'], res['output_size'])
res['flops'] *= res['epoch'] * res['input_size'] * res['data_split']

In [None]:
all_mean, all_std, keys = get_stats(res, study_factors, xaxis=xaxis, **kwargs)

In [None]:
if save_fig:
    fig, ax = plt.subplots(1, 1, figsize=(1.75, 1.5))
else:
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    ax = axes[0]

y_name = "test_loss"
legend = []
all_plots = []
ind = 0
for val, val_std in zip(all_mean, all_std):
    ind += 1
    if ind not in [1, 2, 3, 4, 6, 10, 19]:
        continue
    # if ind not in [1, 2, 6, 19]:
    #     continue
    a, = ax.plot(val[y_name], linewidth=1.5)
    all_plots.append(a)
    # leg = ' '.join([rf'{key}={val[key].values[0]}' for key in keys])
    leg = rf'$\gamma={val[keys[0]].values[0]:.2f}$'
    legend.append(leg)
ax.set_yscale('log')
ax.set_xscale('log')
ax.set_xlabel("Flop", fontsize=8)
ax.set_ylabel(r"Test loss ${\cal L}_{\text{unobs.}}$", fontsize=8)
ax.tick_params(axis='both', labelsize=6)
ax.grid()
if save_fig:
    fig.savefig(IMAGE_DIR / f'{name}_{kwargs["nb_parents"]}.pdf', bbox_inches='tight')

if save_fig:
    fig, ax = plt.subplots(1, 1, figsize=(.25, 1.5))
    leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=6)
    ax.axis('off')
    fig.savefig(IMAGE_DIR / f'{name}_leg.pdf', bbox_inches='tight')
else:
    ax = axes[1]
    leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=12)
    ax.axis('off')

## Isoflop curves

In [None]:
all_flops = [3e10, 1e11, 3e11, 1e12, 3e12, 1e13, 3e13, 1e14, 3e14]
# all_flops = [2e10, 4e10, 8e10, 16e10, 32e10]
all_val = np.empty((len(all_mean), len(all_flops)))
all_keys = np.empty(len(all_mean))

for i, mean in enumerate(all_mean):
    all_keys[i] = mean['data_split'].values[0]
    for j, flop in enumerate(all_flops):
        ind = sum((mean.index.values < flop))
        try:
            all_val[i, j] = mean.iloc[ind]['test_loss']
        except IndexError:
            all_val[i, j] = np.nan

In [None]:
if save_fig:
    fig, ax = plt.subplots(1, 1, figsize=(1.75, 1.5))
else:
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    ax = axes[0]

legend = []
all_plots = []
for ind in range(len(all_flops)):
    if ind % 2:
        continue
    a, = ax.plot(all_keys, all_val[:, ind], linewidth=1.5)
    all_plots.append(a)
    # leg = ' '.join([rf'{key}={val[key].values[0]}' for key in keys])
    leg = rf'$F={all_flops[ind]:.0e}X$'.replace('e+', '\cdot 10^{').replace('X', '}')

    legend.append(leg)
ax.set_yscale('log')
ax.set_xscale('log')
ax.set_xlabel(r"Data split $\gamma$", fontsize=8)
ax.set_ylabel(r"Test loss ${\cal L}_{\text{unobs.}}$", fontsize=8)
ax.tick_params(axis='both', labelsize=6)
ax.grid()
if save_fig:
    fig.savefig(IMAGE_DIR / f'isoflop_{kwargs["nb_parents"]}.pdf', bbox_inches='tight')

if save_fig:
    fig, ax = plt.subplots(1, 1, figsize=(.25, 1.5))
    leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=6)
    ax.axis('off')
    fig.savefig(IMAGE_DIR / f'isoflop_leg.pdf', bbox_inches='tight')
else:
    ax = axes[1]
    leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=12)
    ax.axis('off')