In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import logging

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rc

from analysis import load_configs, load_experimental_results, get_stats
from factorization.config import IMAGE_DIR, USETEX
from factorization.models.mlp import Model

logging.basicConfig(level=logging.INFO)


rc("font", family="serif", size=8)
rc("text", usetex=USETEX)
if USETEX:
    rc("text.latex", preamble=r"\usepackage{times}")


def get_names(name):
    match name:
        case "iid":
            return "iid", ["input_factors", "output_factors", "alphas", "batch_size"]
        case "iid2":
            return "iid2", ["output_factors", "alphas", "batch_size"]
        case "compression":
            return "compression", ["input_factors", "output_factors", "emb_dim", "nb_layers"]
        case "compression2":
            return "compression2", ["output_factors", "emb_dim"]
        case "compression3":
            return "compression3", ["output_factors", "emb_dim"]
        case "factor":
            return "exp1_factor", ["input_factors", "output_factors"]
        case "input":
            return "exp1_input", ["input_factors", "emb_dim"]
        case "input2":
            return "exp1_input2", ["input_factors"]
        case "dim":
            return "exp1_dim", ["emb_dim"]
        case "layer":
            return "exp1_layer", ["nb_layers"]
        case "lr":
            return "exp1_lr", ["learning_rate"]
        case "split":
            return "exp1_split", ["data_split"]
        case "ffn":
            return "exp1_ffn", ["ffn_dim"]
        case "filtration":
            return "filtration", ["bernouilli"]

## IID runs

In [None]:
name = "iid2"
xaxis = "epoch"

kwargs = {
    "alphas": 1e-1,
    "batch_size": 2048,
}

file_path, study_factors = get_names(name)

In [None]:
all_configs = load_configs(file_path)
res = load_experimental_results(all_configs, **kwargs)

In [None]:
all_mean, all_std, keys = get_stats(res, study_factors, xaxis=xaxis, **kwargs)

In [None]:
import ast
from itertools import product
import pandas as pd

mean = pd.concat(all_mean)
mean['output_complexity'] = mean['output_factors'].apply(lambda x: sum(ast.literal_eval(x)))

all_mean = []
keys = ['output_complexity', 'output_factors']
all_vals = [np.sort(mean[key].unique()) for key in keys]

for vals in product(*all_vals):
    ind = np.ones(len(mean), dtype=bool)
    for key, val in zip(keys, vals):
        ind &= mean[key] == val
    if ind.sum() > 0:
        all_mean.append(mean[ind])

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(1.75, 1.5))
# fig, ax = plt.subplots(1, 1, figsize=(7, 5))

y_name = "test_loss"
color = 'C0'
legend = []
all_plots = []
ind = 1
for val, val_std in zip(all_mean, all_std):
    ind += 1
    if ind % 2 != 1:
        continue
    a, = ax.plot(val[y_name], linewidth=1)
    all_plots.append(a)
    # leg = ' '.join([f'{key}={val[key].values[0]}' for key in keys])
    leg = rf'$(q_i)=${val[keys[1]].values[0]}'
    legend.append(leg)
ax.set_yscale('log')
ax.set_xscale('log')
ax.set_xlim(2e2, 1e4)
ax.set_ylim(5e-3, 1e0)
ax.set_xlabel("Epoch", fontsize=8)
ax.set_ylabel(r"Loss ${\cal L}$", fontsize=8)
ax.tick_params(axis='both', labelsize=6)
ax.grid()
fig.savefig(IMAGE_DIR / f'{name}.pdf', bbox_inches='tight')

fig, ax = plt.subplots(1, 1, figsize=(.25, 1.5))
ax.axis('off')
leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=6)
fig.savefig(IMAGE_DIR / f'{name}_leg.pdf', bbox_inches='tight')

## Compression runs

Effect of embedding dimension on losses w.r.t. epochs

In [None]:
name = "compression3"
xaxis = "emb_dim"

kwargs = {
    "emb_dim": 8,
}

file_path, study_factors = get_names(name)

In [None]:
all_configs = load_configs(file_path)
res = load_experimental_results(all_configs, **kwargs)

In [None]:
xaxis = "epoch"
all_mean, all_std, keys = get_stats(res, study_factors, xaxis=xaxis, **kwargs)

In [None]:
import ast
from itertools import product
import pandas as pd

mean = pd.concat(all_mean)
mean['output_complexity'] = mean['output_factors'].apply(lambda x: sum(ast.literal_eval(x)))

all_mean = []
keys = ['output_complexity', 'output_factors']
all_vals = [np.sort(mean[key].unique()) for key in keys]

for vals in product(*all_vals):
    ind = np.ones(len(mean), dtype=bool)
    for key, val in zip(keys, vals):
        ind &= mean[key] == val
    if ind.sum() > 0:
        all_mean.append(mean[ind])

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(1.75, 1.5))

y_name = "loss"
color = 'C0'
legend = []
all_plots = []
for val, val_std in zip(all_mean, all_std):
    a, = ax.plot(val[y_name], linewidth=2)
    all_plots.append(a)
    leg = rf'$\sum q_i=${val["output_complexity"].values[0]}'
    legend.append(leg)
ax.set_yscale('log')
ax.set_xscale('log')
ax.set_xlim(1e3, 1e5)
ax.set_xlabel("Epoch", fontsize=8)
ax.set_ylabel(r"Loss ${\cal L}$", fontsize=8)
ax.tick_params(axis='both', labelsize=6)
ax.grid()
fig.savefig(IMAGE_DIR / f'{name}.pdf', bbox_inches='tight')

fig, ax = plt.subplots(1, 1, figsize=(.25, 1.5))
ax.axis('off')
leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=6)
fig.savefig(IMAGE_DIR / f'{name}_leg.pdf', bbox_inches='tight')

Effect of embedding dimension on the loss after 10 000 epoch

In [None]:
name = "compression3"
xaxis = "emb_dim"

kwargs = {
    "final": True,
}

file_path, study_factors = get_names(name)

In [None]:
all_configs = load_configs(file_path)
res = load_experimental_results(all_configs, **kwargs)

In [None]:
all_mean, all_std, keys = get_stats(res, study_factors, xaxis=xaxis, **kwargs)

In [None]:
import ast
from itertools import product
import pandas as pd

mean = pd.concat(all_mean)
mean['output_complexity'] = mean['output_factors'].apply(lambda x: sum(ast.literal_eval(x)))

all_mean = []
keys = ['output_complexity', 'output_factors']
all_vals = [np.sort(mean[key].unique()) for key in keys]

for vals in product(*all_vals):
    ind = np.ones(len(mean), dtype=bool)
    for key, val in zip(keys, vals):
        ind &= mean[key] == val
    if ind.sum() > 0:
        all_mean.append(mean[ind])

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(1.75, 1.5))
# fig, ax = plt.subplots(1, 1, figsize=(10, 7))

y_name = "loss"
color = 'C0'
legend = []
all_plots = []
i = 0
for val, val_std in zip(all_mean, all_std):
    a, = ax.plot(val[y_name], linewidth=1, color=f'C{i}')
    sum_q = val[keys[0]].values[0]
    if i in [0, 3]:
        ax.plot([sum_q -.2, sum_q -.2], [3e-5, 1e-2], linewidth=1, color=f'C{i}', linestyle='--')
    elif i == 7:
        pass
    else:
        ax.plot([sum_q, sum_q], [3e-5, 1e-2], linewidth=1, color=f'C{i}', linestyle='--')
    all_plots.append(a)
    leg = rf'$\sum q_i=${sum_q}'
    legend.append(leg)
    i += 1
ax.set_yscale('log')
ax.set_xlabel(r"Emb. dim. $d$", fontsize=8)
ax.set_ylabel(r"Loss ${\cal L}$", fontsize=8)
ax.tick_params(axis='both', labelsize=6)
ax.grid()

fig.savefig(IMAGE_DIR / f'{name}_emb.pdf', bbox_inches='tight')

fig, ax = plt.subplots(1, 1, figsize=(.25, 1.5))
# fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.axis('off')
leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=6)
# leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=12)
fig.savefig(IMAGE_DIR / f'{name}_emb_leg.pdf', bbox_inches='tight')

Effect of the number of layers and embedding dimension.

In [None]:
name = "compression"
# xaxis = "epoch"
xaxis = "flops"

kwargs = {
    "input_factors": [2, 2, 2, 3, 3, 5],
    "output_factors": [2, 3, 5],
    "nb_layers": 1,
    # "emb_dim": 10,
}

file_path, study_factors = get_names(name)

In [None]:
all_configs = load_configs(file_path)
res = load_experimental_results(all_configs, **kwargs)

In [None]:
res['flops'] = Model.get_flops(res['emb_dim'], res['ffn_dim'], res['nb_layers'], res['output_size'])
res['flops'] *= res['epoch']

In [None]:
all_mean, all_std, keys = get_stats(res, study_factors, xaxis=xaxis, **kwargs)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(1.75, 1.5))
# fig, ax = plt.subplots(1, 1, figsize=(10, 7))

y_name = "loss"
color = 'C0'
legend = []
all_plots = []
i = 0
for val, val_std in zip(all_mean, all_std):
    i += 1
    if i % 3 != 0:
        continue
    a, = ax.plot(val[y_name], linewidth=1)
    all_plots.append(a)
    leg = ' '.join([fr'$d=${val[key].values[0]}' for key in keys])
    legend.append(leg)
ax.set_yscale('log')
ax.set_xscale('log')
# ax.set_xlabel("Epoch", fontsize=8)
ax.set_xlabel("Flop", fontsize=8)
ax.set_ylabel(r"Loss ${\cal L}$", fontsize=8)
ax.tick_params(axis='both', labelsize=6)
ax.grid()

fig.savefig(IMAGE_DIR / f'{name}_{xaxis}.pdf', bbox_inches='tight')

fig, ax = plt.subplots(1, 1, figsize=(.25, 1.5))
# fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.axis('off')
leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=6)
fig.savefig(IMAGE_DIR / f'{name}_{xaxis}_leg.pdf', bbox_inches='tight')

## Connectivity runs

In [None]:
name = "filtration"
xaxis = "epoch"

kwargs = {
    "alphas": 1e-2,
}

file_path, study_factors = get_names(name)

In [None]:
all_configs = load_configs(file_path)
res = load_experimental_results(all_configs, **kwargs)

In [None]:
all_mean, all_std, keys = get_stats(res, study_factors, xaxis=xaxis, **kwargs)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(1.75, 1.5))

y_name = "test_loss"
color = 'C0'
legend = []
all_plots = []
ind = 0
for val, val_std in zip(all_mean, all_std):
    ind += 1
    if ind not in [1, 2, 3, 4, 5, 6, 9, 14]:
        continue
    a, = ax.plot(val[y_name], linewidth=1.5)
    all_plots.append(a)
    leg = ' '.join([rf'$\beta$={val[key].values[0]:.2f}' for key in keys])
    # leg = rf'$\sum q_i=${val[keys[0]].values[0]}'
    legend.append(leg)
ax.set_yscale('log')
ax.set_xscale('log')
ax.set_xlabel("Epoch", fontsize=8)
ax.set_ylabel(r"Test loss ${\cal L}$", fontsize=8)
ax.tick_params(axis='both', labelsize=6)
ax.grid()
fig.savefig(IMAGE_DIR / f'{name}.pdf', bbox_inches='tight')

fig, ax = plt.subplots(1, 1, figsize=(.25, 1.5))
ax.axis('off')
leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=6)
fig.savefig(IMAGE_DIR / f'{name}_leg.pdf', bbox_inches='tight')

## Other Genralization Study

Input factors

In [None]:
name = "input2"
xaxis = "epoch"

kwargs = {
    "bernouilli": 0.2,
}

file_path, study_factors = get_names(name)

In [None]:
all_configs = load_configs(file_path)
res = load_experimental_results(all_configs, **kwargs)

In [None]:
all_mean, all_std, keys = get_stats(res, study_factors, xaxis=xaxis, **kwargs)

In [None]:
import ast
from itertools import product
import pandas as pd

mean = pd.concat(all_mean)
mean['input_complexity'] = mean['input_factors'].apply(lambda x: sum(ast.literal_eval(x)))

all_mean = []
keys = ['input_complexity', 'input_factors']
all_vals = [np.sort(mean[key].unique()) for key in keys]

for vals in product(*all_vals):
    ind = np.ones(len(mean), dtype=bool)
    for key, val in zip(keys, vals):
        ind &= mean[key] == val
    if ind.sum() > 0:
        all_mean.append(mean[ind])

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(1.75, 1.5))
# fig, ax = plt.subplots(1, 1, figsize=(7, 5))

y_name = "test_loss"
color = 'C0'
legend = []
all_plots = []
ind = 0
for val, val_std in zip(all_mean, all_std):
    ind += 1
    if (ind % 8 != 0 and ind < 36) or ind in [39, 40]:
        continue
    a, = ax.plot(val[y_name], linewidth=1.5)
    all_plots.append(a)
    # leg = ' '.join([rf'{key}={val[key].values[0]}' for key in keys])
    leg = rf'$(p_i)=${val[keys[1]].values[0]}'
    legend.append(leg)
ax.set_yscale('log')
ax.set_xscale('log')
# ax.set_yticks([1e-0, 2e0, 3e0])
ax.set_xlabel("Epoch", fontsize=8)
ax.set_ylabel(r"Test loss ${\cal L}$", fontsize=8)
ax.tick_params(axis='both', labelsize=6)
ax.grid()
fig.savefig(IMAGE_DIR / f'{name}.pdf', bbox_inches='tight')

fig, ax = plt.subplots(1, 1, figsize=(.25, 1.5))
# fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.axis('off')
leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=6)
# leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=12)
fig.savefig(IMAGE_DIR / f'{name}_leg.pdf', bbox_inches='tight')

Train/test split

In [None]:
name = "split"
xaxis = "flops"

kwargs = {
}

file_path, study_factors = get_names(name)

In [None]:
all_configs = load_configs(file_path)
res = load_experimental_results(all_configs, **kwargs)

In [None]:
res['flops'] = res['epoch'] * res['data_split']

In [None]:
all_mean, all_std, keys = get_stats(res, study_factors, xaxis=xaxis, **kwargs)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(1.75, 1.5))

y_name = "test_loss"
color = 'C0'
legend = []
all_plots = []
ind = 0
for val, val_std in zip(all_mean, all_std):
    ind += 1
    if ind % 2 == 0:
        continue
    a, = ax.plot(val[y_name], linewidth=1.5)
    all_plots.append(a)
    leg = ' '.join([rf'{key}={val[key].values[0]}' for key in keys])
    leg = rf'{100 * val[keys[0]].values[0]:.0f} %'
    legend.append(leg)
ax.set_yscale('log')
ax.set_xscale('log')
ax.set_xlabel("Flop", fontsize=8)
ax.set_ylabel(r"Test loss ${\cal L}$", fontsize=8)
ax.tick_params(axis='both', labelsize=6)
ax.grid()
fig.savefig(IMAGE_DIR / f'{name}.pdf', bbox_inches='tight')

fig, ax = plt.subplots(1, 1, figsize=(.25, 1.5))
ax.axis('off')
leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=6)
fig.savefig(IMAGE_DIR / f'{name}_leg.pdf', bbox_inches='tight')

Embedding dimension

In [None]:
name = "dim"
xaxis = "flops"

kwargs = {
}

file_path, study_factors = get_names(name)

In [None]:
all_configs = load_configs(file_path)
res = load_experimental_results(all_configs, **kwargs)

In [None]:
res['flops'] = Model.get_flops(res['emb_dim'], res['ffn_dim'], res['nb_layers'], res['output_size'])
res['flops'] *= res['epoch']

In [None]:
all_mean, all_std, keys = get_stats(res, study_factors, xaxis=xaxis, **kwargs)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(1.75, 1.5))

y_name = "test_loss"
color = 'C0'
legend = []
all_plots = []
ind = 0
for val, val_std in zip(all_mean, all_std):
    ind += 1
    if ind % 2 != 0:
        continue
    a, = ax.plot(val[y_name], linewidth=1.5)
    all_plots.append(a)
    leg = ' '.join([rf'{key}={val[key].values[0]}' for key in keys])
    leg = rf'$d=${val[keys[0]].values[0]}'
    legend.append(leg)
ax.set_yscale('log')
ax.set_xscale('log')
ax.set_yticks([6e-1, 1e-0, 2e0, 3e0])
ax.set_xlabel("Flop", fontsize=8)
ax.set_ylabel(r"Test loss ${\cal L}$", fontsize=8)
ax.tick_params(axis='both', labelsize=6)
ax.grid()
fig.savefig(IMAGE_DIR / f'{name}.pdf', bbox_inches='tight')

fig, ax = plt.subplots(1, 1, figsize=(.25, 1.5))
ax.axis('off')
leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=6)
fig.savefig(IMAGE_DIR / f'{name}_leg.pdf', bbox_inches='tight')

FFN dimension

In [None]:
name = "ffn"
xaxis = "flops"

kwargs = {
}

file_path, study_factors = get_names(name)

In [None]:
all_configs = load_configs(file_path)
res = load_experimental_results(all_configs, **kwargs)

In [None]:
res['flops'] = Model.get_flops(res['emb_dim'], res['ffn_dim'], res['nb_layers'], res['output_size'])
res['flops'] *= res['epoch']

In [None]:
all_mean, all_std, keys = get_stats(res, study_factors, xaxis=xaxis, **kwargs)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(1.75, 1.5))

y_name = "test_loss"
color = 'C0'
legend = []
all_plots = []
ind = 0
for val, val_std in zip(all_mean, all_std):
    ind += 1
    if ind % 1 != 0:
        continue
    a, = ax.plot(val[y_name], linewidth=1.5)
    all_plots.append(a)
    leg = ' '.join([rf'{key}={val[key].values[0]}' for key in keys])
    leg = rf'$h=${val[keys[0]].values[0]}'
    legend.append(leg)
ax.set_yscale('log')
ax.set_xscale('log')
ax.set_yticks([6e-1, 1e-0, 2e0, 3e0])
ax.set_xlabel("Flop", fontsize=8)
ax.set_ylabel(r"Test loss ${\cal L}$", fontsize=8)
ax.tick_params(axis='both', labelsize=6)
ax.grid()
fig.savefig(IMAGE_DIR / f'{name}.pdf', bbox_inches='tight')

fig, ax = plt.subplots(1, 1, figsize=(.25, 1.5))
ax.axis('off')
leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=6)
fig.savefig(IMAGE_DIR / f'{name}_leg.pdf', bbox_inches='tight')

Number of layers

In [None]:
name = "layer"
xaxis = "flops"

kwargs = {
}

file_path, study_factors = get_names(name)

In [None]:
all_configs = load_configs(file_path)
res = load_experimental_results(all_configs, **kwargs)

In [None]:
res['flops'] = Model.get_flops(res['emb_dim'], res['ffn_dim'], res['nb_layers'], res['output_size'])
res['flops'] *= res['epoch']

In [None]:
all_mean, all_std, keys = get_stats(res, study_factors, xaxis=xaxis, **kwargs)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(1.75, 1.5))

y_name = "test_loss"
color = 'C0'
legend = []
all_plots = []
ind = 0
for val, val_std in zip(all_mean, all_std):
    ind += 1
    if ind % 1 != 0:
        continue
    a, = ax.plot(val[y_name], linewidth=1.5)
    all_plots.append(a)
    leg = ' '.join([rf'{key}={val[key].values[0]}' for key in keys])
    leg = rf'nb layers={val[keys[0]].values[0]}'
    legend.append(leg)
ax.set_yscale('log')
ax.set_xscale('log')
ax.set_yticks([6e-1, 1e-0, 2e0, 3e0])
ax.set_xlabel("Flop", fontsize=8)
ax.set_ylabel(r"Test loss ${\cal L}$", fontsize=8)
ax.tick_params(axis='both', labelsize=6)
ax.grid()
fig.savefig(IMAGE_DIR / f'{name}.pdf', bbox_inches='tight')

fig, ax = plt.subplots(1, 1, figsize=(.25, 1.5))
ax.axis('off')
leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=6)
fig.savefig(IMAGE_DIR / f'{name}_leg.pdf', bbox_inches='tight')

Learning rate

In [None]:
name = "lr"
xaxis = "epoch"

kwargs = {
}

file_path, study_factors = get_names(name)

In [None]:
all_configs = load_configs(file_path)
res = load_experimental_results(all_configs, **kwargs)

In [None]:
all_mean, all_std, keys = get_stats(res, study_factors, xaxis=xaxis, **kwargs)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(1.75, 1.5))

y_name = "test_loss"
color = 'C0'
legend = []
all_plots = []
ind = 0
for val, val_std in zip(all_mean, all_std):
    ind += 1
    if ind % 2 == 0:
        continue
    a, = ax.plot(val[y_name], linewidth=1.5)
    all_plots.append(a)
    leg = ' '.join([rf'{key}={val[key].values[0]}' for key in keys])
    leg = rf'lr={val[keys[0]].values[0]}'
    legend.append(leg)
ax.set_yscale('log')
ax.set_xscale('log')
ax.set_yticks([6e-1, 1e-0, 2e0, 3e0])
ax.set_xlabel("Flops", fontsize=8)
ax.set_ylabel(r"Test loss ${\cal L}$", fontsize=8)
ax.tick_params(axis='both', labelsize=6)
ax.grid()
fig.savefig(IMAGE_DIR / f'{name}.pdf', bbox_inches='tight')

fig, ax = plt.subplots(1, 1, figsize=(.25, 1.5))
ax.axis('off')
leg = ax.legend(all_plots, legend, loc='center', ncol=1, fontsize=6)
fig.savefig(IMAGE_DIR / f'{name}_leg.pdf', bbox_inches='tight')