# (05) Trainers (KABA)

**Motivation**: <br>

In [1]:
# HIDE CODE


import os, sys
from IPython.display import display

# tmp & extras dir
git_dir = os.path.join(os.environ['HOME'], 'Dropbox/git')
extras_dir = os.path.join(git_dir, 'jb-MTMST/_extras')
fig_base_dir = os.path.join(git_dir, 'jb-MTMST/figs')
tmp_dir = os.path.join(git_dir, 'jb-MTMST/tmp')

# GitHub
sys.path.insert(0, os.path.join(git_dir, '_MTMST'))
from vae.train_vae import TrainerVAE, ConfigTrainVAE
from vae.vae2d import VAE, ConfigVAE
from figures.fighelper import *
from analysis.glm import *


# warnings, tqdm, & style
warnings.filterwarnings('ignore', category=DeprecationWarning)
from rich.jupyter import print
%matplotlib inline
set_style()

In [2]:
path = results_dir()
path = pathlib.Path(path)

trainer_paths = set()
pattern = '**/Trainer' # '**/*fixate1*/**/Trainer'
for p in path.rglob(pattern):
    if p.is_dir():
        trainer_paths.add(str(p))
trainer_paths = sorted(trainer_paths)
len(trainer_paths)

34

In [3]:
for fit_path in trainer_paths:
    fit_name = fit_path.split('/')[-2]
    f = pjoin(tmp_dir, 'trainer_analysis', fit_name)
    f = f"{f}.npy"
    if not os.path.isfile(f):
        print(fit_name, os.path.isfile(f))

In [4]:
for fit_path in trainer_paths:
    fit_name = fit_path.split('/')[-2]
    f = pjoin(tmp_dir, 'trainer_analysis', fit_name)
    f = f"{f}.npy"
    if os.path.isfile(f):
        continue

    print('~' * 12 + '   fit name :   ' + '~' * 12)
    print(fit_name)
    print('-' * 40)
    tr, _ = load_model_lite(fit_path, 'cuda')
    all_norms = to_np(torch.cat(
        tr.model.all_lognorm).exp())
    avg_norms = all_norms.mean()

    fig, ax = create_figure(1, 1, (5, 2.5))
    sns.histplot(all_norms, bins=np.linspace(0.5, 1.5, 101), ax=ax)
    ax.axvline(avg_norms, color='r', ls='--', lw=1.2, label=f'avg = {avg_norms:0.3f}')
    title = f"sim = {tr.model.cfg.sim},   " + r"$\beta$ = " + f"{tr.cfg.kl_beta:0.2f}"
    ax.set_title(title, fontsize=12)
    ax.legend()
    plt.show()
    
    val, loss = tr.validate(use_ema=False)
    data_trn, _ = tr.forward('trn', freeze=True, use_ema=False)
    data_vld, _ = tr.forward('vld', freeze=True, use_ema=False)
    data_tst, _ = tr.forward('tst', freeze=True, use_ema=False)
    
    msg = f"{tr.model.cfg.sim}:\tbeta = {tr.cfg.kl_beta},\t"
    msg += f"NELBO: {loss['epe'].mean() + loss['kl'].mean():0.2f}"
    print(msg)
    print({k: v.mean() for k, v in loss.items()})

    # regress
    f = tr.dl_vld.dataset.f + tr.dl_vld.dataset.f_aux
    g = np.concatenate([
        tr.dl_vld.dataset.g,
        tr.dl_vld.dataset.g_aux,
    ], axis=1)
    g_tst = np.concatenate([
        tr.dl_tst.dataset.g,
        tr.dl_tst.dataset.g_aux,
    ], axis=1)

    shape = (len(f), data_vld['z'].shape[1])
    importances_mu = np.zeros(shape)
    importances_sd = np.zeros(shape)

    for i in tqdm(range(len(f)), leave=False):
        _lr = sk_linear.LinearRegression().fit(
            data_vld['z'], g[:, i])
        result = sk_inspect.permutation_importance(
            estimator=_lr,
            X=data_tst['z'],
            y=g_tst[:, i],
            n_repeats=5,
            random_state=0,
        )
        importances_mu[i] = result.importances_mean
        importances_sd[i] = result.importances_std

    everything = {
        'val': val,
        'loss': loss,
        'data_trn': data_trn,
        'data_vld': data_vld,
        'data_tst': data_tst,
        'importances_mu': importances_mu,
        'importances_sd': importances_sd,
        'f': f,
    }
    save_obj(
        obj=everything,
        file_name=fit_name,
        save_dir=pjoin(tmp_dir, 'trainer_analysis'),
        verbose=True,
        mode='npy',
    )

    torch.cuda.empty_cache()

In [5]:
for fit_path in trainer_paths:
    fit_name = fit_path.split('/')[-2]
    f = pjoin(tmp_dir, 'trainer_analysis', fit_name)
    f = f"{f}.npy"
    if not os.path.isfile(f):
        print(fit_name, os.path.isfile(f))