In [None]:
import sys
import os

# Dynamically find the path to `src/`
notebook_dir = os.getcwd()  # Get current working directory
project_root = os.path.abspath(os.path.join(notebook_dir, ".."))  # Go up one level
src_path = os.path.join(project_root, "src")

# Add to sys.path if not already present
if src_path not in sys.path:
    sys.path.append(src_path)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pickle
import torch.nn as nn
from torch.utils.data import DataLoader

from src.data_io import load_runinfo_from_rundir, reload_lossbaseline_dict
from src.data_tools import data_train_test_split_linear, DatasetWrapper, load_dataset
from src.nn_loss_baselines import theory_linear_expected_error, report_dataset_loss
from src.nn_model_base import load_modeltype_from_fname, load_model_from_rundir
from src.nn_train_methods import (train_model, theory_linear_expected_error, loss_if_predict_linalg,
                              loss_if_predict_linalg_shrunken, loss_if_predict_zero)
from src.settings import DIR_MODELS, DIR_OUT, DIR_RUNS, COLOR_TARGET, COLOR_PRED, COLOR_INPUT

In [None]:
"""
Script for analysis of case 0: linear subspace task (with varying context length)
"""

def gamma_star(sigma2_pure_context, sigma2_corruption):
    return 1 / (sigma2_pure_context + sigma2_corruption)


COLOR_TRAIN      = '#4C72B0'
COLOR_TEST       = '#7BC475'
COLOR_TEST_DARK1 = '#64A461'
COLOR_TEST_DARK2 = '#569256'
COLOR_PROJ       = '#BCBEC0'
COLOR_PROJSHRUNK = '#B56BAC'

## Train a model on the linear manifold task
(could alternatively load a model)

In [None]:
# these determine both (A) dataset generation and (B) the model currently
context_len = 500   # this will induce the synthetic dataset via data_train_test_split_linear  - 500
dim_n = 16           # this will induce the synthetic dataset ^^ (try 16/32 to 128)

nn_model = 'TransformerModelV1noresOmitLast'

datagen_case = 0  # {0, 1, 2} -> {linear, GMM clusters, manifold}

seed_dataset = 0  # works here
seed_torch   = 0  # dataset 4 with seed 4 is problem for multitraj 100/16/8 gradflow sgd 0.5

epochs = 100
# defaults:
# - if train size is 8000, set batch_size to 800
# - if train size is 800,  set batch_size to 80
batch_size = 80  # reminder: 'manifold' notebook is using batch size of one (1)

# (in num. batches - default of 4 when bsz was 80, 800 size train set, so 10 batches per epoch)
#   e.g. 1000 batch-per-epoch, then get 10 loss evals per epoch for full_loss_sample_interval = 100
full_loss_sample_interval = 4

optimizer_choice = 'adam'  # sgd or adam

if optimizer_choice == 'sgd':
    optimizer_lr = 0.1  #0.01  # 0.5
    #optimizer_lr = 80*0.1  #1.  400 epochs for spheres case with original params
    # scheduler_kwargs = None
    scheduler_kwargs = dict(milestones=[0.8*int(epochs), 0.9*int(epochs)], gamma=0.1)  # mult by gamma each milestone
else:
    assert optimizer_choice == 'adam'
    optimizer_lr = 1e-2  # default: 1e-2
    scheduler_kwargs = None

flag_save_dataset   = True  # default: False (filesize can be large)
flag_vis_loss       = True
flag_vis_weights    = True
flag_vis_batch_perf = True

# prep dataset generation kwargs
# ################################################################################
num_W_in_dataset = 1000
context_examples_per_W = 1
samples_per_context_example = 1

base_kwargs = dict(
    context_len=context_len,
    dim_n=dim_n,
    num_W_in_dataset=num_W_in_dataset,  # this will be train_plus_test_size (since other kwargs are 1)
    context_examples_per_W=context_examples_per_W,
    samples_per_context_example=samples_per_context_example,
    test_ratio=0.2,
    verbose=True,
    as_torch=True,
    savez_fname=None,  # will be set internally depending on flag_save_dataset
    seed=seed_dataset,
    style_origin_subspace=True,
    style_corruption_orthog=False,
)

# manually modify these case-specific settings if desired
assert datagen_case == 0
linear_datagen_kwargs = base_kwargs | dict(
    sigma2_corruption=1.0,
    sigma2_pure_context=2.0,
    style_subspace_dimensions=8,  # int or 'random'
)

(net, model_fname, io_dict, loss_vals_dict,
 train_loader, test_loader, x_train, y_train, x_test, y_test,
 train_data_subspaces, test_data_subspaces) = train_model(
    nn_model=nn_model,
    restart_nn_instance=None,
    restart_dataset=None,
    train_plus_test_size=num_W_in_dataset * context_examples_per_W * samples_per_context_example,
    context_len=context_len, dim_n=dim_n,
    datagen_case=datagen_case,
    datagen_kwargs=linear_datagen_kwargs,
    batch_size=batch_size,
    seed_torch=seed_torch,
    epochs=epochs,
    optimizer_choice=optimizer_choice, optimizer_lr=optimizer_lr, scheduler_kwargs=scheduler_kwargs,
    full_loss_sample_interval=full_loss_sample_interval,
    flag_save_dataset=flag_save_dataset,
    flag_vis_loss=flag_vis_loss,
    flag_vis_weights=flag_vis_weights,
    flag_vis_batch_perf=flag_vis_batch_perf)

In [None]:
# Main data storage object
datadict_eval_loss = dict()  # will fill up one for each evaluated dim d (plus one for 'train' baselines
datadict_eval_loss['linalg_proj'] = dict()
datadict_eval_loss['linalg_proj_shrunken'] = dict()
datadict_eval_loss['predict_zero'] = dict()

# 1) Load the model we just trained
models_to_load = [ io_dict['dir_base'] ]
model_dirs     = [ a for a in models_to_load] 
model_labels   = ['n=16, d=8']

"""
# 1) Load a previously trained model
# Could alternatively load a pre-existing training run
dir_parent = DIR_RUNS
models_to_load = ['check_ddim_vary_inference']
model_dirs = [dir_parent + os.sep + a for a in models_to_load]
model_labels = ['n=16, d=8']
"""

epoch_to_load = None   # int or None (None means final epoch)

list_of_subspace_dim_d = np.arange(1, 16, dtype=int)  # stop one dim below dim n

# 0 - Settings for dataset gen
criterion = nn.MSELoss()
seed_dataset = 7 # (note: seed 4 was used in init training run)

################################################################################
# params to build data
################################################################################
# parameters that control dataset size / variety
num_W_in_dataset = 1000  #100
context_examples_per_W = 1
test_ratio = None                # if None, all data samples will be in 'train' bucket (test is None)
style_origin_subspace = True
style_corruption_orthog = False
batch_size = num_W_in_dataset    # full batch

#context_lengths_to_eval = None  # if None, will only use the context length from the runinfo file
#context_lengths_to_eval = [30, 50, 100, 200, 500]
context_lengths_to_eval = [30, 40, 50, 60, 80, 100, 200, 500]

# Perform inference on the loaded model (use a trained and untrained model from model directory)
for idx_model, model_str in enumerate(models_to_load):
    dir_run = model_dirs[idx_model]

    datadict_eval_loss[model_str] = dict()
    subdict = datadict_eval_loss[model_str]
    
    # 1A - load the runinfo
    runinfo_dict = load_runinfo_from_rundir(dir_run)
    loaded_dim_n = runinfo_dict['dim_n']
    nn_model_str = runinfo_dict['model']
    epochs = runinfo_dict['epochs']
    loaded_dim_d = runinfo_dict['style_subspace_dimensions']
    sigma2_pure_context = runinfo_dict['sigma2_pure_context']
    sigma2_corruption = runinfo_dict['sigma2_corruption']
    style_corruption_orthog = runinfo_dict['style_corruption_orthog']
    style_origin_subspace = runinfo_dict['style_origin_subspace']
    full_loss_sample_interval = runinfo_dict['full_loss_sample_interval']
    optimizer_choice = runinfo_dict['optimizer_choice']
    context_length = runinfo_dict['context_len']

    if context_lengths_to_eval is None:
        context_lengths_to_eval = [context_length]

    subdict['loaded_dim_d'] = loaded_dim_d
    subdict['loaded_dim_n'] = loaded_dim_n

    # debugging override
    #context_length = 30  # 30= is smallest context len that satisfies internal asserts (somewhat arbitrary)

    # compute gamma_star from runinfo dict parameters
    gamma_star_val = gamma_star(sigma2_pure_context, sigma2_corruption)

    # 1B - load the model
    model_fname = runinfo_dict['fname']
    nn_model = runinfo_dict['model']
    LOADED_NETWORK = load_model_from_rundir(dir_run, epoch_int=epoch_to_load)  # load final epoch weights
    # for debugging, load the untrained model
    LOADED_NETWORK_UNTRAINED = load_model_from_rundir(dir_run, epoch_int=0)  # load init epoch (0) weights

    # 1C - load the baselines from data_for_replot dir
    dir_replot = dir_run + os.sep + 'data_for_replot'
    lossbaseline_dict = reload_lossbaseline_dict(dir_replot)

    # 1D - Load the dataset used during training
    loaded_x_train, loaded_y_train, loaded_x_test, loaded_y_test = load_dataset(
        dir_run + os.sep + 'training_dataset_split.npz', as_torch=True)
    loaded_trainloader = DataLoader(DatasetWrapper(loaded_x_train, loaded_y_train), batch_size=batch_size, shuffle=True)
    loaded_testloader = DataLoader(DatasetWrapper(loaded_x_test, loaded_y_test), batch_size=batch_size, shuffle=True)

    # 2) Populate the dictionary: datadict_eval_loss
    # - fill values for 'train' key
    subdict['trained'] = dict()
    # - fill values for 'untrained' key
    subdict['untrained'] = dict()
    # - then fill values for each evaluated dim d (as an in-context learning demo)
    
    subdict['trained']['vals_loss_predict_zero'] = lossbaseline_dict['dumb_A_mse_on_train']
    subdict['trained']['vals_loss_theory_linalg'] = lossbaseline_dict['heuristic_mse_on_train']
    subdict['trained']['vals_loss_theory_linalg_shrunken'] = lossbaseline_dict['heuristic_mse_shrunken_on_train']
    subdict['trained']['eq_expected_error_linalg_shrunken'] = theory_linear_expected_error(
        loaded_dim_n, loaded_dim_d, sigma2_corruption, sigma2_pure_context)
    subdict['trained']['vals_loss_predict_zero'] = lossbaseline_dict['dumb_A_mse_on_train']

    with open(dir_replot + os.sep + 'loss_vals_dict.pkl', 'rb') as handle:
        loss_vals_dict = pickle.load(handle)
    subdict['trained']['loss_vals_dict'] = loss_vals_dict

    subdict['trained']['reload_train_and_eval'] = report_dataset_loss(LOADED_NETWORK, criterion, loaded_trainloader, 'RELOAD train', print_val=True)
    subdict['trained']['reload_test_and_eval'] = report_dataset_loss(LOADED_NETWORK, criterion, loaded_testloader,    'RELOAD test', print_val=True)


    def gen_dataset_given_dim_d(dim_d_inference, context_len_select, seed_dataset):

        # Note: test_ratio is None, so all data samples will be in 'train' bucket (foo_test are all None)
        X_train, Y_train, X_test, Y_test, train_data_subspaces, test_data_subspaces = data_train_test_split_linear(
            context_len=context_len_select,
            dim_n=loaded_dim_n,
            num_W_in_dataset=num_W_in_dataset,
            context_examples_per_W=context_examples_per_W,
            samples_per_context_example=1,
            test_ratio=None,
            style_subspace_dimensions=dim_d_inference,
            style_corruption_orthog=style_corruption_orthog,
            style_origin_subspace=style_origin_subspace,
            sigma2_corruption=sigma2_corruption,
            sigma2_pure_context=sigma2_pure_context,
            verbose=True,
            savez_fname=None,
            seed=seed_dataset)

        # specify training and testing datasets
        subspaces_d_dataset = DatasetWrapper(X_train, Y_train)
        #train_dataset.plot()
        print('dataset.x.shape', subspaces_d_dataset.x.shape)

        assert subspaces_d_dataset.__len__() % batch_size == 0  # want clean multiples to avoid overweight last batch

        # Setup data batching
        nwork = 0
        dataloader_subd = DataLoader(subspaces_d_dataset, batch_size=batch_size, shuffle=True, num_workers=nwork)

        return X_train, Y_train, train_data_subspaces, dataloader_subd

    # next: fill in these empty arrays
    # - for each context length, measure the loss at each value of dim d
    for lval in context_lengths_to_eval:
        subdict['trained'][lval] = dict()
        subdict['untrained'][lval] = dict()
        
        subdict['trained'][lval]['loss_dim_d_inference'] = np.zeros(len(list_of_subspace_dim_d))
        subdict['untrained'][lval]['loss_dim_d_inference'] = np.zeros(len(list_of_subspace_dim_d))
    
        datadict_eval_loss['linalg_proj'][lval] = np.zeros(len(list_of_subspace_dim_d))
        datadict_eval_loss['linalg_proj_shrunken'][lval] = np.zeros(len(list_of_subspace_dim_d))
        datadict_eval_loss['predict_zero'][lval] = np.zeros(len(list_of_subspace_dim_d))
    
        for idx, dim_d_inference in enumerate(list_of_subspace_dim_d):
    
            # 2A) create dataset without saving
            print('\nCreating inference dataset for dim_d_inference=%d...' % dim_d_inference)
            X_train, Y_train, train_data_subspaces, dataloader = gen_dataset_given_dim_d(dim_d_inference, lval, seed_dataset)
            assert len(Y_train) == num_W_in_dataset
    
            # 2B) report dataset loss on the dim d dataset (full)
            print('...checking performance for dim_d_inference=%d...' % dim_d_inference)
            loss_eval_trained = report_dataset_loss(LOADED_NETWORK, criterion, dataloader,
                                                    'trained model (inferenceå)', print_val=True)
            subdict['trained'][lval]['loss_dim_d_inference'][idx] = loss_eval_trained
            # now repeat for the untrained model
            loss_eval_untrained = report_dataset_loss(LOADED_NETWORK_UNTRAINED, criterion, dataloader,
                                                    'not-trained model (inference)',print_val=True)
            subdict['untrained'][lval]['loss_dim_d_inference'][idx] = loss_eval_untrained
            
            # 2C) add extra baselines specific to the dataset
            # baseline: linalg_proj
            datadict_eval_loss['linalg_proj'][lval][idx] = loss_if_predict_linalg(
                criterion, dataloader, 'DATALABEL')
            # baseline: linalg_proj_shrunken
            datadict_eval_loss['linalg_proj_shrunken'][lval][idx] = loss_if_predict_linalg_shrunken(
                criterion, dataloader, 'DATALABEL', sigma2_pure_context, sigma2_corruption, 
                style_origin_subspace=style_origin_subspace, style_corruption_orthog=style_corruption_orthog)
            # baseline: predict_zero
            datadict_eval_loss['predict_zero'][lval][idx] = loss_if_predict_zero(
                criterion, dataloader, 'DATALABEL')

# ========================================================
# PLOTTING - Build the primary plot
# ========================================================
plt.figure(figsize=(6,4))
# main curves

# Perform inference on the loaded model (use a trained and untrained model from model directory)
for idx_model, model_str in enumerate(models_to_load):
    dir_run = model_dirs[idx_model]
    subdict = datadict_eval_loss[model_str]
    suffix = model_labels[idx_model]

    # main curves
    for lval in context_lengths_to_eval:
        plt.plot(list_of_subspace_dim_d, subdict['trained'][lval]['loss_dim_d_inference'],
                 '-o', label='inference L=%d (trained %s)' % (lval, suffix), color='black')
        plt.plot(list_of_subspace_dim_d, subdict['untrained'][lval]['loss_dim_d_inference'],
                 '-o', label='inference L=%d (untrained %s)' % (lval, suffix), color='black', alpha=0.5, markerfacecolor='white')

    # baselines - plot trained model baseline information
    plt.axhline(subdict['trained']['vals_loss_predict_zero'], label=r'predict $0$',
                linestyle='--', alpha=0.75, linewidth=1.5, color='grey')
    plt.axhline(subdict['trained']['vals_loss_theory_linalg'], label=r'$P \tilde x$',
                linestyle='-', alpha=0.75, linewidth=1.5, color=COLOR_PROJ)
    plt.axhline(subdict['trained']['vals_loss_theory_linalg_shrunken'], label=r'$\gamma P \tilde x$',
                linestyle='-', alpha=0.75, linewidth=1.5, color=COLOR_PROJSHRUNK)
    plt.axhline(subdict['trained']['eq_expected_error_linalg_shrunken'], label=r'Expected error at $\theta^*$',
                linestyle=':', alpha=0.75, linewidth=3, color=COLOR_PROJSHRUNK)
    plt.axvline(subdict['loaded_dim_d'], label=r'$d$ (%s)' % suffix,
                linestyle='-', alpha=0.75, linewidth=1.5, color=COLOR_TRAIN)

    # plot final MSE of trained model (train/test set loss)
    plt.axhline(subdict['trained']['loss_vals_dict']['loss_train_interval']['y'][-1],
                label=r'Final MSE train (%s)' % suffix,
                linestyle='-', alpha=1, linewidth=1.5, color=COLOR_TRAIN)
    plt.axhline(subdict['trained']['loss_vals_dict']['loss_test_interval']['y'][-1],
                label=r'Final MSE test (%s)' % suffix,
                linestyle='-', alpha=1, linewidth=1.5, color=COLOR_TEST)

    # baselines - debugging (check MSE on reloaded train dataset)
    plt.axhline(subdict['trained']['reload_train_and_eval'],
                label=r'RELOAD MSE train (%s)' % suffix,
                linestyle='-', alpha=1, linewidth=4.5, color=COLOR_TRAIN, zorder=11)
    plt.axhline(subdict['trained']['reload_test_and_eval'],
                label=r'RELOAD MSE test (%s)' % suffix,
                linestyle='-', alpha=1, linewidth=4.5, color=COLOR_TEST, zorder=11)

plt.xlabel(r'$d$ (subspace dim)')
plt.ylabel(r'$\frac{1}{n}$ MSE')
plt.title(r'Inference with lower and higher subspace dimensions than training set')

plt.legend(ncol=2)
plt.tight_layout()

plt.savefig(DIR_OUT + os.sep + 'inference_varying_dim_d.png')
plt.savefig(DIR_OUT + os.sep + 'inference_varying_dim_d.svg')
plt.show()


# Script is above. Alt plots below

In [None]:
# ========================================================
# PLOTTING - Build the primary plot
# ========================================================
context_len_to_marker = ['^', 'd', 'o', 'o', 'o', 'o', 'o', 'o']

context_len_to_color = [COLOR_TEST, COLOR_TEST_DARK1, COLOR_TEST_DARK2] + [COLOR_TEST_DARK2] * 5

flag_mult_by_n = True        # just for this plot to get nicer scalings
flag_plot_baselines = False  # original: True

assert context_lengths_to_eval == [30, 40, 50, 60, 80, 100, 200, 500]  # dummy check for hardcoded usage in next line
subset_context_lengths_to_eval = [context_lengths_to_eval[a] for a in [0, 2, 7]]

# plot main curves
plt.figure(figsize=(3.25, 4))

ms = 4

arr_subspace_dim_d = np.array(list_of_subspace_dim_d)

# Perform inference on the loaded model (use a trained and untrained model from model directory)
for idx_model, model_str in enumerate(models_to_load):
    
    dir_run = model_dirs[idx_model]
    subdict = datadict_eval_loss[model_str]
    suffix = model_labels[idx_model]
    
    loaded_dim_d = subdict['loaded_dim_d']    
    loaded_dim_n = subdict['loaded_dim_n']

    if flag_mult_by_n:
        scaling = loaded_dim_n
    else: 
        scaling = 1  # i.e. will have no effect
    
    # main curves
    for idx, lval in enumerate(subset_context_lengths_to_eval):

        marker_choice = context_len_to_marker[idx]
        print(idx_model, marker_choice)

        # plot **inference** loss for trained models
        plt.plot(list_of_subspace_dim_d, subdict['trained'][lval]['loss_dim_d_inference'] / arr_subspace_dim_d * scaling,
                 '-', label='inference L=%d (trained %s)' % (lval, suffix),
                 zorder=12, 
                 marker=marker_choice,
                 #color=COLOR_TEST_DARK2, markerfacecolor='white', markersize=ms)
                 color=context_len_to_color[idx], 
                 markeredgecolor='k', markeredgewidth=0.5, 
                 alpha=0.9, markersize=ms)#, markerfacecolor='white', )

        # plot **inference** loss for ---untrained--- models
        plt.plot(list_of_subspace_dim_d, subdict['untrained'][lval]['loss_dim_d_inference'] / arr_subspace_dim_d * scaling,
                 '-', zorder=11, 
                 label='inference L=%d (untrained %s)' % (lval, suffix),
                 marker=marker_choice, markeredgecolor='k', markeredgewidth=0.5, 
                 color='grey', alpha=0.5, markersize=ms)#, markerfacecolor='white', )

        if idx_model == 0 and flag_plot_baselines:
            # do not need to plot the curves below more than once, but don't want extra loop for now
            plt.plot(list_of_subspace_dim_d, datadict_eval_loss['linalg_proj'][lval] / arr_subspace_dim_d * scaling,
                     '-', label='baseline: proj L=%d' % (lval),
                     marker=marker_choice, markeredgecolor='k', markeredgewidth=0.5, 
                     color=COLOR_PROJ, alpha=0.5, markersize=ms)#, markerfacecolor='white', )
            # skip - linalg_proj_shrunken
            plt.plot(list_of_subspace_dim_d, datadict_eval_loss['linalg_proj_shrunken'][lval] / arr_subspace_dim_d * scaling,
                     '-', label='baseline: proj-shrunk L=%d' % (lval),
                     marker=marker_choice, markeredgecolor='k', markeredgewidth=0.5, 
                     color=COLOR_PROJSHRUNK, alpha=0.5, markerfacecolor=COLOR_PROJSHRUNK, markersize=ms)
     
    # plot final MSE of trained model (train/test set loss)
    loaded_model_final_loss_train = subdict['trained']['loss_vals_dict']['loss_train_interval']['y'][-1] / loaded_dim_d * scaling
    loaded_model_final_loss_test  = subdict['trained']['loss_vals_dict']['loss_test_interval']['y'][-1]  / loaded_dim_d * scaling
    plt.scatter(loaded_dim_d, loaded_model_final_loss_train, color=COLOR_TRAIN, ec='k', marker='o', s=30, zorder=15)
    plt.scatter(loaded_dim_d, loaded_model_final_loss_test, color=COLOR_TEST, ec='k', marker='o', s=30, zorder=15)

    loaded_model_init_loss_train = subdict['trained']['loss_vals_dict']['loss_train_interval']['y'][0] / loaded_dim_d * scaling
    loaded_model_init_loss_test  = subdict['trained']['loss_vals_dict']['loss_test_interval']['y'][0]  / loaded_dim_d * scaling
    plt.scatter(loaded_dim_d, loaded_model_init_loss_train, color='grey', ec='k', marker='o', s=30, zorder=15)

    # vert line showing dim-d where model was trained
    plt.axvline(subdict['loaded_dim_d'], label=r'$d$ (%s)' % suffix,
                linestyle='--', alpha=0.75, linewidth=1.5, color=COLOR_TRAIN)

# BASELINE - linear MMSE estimator
dense_dspace = np.linspace(0.1, 16, 100)
eq_theory_baseline_dense_dspace = [
    theory_linear_expected_error(16, DD, sigma2_corruption, sigma2_pure_context) / float(DD) * scaling for DD in dense_dspace
]

# BASELINE - predict 0 -> E[xx.T] = d * sigma_z^2 
dense_dspace = np.linspace(0, 2 * loaded_dim_n, 100)
eq_predict_0_baseline_dense_dspace = sigma2_pure_context / loaded_dim_n * scaling * np.ones_like(dense_dspace)

plt.plot(dense_dspace, eq_theory_baseline_dense_dspace, label=r'Expected error at $\theta^*$',
         color=COLOR_PROJSHRUNK, linestyle=':', alpha=0.99, linewidth=2)
#plt.axhline(eq_theory_baseline_dense_dspace[0], label='Expected error at $\theta^*$',
#            color=COLOR_PROJSHRUNK, linestyle='-', alpha=0.99, linewidth=3)
plt.plot(dense_dspace, eq_predict_0_baseline_dense_dspace, label='Predict 0', 
         color='grey', linestyle='--', alpha=0.99, linewidth=2)


plt.xlabel(r'$d$ (subspace dim)')
if flag_mult_by_n:
    plt.ylabel(r'$d^{-1}$ MSE')
else:
    plt.ylabel(r'$d^{-1}$ $n^{-1}$ MSE')

#plt.legend(fontsize=6)
plt.title(r'Inference with lower and higher subspace dimensions than training set')
#plt.legend(ncol=2)

# remove right and top spine
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)

plt.xlim(-0.5, loaded_dim_n+0.3)
plt.ylim(0.0, 2.24)

plt.tight_layout()

plt.savefig(DIR_OUT + os.sep + 'inference_varying_dim_d_normbyd.png')
plt.savefig(DIR_OUT + os.sep + 'inference_varying_dim_d_normbyd.svg')
plt.show()
plt.show()