In [1]:
from collections import OrderedDict
import re
import os

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
from tqdm.notebook import tqdm
import numpy as np
import matplotlib as mpl

from eval import get_run_metrics, read_run_dir, get_model_from_run
from samplers import get_data_sampler
from tasks import get_task_sampler, HaarWavelets
from sklearn.linear_model import LinearRegression, Lasso, LassoCV
from sklearn.preprocessing import PolynomialFeatures
from train import get_all_deg2_term_indices


sns.set_theme('notebook', 'darkgrid')
palette = sns.color_palette('colorblind')
mpl.rcParams['figure.dpi'] = 300

run_dir = "../models"

### Reference solver functions


In [10]:

def get_linear_regression_rerefence_errors(xs, ys):
    # Least Squares Optimization
    lsq_errors = []
    for i in tqdm(range(1, xs.shape[1])): # 1-19
        preds = []
        for batch_id in range(xs.shape[0]): # 0-63
            preds.append(
            # fit n_points -1 regressors for each entry in batch
            LinearRegression(fit_intercept = False).fit(xs[batch_id,:i], ys[batch_id,:i])\
                .predict(xs[batch_id,i:i+1])[0]
            )
        preds = np.array(preds).squeeze()
        lsq_errors.append(((ys[:,i] - preds)**2).mean(axis = 0).numpy())
    return np.array(lsq_errors)


def transform_features_to_only_deg2_monomials(data):
    # data -> n_points x n_dim_input
    n_dim_inp = data.shape[1]
    trans = PolynomialFeatures(degree=2, interaction_only=True)
    data_interac_only = trans.fit_transform(data) # n_points x no_of_interac_only_terms
    data_sq_and_interac = np.concatenate(
        ((data_interac_only[:,1:(n_dim_inp+1)])**2, data_interac_only[:,(n_dim_inp+1):]),
        axis=1) # n_points x no_of_self_sq_and_interac_only_terms
    return data_sq_and_interac


def transform_features_to_upto_deg2_monomials(data):
    # data -> n_points x n_dim_input
    n_dim_inp = data.shape[1]
    trans = PolynomialFeatures(degree=2)
    data_upto_deg2_feats = trans.fit_transform(data) # n_points x no of deg 0, deg1 and deg 2 terms
    # data_interac_only = trans.fit_transform(data) # n_points x no_of_interac_only_terms
    # data_sq_and_interac = np.concatenate(
    #     ((data_interac_only[:,1:(n_dim_inp+1)])**2, data_interac_only[:,(n_dim_inp+1):]),
    #     axis=1) # n_points x no_of_self_sq_and_interac_only_terms
    return data_upto_deg2_feats

def transform_features_to_only_fixedS_deg2_monomials(data, fixedS):
    data_fixedS_feats = data[:,fixedS[:,0]]*data[:,fixedS[:,1]]
    return data_fixedS_feats # n_points x fixedS deg 2 terms


def get_polynomial_regression_only_deg2_monomials_fixedS_reference_errors(xs, ys, batch_fixedS):
    # xs -> batch x n_points x n_dim_input
    # ys -> batch x n_points
    # fixedS -> [[1,3], [4,5],..] |S| x 2 -- monomial term indices
    # Least Squares Optimization
    lsq_errors = []
    lsq_preds = []
    for i in tqdm(range(1, xs.shape[1])): # 1-19
        preds = []
        for batch_id in range(xs.shape[0]): # 0-63
            # fit n_points -1 regressors for each entry in batch
            xs_feats = xs[batch_id,:i] # (i, n_dim_input)
            ys_feats = ys[batch_id,:i] # (i,)
            deg2_monomial_feats = transform_features_to_only_fixedS_deg2_monomials(xs_feats, batch_fixedS[batch_id])
            polyreg_fit = LinearRegression(fit_intercept = False).fit(deg2_monomial_feats, ys_feats)
            polyreg_pred = polyreg_fit.predict(
                transform_features_to_only_fixedS_deg2_monomials(xs[batch_id,i:i+1], batch_fixedS[batch_id])
                                               )[0]
            preds.append(
                polyreg_pred
            )
        preds = np.array(preds).squeeze()
        lsq_errors.append(((ys[:,i] - preds)**2).numpy())
        lsq_preds.append(preds)
        # lsq_errors.append(((ys[:,i] - preds)**2).mean(axis = 0).numpy())
    return np.array(lsq_errors).T, np.array(lsq_preds).T

# fit a degree 2 poly that looks like w1.x1.x1 + w2.x1.x2 +.....+ wk.xd.xd
# where xi for i=1 to 20 are indices in input x (20 variables) i.e. d=20,
# and weights wj for j=1 to (20C1 + 20C2); 20C1 terms x1.x1, x2.x2, ...,x20.x20; 20C2 terms x1.x2 + x1.x3 +... 
def get_polynomial_regression_only_deg2_monomials_reference_errors(xs, ys):
    # xs -> batch x n_points x n_dim_input
    # ys -> batch x n_points
    # Least Squares Optimization
    lsq_errors = []
    lsq_preds = []
    for i in tqdm(range(1, xs.shape[1])): # 1-19
        preds = []
        for batch_id in range(xs.shape[0]): # 0-63
            # fit n_points -1 regressors for each entry in batch
            xs_feats = xs[batch_id,:i] # (i, n_dim_input)
            ys_feats = ys[batch_id,:i] # (i,)
            deg2_monomial_feats = transform_features_to_only_deg2_monomials(xs_feats)
            polyreg_fit = LinearRegression(fit_intercept = False).fit(deg2_monomial_feats, ys_feats)
            polyreg_pred = polyreg_fit.predict(
                transform_features_to_only_deg2_monomials(xs[batch_id,i:i+1])
                                               )[0]
            preds.append(
                polyreg_pred
            )
        preds = np.array(preds).squeeze()
        lsq_errors.append(((ys[:,i] - preds)**2).numpy())
        lsq_preds.append(preds)
        # lsq_errors.append(((ys[:,i] - preds)**2).mean(axis = 0).numpy())
    return np.array(lsq_errors).T, np.array(lsq_preds).T

# fit a degree 2 poly that looks like w0 + w1.x1.x1 + w2.x1.x2 +.....+ wk.xd.xd
# where xi for i=1 to 20 are indices in input x (20 variables) i.e. d=20,
# and weights wj for j=0 to (20C1 + 20C2); 20C1 terms x1.x1, x2.x2, ...,x20.x20; 20C2 terms x1.x2 + x1.x3 +... 
def get_polynomial_regression_upto_deg2_monomials_reference_errors(xs, ys):
    # xs -> batch x n_points x n_dim_input
    # ys -> batch x n_points
    # Least Squares Optimization
    lsq_errors = []
    for i in tqdm(range(1, xs.shape[1])): # 1-19
        preds = []
        for batch_id in range(xs.shape[0]): # 0-63
            # fit n_points -1 regressors for each entry in batch
            xs_feats = xs[batch_id,:i] # (i, n_dim_input)
            ys_feats = ys[batch_id,:i] # (i,)
            deg2_monomial_feats = transform_features_to_upto_deg2_monomials(xs_feats)
            polyreg_fit = LinearRegression(fit_intercept = False).fit(deg2_monomial_feats, ys_feats)
            polyreg_pred = polyreg_fit.predict(
                transform_features_to_upto_deg2_monomials(xs[batch_id,i:i+1])
                                               )[0]
            preds.append(
                polyreg_pred
            )
        preds = np.array(preds).squeeze()
        lsq_errors.append(((ys[:,i] - preds)**2).mean(axis = 0).numpy())
    return np.array(lsq_errors)

def get_sparse_regression_all_deg2_monomials_reference_errors(xs, ys, alpha=0.01, max_iter=100000):
    lasso_errors = []
    lasso_preds = []
    for i in tqdm(range(1, xs.shape[1])):
        preds = []
        for batch_id in range(xs.shape[0]):
            xs_feats = xs[batch_id,:i] # (i, n_dim_input)
            ys_feats = ys[batch_id,:i] # (i,)
            deg2_monomial_feats = transform_features_to_only_deg2_monomials(xs_feats)
            # if i < 5: # why?
            preds.append(
            Lasso(alpha=alpha, fit_intercept = False, max_iter=max_iter).fit(deg2_monomial_feats, ys_feats)\
                .predict(transform_features_to_only_deg2_monomials(xs[batch_id,i:i+1]))[0]
            )
            # else:
            #     preds.append(
            #     LassoCV(fit_intercept = False).fit(deg2_monomial_feats, ys_feats)\
            #         .predict(transform_features_to_only_deg2_monomials(xs[batch_id,i:i+1]))[0]
            #     )
        preds = np.array(preds).squeeze()
        lasso_errors.append(((ys[:,i] - preds)**2).numpy())
        lasso_preds.append(preds)
        # lasso_errors.append(((ys[:,i] - preds)**2).mean(axis = 0).numpy())
    return np.array(lasso_errors).T, np.array(lasso_preds).T

def get_sparse_regression_reference_errors(xs, ys):
    lasso_errors = []
    for i in tqdm(range(1, xs.shape[1])):
        preds = []
        for batch_id in range(xs.shape[0]):
            if i < 5: # why?
                preds.append(
                Lasso(fit_intercept = False).fit(xs[batch_id,:i], ys[batch_id,:i])\
                    .predict(xs[batch_id,i:i+1])[0]
                )
            else:
                preds.append(
                LassoCV(fit_intercept = False).fit(xs[batch_id,:i], ys[batch_id,:i])\
                    .predict(xs[batch_id,i:i+1])[0]
                )
        preds = np.array(preds).squeeze()
        lasso_errors.append(((ys[:,i] - preds)**2).mean(axis = 0).numpy())
    return np.array(lasso_errors)

def get_sign_vec_cs_reference_errors(xs, ys, n_dims):
    # Inf Norm Optimization
    mat_dim = int(np.sqrt(xs.shape[2]))
    baseline_errors_batch = []
    for b in tqdm(range(xs.shape[0])):
        errors = []
        for t in range(xs.shape[1] - 1):
            w_star = Variable([n_dims, 1])
            obj = Minimize(cvxnorm(w_star, 'inf'))
            constraints = [ys[b,:t+1].numpy()[:,np.newaxis] == (xs[b,:t+1].numpy() @ w_star)]
            prob = Problem(obj, constraints)
            result = prob.solve()#verbose=True)
            if prob.status == cvxpy.OPTIMAL:
                pred = w_star.value[:,0] @ xs[b,t+1].numpy()
                errors.append((pred - ys[b,t+1].numpy())**2)
            else:
                errors.append(prob.value)
        baseline_errors_batch.append(errors)
    return np.array(baseline_errors_batch).mean(0)

def plot_results(
    transformer_loss: list, 
    plot_title,
    task_losses: dict,
    task_kwargs: dict = {}
):
    x_axis_items = np.arange(1, transformer_loss.shape[0]+1)
    plt.plot(x_axis_items, transformer_loss, lw=2, label="Transformer")
    if "sign_vec_cs" in task_losses:
        plt.plot(x_axis_items, task_losses["sign_vec_cs"], label = "Inf Norm Minimization")
        plt.scatter(task_kwargs["sign_vec_cs"]["bound"] + 1,0, color="red", label="Bound")
    if "linear_regression" in task_losses:
        plt.plot(x_axis_items, task_losses["linear_regression"], lw=2, label = "Least Squares")
    if "sparse_regression" in task_losses:
        plt.plot(x_axis_items, task_losses["sparse_regression"], lw=2, label = "Lasso")
    if "monomial_deg2" in task_losses:
        plt.plot(x_axis_items, task_losses["monomial_deg2"], lw=2, label = "L.Sq. - all deg 2 monomials")
    if "monomial_deg0_1_2" in task_losses:
        plt.plot(x_axis_items, task_losses["monomial_deg0_1_2"], lw=2, label = "L.Sq. - all deg 0,1,2 monomials")
    if "monomial_deg2_fixedS" in task_losses:
        plt.plot(x_axis_items, task_losses["monomial_deg2_fixedS"], lw=2, label = "L.Sq. - all deg 2 fixed S monomials")
    if "2-layerNN-GD" in task_losses:
        plt.plot(x_axis_items, task_losses["2-layerNN-GD"], lw=2, label = "2-layer NN, GD")
    if "lasso_all_deg2_monomials" in task_losses:
        plt.plot(x_axis_items, task_losses["lasso_all_deg2_monomials"], lw=2, label = "Lasso - all deg 2 monomials")
    if "haar_basis" in task_losses:
        plt.plot(x_axis_items, task_losses["haar_basis"], lw=2, label = "OLS on Haar Basis features (max_level=3)")

    plt.xlabel("# in-context examples")
    plt.ylabel("squared error")
    plt.title(plot_title)
    plt.legend()
    plt.show()

def plot_results_single_curve(
    loss: list, 
    loss_label: str
):
    x_axis_items = np.arange(1, loss.shape[0]+1)
    plt.plot(x_axis_items, loss, lw=2, label=loss_label)
 
    plt.xlabel("# in-context examples")
    plt.ylabel("squared error")
    plt.legend()
    plt.show()


def compute_average_loss_difference(loss1, loss2, losses_names):
    assert loss1.shape == loss2.shape, "loss1 and loss2 must have same length"
    print(f"Average difference between {losses_names} losses:", np.abs(loss1 - loss2).mean())

def average_loss_from_index(ind, loss, loss_name):
    print(f"Average {loss_name} loss from index {ind} is:", loss[ind:].mean())

def average_loss_after_seeing_at_least_k_examples(k, loss, loss_name):
    print(f"Average {loss_name} loss after seeing {k} examples in prompt is:", loss[k-1:].mean())

### Multitask eval functions

In [54]:
from KS_monomial_sets import monomial_terms


def get_model_from_run_folder(run_folder_path):
    # the only folder in run_folder_path is the run_id folder
    run_id = os.listdir(run_folder_path)[0]
    run_path = os.path.join(run_folder_path, run_id)
    recompute_metrics = False

    print(f"Loading model from {run_folder_path}/{run_id}...")
    model, conf = get_model_from_run(run_path)

    numDeg2Select = conf.training.task_kwargs["numDeg2Select"]
    sizeOfK = conf.training.task_kwargs["sizeOfK"]
    conf.training.task_kwargs["fixedK"] = torch.tensor(monomial_terms[f"{conf.model.n_dims}-{sizeOfK}-{numDeg2Select}"], dtype=torch.int64)

    return model, conf 

def get_model_from_run_folder_usual(run_folder_path):
    # the only folder in run_folder_path is the run_id folder
    run_id = os.listdir(run_folder_path)[0]
    run_path = os.path.join(run_folder_path, run_id)

    print(f"Loading model from {run_folder_path}/{run_id}...")
    model, conf = get_model_from_run(run_path)
    
    return model, conf 


def do_task_config(conf, n_dims, eval_batch_size):
    data_kwargs = conf.training.data_kwargs
    if data_kwargs is None:
        data_kwargs = {}

    # task_kwargs = copy.deepcopy(conf.training.task_kwargs)
    all_deg2_terms = get_all_deg2_term_indices(n_dims)
    conf.training.task_kwargs["all_deg2_terms"] = all_deg2_terms
    data_sampler = get_data_sampler(conf.training.data, n_dims, **data_kwargs)
    task_sampler = get_task_sampler(
        conf.training.task,
        n_dims,
        eval_batch_size,
        num_tasks=conf.training.num_tasks,
        **conf.training.task_kwargs
    )

    return data_sampler, task_sampler

def get_monomial_reference_errors(sizeOfK, xs, ys, conf, K_used=None, is_ref=["only2","upto2","onlyfixedS","lasso"], lasso_alpha=0.01):
    # reference is a degree 2 poly with (20 self sq. + 190 cross) terms that looks like w1.x1.x1 + w2.x2.x2 +.....+ w21.x1.x2 + w22.x2.x3.....+ w210.x19.x20
    # where xi for i=1 to 20 are indices in input x
    # and weights wj for j=1 to (20C1 + 20C2)=210; 20C1 terms x1.x1, x2.x2, ...,x20.x20; 20C2 terms x1.x2 + x1.x3 +... 
    lsq_errors_only_deg_2=None
    lsq_errors_upto_deg_2=None
    lsq_errors_only_deg_2_fixedS=None
    lasso_errors=None
    if "only2" in is_ref:
        lsq_errors_only_deg_2, _ = get_polynomial_regression_only_deg2_monomials_reference_errors(xs, ys)
        # B x P
    # reference is a degree 2 poly with (1 bias + 20 deg 1 + 20 self sq. + 190 cross) terms that looks like w0 + w1.x1 + w2.x2 + ... + w20.x20 + w21.x1.x1 + w22.x2.x2 +.....+ w41.x1.x2 + w42.x2.x3.....+ w230.x19.x20
    # where xi for i=1 to 20 are indices in input x
    # and weights wj for j=0 to (20C1 + 20C1 + 20C2)=230; 1 bias term; 20C1 (deg 1 terms) x1, x2, ... x20; 20C1 (deg 2 self sq.) terms x1.x1, x2.x2, ...,x20.x20; 20C2 (cross deg 2) terms x1.x2 + x1.x3 +... 
    if "upto2" in is_ref:
        lsq_errors_upto_deg_2  = get_polynomial_regression_upto_deg2_monomials_reference_errors(xs, ys)
    if "onlyfixedS" in is_ref:
        lsq_errors_only_deg_2_fixedS, _ = get_polynomial_regression_only_deg2_monomials_fixedS_reference_errors(xs, ys, np.array(K_used))
        # B x P
    if "lasso" in is_ref:
        lasso_errors, _ = get_sparse_regression_all_deg2_monomials_reference_errors(xs, ys, alpha=lasso_alpha)
        # B x P

    return {
        "only2":lsq_errors_only_deg_2,
        "upto2":lsq_errors_upto_deg_2,
        "onlyfixedS":lsq_errors_only_deg_2_fixedS,
        "lasso":lasso_errors
    }


def eval_and_plot_ID(model, bp_proxy_tf, conf, data_sampler, task_sampler, len_deg2_monomials, sizeOfK, eval_reference_list, eval_bs, eval_pts, lasso_alpha):
    print("task:", conf.training.task)
    task = task_sampler()
    xs = data_sampler.sample_xs(b_size=eval_bs, n_points=eval_pts)
    
    chosenSindices = torch.randint(0, sizeOfK, (eval_bs,))
    ys = task.evaluate(xs, mode='eval', chosenSindices=chosenSindices) # ID eval
    K_used = conf.training.task_kwargs["fixedK"][chosenSindices]
    print("current task:", task)
    print("xs shape:", xs.shape)
    print("ys shape:", ys.shape)

    assert len(task.all_deg2_terms) == len_deg2_monomials
    model.eval()
    bp_proxy_tf.eval()
    with torch.no_grad():
        transformer_pred = model(xs, ys)
        bp_proxy_pred = bp_proxy_tf(xs, ys)
    
    ref_errors_dict=get_monomial_reference_errors(sizeOfK, xs, ys, conf, K_used, is_ref=eval_reference_list, lasso_alpha=lasso_alpha)
    return ref_errors_dict, task, transformer_pred, bp_proxy_pred, ys

def eval_and_plot_OOD(model, bp_proxy_tf, conf, data_sampler, task_sampler, len_deg2_monomials, sizeOfK, eval_reference_list, eval_bs, eval_pts, lasso_alpha):
    print("task:", conf.training.task)
    task = task_sampler()
    xs = data_sampler.sample_xs(b_size=eval_bs, n_points=eval_pts)
    
    ys = task.evaluate_ood(xs) # OOD eval
    K_used = task.selected_monomial_indices
    print("current task:", task)
    print("xs shape:", xs.shape)
    print("ys shape:", ys.shape)

    assert len(task.all_deg2_terms) == len_deg2_monomials
    model.eval()
    bp_proxy_tf.eval()
    with torch.no_grad():
        transformer_pred = model(xs, ys)
        bp_proxy_pred = bp_proxy_tf(xs, ys)
    
    ref_errors_dict=get_monomial_reference_errors(sizeOfK, xs, ys, conf, K_used, is_ref=eval_reference_list, lasso_alpha=lasso_alpha)
    return ref_errors_dict, task, transformer_pred, bp_proxy_pred, ys



### Plotting helpers

In [31]:
%matplotlib inline
import seaborn as sns
import matplotlib
sns.set_theme('notebook', 'darkgrid')
palette = sns.color_palette('colorblind')
mpl.rcParams['figure.dpi'] = 300
# mpl.rcParams['text.usetex'] = True

matplotlib.rcParams.update({
    'axes.titlesize': 8,
  'figure.titlesize': 10, # was 10
  'legend.fontsize': 10, # was 10
  'xtick.labelsize': 6,
  'ytick.labelsize': 6,
})

seed=42
torch.manual_seed(seed)
def get_df_from_pred_array(pred_arr, n_points, offset = 0):
    # pred_arr --> b x pts-1
    batch_size=pred_arr.shape[0]
    flattened_arr = pred_arr.ravel()
    points = np.array(list(range(offset, n_points)) * batch_size)
    df = pd.DataFrame({'y': flattened_arr, 'x': points})
    return df

def lineplot_with_ci(pred_or_err_arr, n_points, offset, label, ax, seed):
    sns.lineplot(data=get_df_from_pred_array(pred_or_err_arr, n_points=n_points, offset = offset), 
                y="y", x="x",
                label=label, 
                ax=ax, n_boot=1000, 
                seed=seed, 
                ci=90
    )

In [59]:
def plot_latex(plot_dict, K, is_id, alpha, y_limit=None):
    plt_title = f"$K={K}$, "
    seed=42
    if is_id:
        ID_OOD = "ID"
        plt_title += "ID Evaluation"
    else:
        ID_OOD = "OOD"
        plt_title += "OOD Evaluation"
    
    sns.set(style = "whitegrid", font_scale=1.5)

    fig, ax1 = plt.subplots(1, 1, figsize=(7, 5), constrained_layout=True)

    for k, v in plot_dict.items():
        if v is not None:
            print(k, v.shape)
            lineplot_with_ci(v, n_points, offset = 1, label=k, ax=ax1, seed=seed)
    
    ax1.set_xlabel("$k$ (# in-context examples)")
    ax1.set_ylabel("$\\mathrm{loss@k}$")
    
    ax1.legend(fontsize=15)
    if y_limit is not None:
        ax1.set_ylim(y_limit[0], y_limit[1])
    plt.title(plt_title)
    plt.show()

### Main Code

In [None]:
# Train models such that they save at path some/path/prefix/polynomials/descriptive_model_folder_name/run_id_uuid
models_log_dir = "some/path/prefix"
n_dims=10
eval_batch_size = 1280
n_points = 125
len_deg2_monomials = 55

lasso_alpha = 0.1
y_limit = [-0.125, 1.625]

K=[10, 20, 40, 100, 500, 1000, 5000]
n_points_tr = 125
sizeOfS = 10

bp_proxy_tf, _ = get_model_from_run_folder_usual(os.path.join(models_log_dir, "polynomials", "descriptive_model_folder_name-for-bp-proxy-model"))


for sizeOfK in K:
    run_folder_name = f"descriptive_model_folder_name_pos130_d10_p{n_points_tr}_S{sizeOfS}_K{sizeOfK}"
    run_folder_path = os.path.join(models_log_dir, "polynomials", run_folder_name)
    model, conf = get_model_from_run_folder(run_folder_path)

    # config for the run
    print(dict(conf.training))
    data_sampler, task_sampler = do_task_config(conf, n_dims, eval_batch_size)

    # eval ID on batch of S
    eval_reference_list=["only2", "onlyfixedS", "lasso"]

    ref_errors_dict, task, transformer_pred, bp_proxy_pred, ys = eval_and_plot_ID(model, bp_proxy_tf, conf, data_sampler, task_sampler, len_deg2_monomials, sizeOfK, eval_reference_list, eval_bs=eval_batch_size, eval_pts=n_points, lasso_alpha=lasso_alpha)
    
    plot_title_str=f"|K|={sizeOfK}, |S|={sizeOfS}, p={n_points_tr}\nID: All S's are from K"

    metric = task.get_metric()
    transformer_loss = metric(transformer_pred, ys).numpy()[:,1:]
    bp_proxy_loss = metric(bp_proxy_pred, ys).numpy()[:,1:]

    # each tensor value must be batch x npoints
    plot_dict = {
        "HMICL TF": transformer_loss,
        "$\\mathrm{OLS}_{\mathcal{S}}$": ref_errors_dict["onlyfixedS"],
        "$\\mathrm{OLS}_{\Phi_{M}}$": ref_errors_dict["only2"],
        "$\\mathrm{Lasso}_{\Phi_{M}}$": ref_errors_dict["lasso"],
        "$\\mathrm{BP}_{\mathrm{proxy}}$": bp_proxy_loss
    }

    plot_latex(plot_dict, sizeOfK, is_id=True, alpha=lasso_alpha, y_limit=y_limit)

    # eval OOD
    eval_reference_list=["only2", "onlyfixedS", "lasso"]

    ref_errors_dict, task, transformer_pred, bp_proxy_pred, ys=eval_and_plot_OOD(model, bp_proxy_tf, conf, data_sampler, task_sampler, len_deg2_monomials, sizeOfK, eval_reference_list, eval_bs=eval_batch_size, eval_pts=n_points, lasso_alpha=lasso_alpha)
    
    plot_title_str=f"|K|={sizeOfK}, |S|={sizeOfS}, p={n_points_tr}\nOOD: All S's are random"

    metric = task.get_metric()
    transformer_loss = metric(transformer_pred, ys).numpy()[:,1:]
    bp_proxy_loss = metric(bp_proxy_pred, ys).numpy()[:,1:]

    # each tensor value must be batch x npoints
    plot_dict = {
        "HMICL TF": transformer_loss,
        "$\\mathrm{OLS}_{\mathcal{S}}$": ref_errors_dict["onlyfixedS"],
        "$\\mathrm{OLS}_{\Phi_{M}}$": ref_errors_dict["only2"],
        "$\\mathrm{Lasso}_{\Phi_{M}}$": ref_errors_dict["lasso"],
        "$\\mathrm{BP}_{\mathrm{proxy}}$": bp_proxy_loss
    }

    plot_latex(plot_dict, sizeOfK, is_id=False, alpha=lasso_alpha, y_limit=y_limit)
        