In [None]:
# Only run once!!
import os

os.chdir("../src/")

In [None]:
from collections import OrderedDict
import re
import os
import math
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
import seaborn as sns
import torch
from tqdm.notebook import tqdm

from eval import get_run_metrics, read_run_dir, get_model_from_run
from plot_utils import basic_plot, collect_results, relevant_model_names
from samplers import get_data_sampler
from tasks import get_task_sampler

import matplotlib as mpl
from sklearn.linear_model import LinearRegression, Lasso, LassoCV, SGDRegressor, Ridge
import numpy as np
import cvxpy
from cvxpy import Variable, Minimize, Problem
from cvxpy import norm as cvxnorm

# from cvxpy import mul_elemwise, SCS
from cvxpy import vec as cvxvec

%matplotlib inline
%load_ext autoreload
%autoreload 2

sns.set_theme("notebook", "darkgrid")
palette = sns.color_palette("colorblind")
mpl.rcParams["figure.dpi"] = 300
mpl.rcParams["text.usetex"] = True

matplotlib.rcParams.update(
    {
        "axes.titlesize": 8,
        "figure.titlesize": 10,  # was 10
        "legend.fontsize": 10,  # was 10
        "xtick.labelsize": 6,
        "ytick.labelsize": 6,
    }
)
run_dir = "../models"

In [None]:
SPINE_COLOR = "gray"


def format_axes(ax):
    for spine in ["top", "right"]:
        ax.spines[spine].set_color(SPINE_COLOR)
        ax.spines[spine].set_linewidth(0.5)

    for spine in ["left", "bottom"]:
        ax.spines[spine].set_color(SPINE_COLOR)
        ax.spines[spine].set_linewidth(0.5)

    ax.xaxis.set_ticks_position("bottom")
    ax.yaxis.set_ticks_position("left")

    for axis in [ax.xaxis, ax.yaxis]:
        axis.set_tick_params(direction="out", color=SPINE_COLOR)
    return ax

In [None]:
task = "sparse_linear_regression"
run_id = "final_model"  # Change according to the id of the model you train
dr_model, dr_conf = get_model_from_run(os.path.join(run_dir, task, run_id))
dr_model.to("cuda:1")

In [None]:
batch_size = 1280  # 1280 #conf.training.batch_size
n_dims = 20
n_points = dr_conf.training.curriculum.points.end
data_sampler = get_data_sampler(dr_conf.training.data, n_dims)
task_sampler = get_task_sampler(
    dr_conf.training.task, n_dims, batch_size, **dr_conf.training.task_kwargs
)

In [None]:
seed = 42
torch.manual_seed(seed)
task = task_sampler()
xs = data_sampler.sample_xs(b_size=batch_size, n_points=n_points)
ys = task.evaluate(xs)

In [None]:
with torch.no_grad():
    transformer_preds = dr_model(xs.to("cuda:1"), ys.to("cuda:1")).cpu()

In [None]:
metric = task.get_metric()
transformer_errors = metric(transformer_preds, ys).numpy().squeeze()

In [None]:
transformer_errors.mean(0)

In [None]:
lsq_preds = [np.zeros(xs.shape[0])]
for i in tqdm(range(1, xs.shape[1])):
    preds = []
    for batch_id in range(xs.shape[0]):
        preds.append(
            LinearRegression(fit_intercept=False)
            .fit(xs[batch_id, :i], ys[batch_id, :i])
            .predict(xs[batch_id, i : i + 1])[0]
        )
    preds = np.array(preds).squeeze()
    lsq_preds.append(preds)
lsq_preds = np.array(lsq_preds)
lsq_preds = torch.tensor(lsq_preds.T).float()

In [None]:
lsq_errors = metric(lsq_preds, ys).numpy().squeeze()

In [None]:
ridge_preds = [np.zeros(xs.shape[0])]
for i in tqdm(range(1, xs.shape[1])):
    preds = []
    for batch_id in range(xs.shape[0]):
        preds.append(
            Ridge(fit_intercept=False, alpha=1e-2)
            .fit(xs[batch_id, :i], ys[batch_id, :i])
            .predict(xs[batch_id, i : i + 1])[0]
        )
    preds = np.array(preds).squeeze()
    ridge_preds.append(preds)
ridge_preds = np.array(ridge_preds)
ridge_preds = torch.tensor(ridge_preds.T).float()

In [None]:
ridge_errors = metric(ridge_preds, ys).numpy().squeeze()

In [None]:
lasso_preds = [np.zeros(xs.shape[0])]
for i in tqdm(range(1, xs.shape[1])):
    preds = []
    for batch_id in range(xs.shape[0]):
        preds.append(
            Lasso(fit_intercept=False, alpha=1e-2)
            .fit(xs[batch_id, :i], ys[batch_id, :i])
            .predict(xs[batch_id, i : i + 1])[0]
        )
    preds = np.array(preds).squeeze()
    lasso_preds.append(preds)
lasso_preds = np.array(lasso_preds)
lasso_preds = torch.tensor(lasso_preds.T).float()

In [None]:
lasso_preds = metric(lasso_preds, ys).numpy().squeeze()

In [None]:
# l2_norm_preds = []
# w_stars = []
# for b in tqdm(range(xs.shape[0])):
#     preds = [0]
#     for t in range(xs.shape[1] - 1):
#         w_star = Variable([n_dims, 1])
#         obj = Minimize(cvxnorm(w_star, 2))
#         constraints = [ys[b,:t+1].numpy()[:,np.newaxis] == (xs[b,:t+1].numpy() @ w_star)]
#         prob = Problem(obj, constraints)
#         result = prob.solve()
#         try:
#             pred = w_star.value[:,0] @ xs[b,t+1].numpy()
#         except:
#             pred = 0
# #         errors.append((pred - ys[b,t+1].numpy())**2)
#         w_stars.append(w_star.value)
#         preds.append(pred)
#     l2_norm_preds.append(preds)
# #     baseline_errors_batch.append(errors)
# # np.mean(errors)
# l2_norm_preds = torch.tensor(l2_norm_preds).float()

In [None]:
# l2_norm_errors = metric(l2_norm_preds, ys).numpy().squeeze()
# l2_norm_errors.shape

In [None]:
# l2_norm_errors.mean(axis = 0)

In [None]:
def get_df_from_pred_array(pred_arr, n_points, offset=0):
    # pred_arr --> b x pts-1
    batch_size = pred_arr.shape[0]
    flattened_arr = pred_arr.ravel()
    points = np.array(list(range(offset, n_points)) * batch_size)
    df = pd.DataFrame({"y": flattened_arr, "x": points})
    return df


def lineplot_with_ci(pred_or_err_arr, n_points, offset, label, ax, seed):
    sns.lineplot(
        data=get_df_from_pred_array(pred_or_err_arr, n_points=n_points, offset=offset),
        y="y",
        x="x",
        label=label,
        ax=ax,
        n_boot=1000,
        seed=seed,
        ci=90,
    )

In [None]:
dr_conf.training.task_kwargs["sparsity"]

In [None]:
sns.set(style="whitegrid", font_scale=1.5)
# latexify(4, 3)
s = dr_conf.training.task_kwargs["sparsity"]
bound = int(2 * s * math.log(n_dims / s) + 5 * s / 4)
fig, ax = plt.subplots()
# ax.plot(list(range(n_points)), transformer_pe_errors.mean(axis=0), label = "With Position Encodings")
# ax.plot(list(range(n_points)), transformer_no_pe_errors.mean(axis=0), label = "Without Position Encodings")
lineplot_with_ci(
    transformer_errors / s, n_points, offset=0, label="Transformer", ax=ax, seed=seed
)
lineplot_with_ci(lsq_errors / s, n_points, offset=0, label="OLS", ax=ax, seed=seed)
lineplot_with_ci(lasso_preds / s, n_points, offset=0, label="Lasso", ax=ax, seed=seed)
plt.axvline(bound, ls="--", color="black")
ax.annotate("Bound", xy=(bound + 0.25, 0.6), color="r", rotation=0)
# lineplot_with_ci(l2_norm_errors, n_points, label="L-2 Norm Min", ax=ax, seed=seed)
ax.set_xlabel("$k$\n(\# in-context examples)")
ax.set_ylabel("$\\texttt{loss@}k$")
ax.set_title("Sparse Regression ICL")
format_axes(ax)
# plt.axhline(baseline, ls="--", color="gray", label="zero estimator")
plt.legend()
plt.savefig("final_plots/sr_errors.pdf", dpi=300, bbox_inches="tight")
plt.show()

In [None]:
def recover_weights(model, xs, ys, w_b):
    model.to("cuda:0")
    batch_size = xs.size(0)
    n_dims = w_b.size(1)
    es = torch.eye(n_dims).unsqueeze(0).repeat(batch_size, 1, 1)
    w_probed = []
    for i in range(n_dims):
        x_probe = torch.concat([xs, es[:, i : i + 1, :]], axis=1)
        y_probe = torch.concat([ys, w_b[:, i : i + 1, 0]], axis=1)
        with torch.no_grad():
            pred = model(x_probe.to("cuda:0"), y_probe.to("cuda:0")).cpu()
        w_is = pred[:, -1:]
        w_probed.append(w_is)
    w_probed = torch.cat(w_probed, axis=1)
    error = ((w_probed - w_b[:, :, 0]) ** 2).mean(axis=1).mean()
    cos_sim = torch.nn.CosineSimilarity(dim=1, eps=1e-6)(w_probed, w_b[:, :, 0]).mean()

    return w_b, w_probed, error, cos_sim

In [None]:
def recover_weights_pv2(model, xs, ys, w_b, data_sampler):
    model.to("cuda:0")
    batch_size = xs.size(0)
    n_dims = w_b.size(1)

    x_probes = data_sampler.sample_xs(b_size=xs.shape[0], n_points=2 * xs.shape[-1] + 1)
    y_probes = []
    for i in range(2 * xs.shape[-1] + 1):
        x_prompt = torch.concat([xs, x_probes[:, i : i + 1, :]], axis=1)
        y_prompt = torch.concat([ys, torch.zeros(xs.shape[0], 1)], axis=1)
        with torch.no_grad():
            pred = model(x_prompt.to("cuda:0"), y_prompt.to("cuda:0")).cpu()
        y_probes.append(pred[:, -1:])

    y_probes = torch.cat(y_probes, axis=1)
    w_probed = []

    for batch in range(len(x_probes)):
        x, y = x_probes[batch], y_probes[batch]
        probe_model = LinearRegression(fit_intercept=False)
        probe_model.fit(x, y)
        w_probed.append(torch.tensor(probe_model.coef_[np.newaxis]).float())

    w_probed = torch.cat(w_probed, axis=0)
    error = ((w_probed - w_b[:, :, 0]) ** 2).mean(axis=1).mean()
    cos_sim = torch.nn.CosineSimilarity(dim=1, eps=1e-6)(w_probed, w_b[:, :, 0]).mean()

    return w_b, w_probed, error, cos_sim

In [None]:
seed = 42
torch.manual_seed(seed)
batch_size = 1280
n_points = dr_conf.training.curriculum.points.end
data_sampler = get_data_sampler(dr_conf.training.data, n_dims)
task_sampler = get_task_sampler(
    dr_conf.training.task, n_dims, batch_size, **dr_conf.training.task_kwargs
)
task = task_sampler()
xs = data_sampler.sample_xs(b_size=batch_size, n_points=n_points)
ys = task.evaluate(xs)

In [None]:
w_b, w_probed, error, cos_sim = recover_weights_pv2(
    dr_model, xs[:, :-1], ys[:, :-1], task.w_b, data_sampler
)

In [None]:
w_probed

In [None]:
cos_sim

In [None]:
w_probed_vecs = []
for n_points_i in tqdm(range(1, n_points)):
    _, w_probed, _, cos_sim = recover_weights_pv2(
        dr_model, xs[:, :n_points_i], ys[:, :n_points_i], task.w_b, data_sampler
    )
    w_probed_vecs.append(w_probed)

In [None]:
# ridge_preds = [np.zeros(xs.shape[0])]
lsq_weights = []
for i in tqdm(range(1, xs.shape[1])):
    weights = []
    for batch_id in range(xs.shape[0]):
        lsq_model = LinearRegression(fit_intercept=False)
        lsq_model.fit(xs[batch_id, :i], ys[batch_id, :i])
        weights.append(lsq_model.coef_)
    lsq_weights.append(weights)
lsq_weights = np.array(lsq_weights)
# ridge_weights = torch.tensor(ridge_weights).transpose(0,1).float()

In [None]:
lsq_weights = torch.tensor(lsq_weights).transpose(0, 1).float()

In [None]:
# ridge_preds = [np.zeros(xs.shape[0])]
lasso_weights = []
for i in tqdm(range(1, xs.shape[1])):
    weights = []
    for batch_id in range(xs.shape[0]):
        lasso_model = Lasso(fit_intercept=False, alpha=1e-2)
        lasso_model.fit(xs[batch_id, :i], ys[batch_id, :i])
        weights.append(lasso_model.coef_)
    lasso_weights.append(weights)
lasso_weights = np.array(lasso_weights)
lasso_weights = torch.tensor(lasso_weights).transpose(0, 1).float()

In [None]:
# l1_norm_weights = []
# w_stars = []
# for b in tqdm(range(xs.shape[0])):
#     weights = []
#     for t in range(xs.shape[1] - 1):
#         w_star = Variable([n_dims, 1])
#         obj = Minimize(cvxnorm(w_star, 1))
#         constraints = [ys[b,:t+1].numpy()[:,np.newaxis] == (xs[b,:t+1].numpy() @ w_star)]
#         prob = Problem(obj, constraints)
#         result = prob.solve()
#         try:
#             pred = w_star.value[:,0] @ xs[b,t+1].numpy()
#         except:
#             pred = 0
# #         errors.append((pred - ys[b,t+1].numpy())**2)
#         weights.append(w_star.value)
#     l1_norm_weights.append(weights)
# #     baseline_errors_batch.append(errors)
# # np.mean(errors)
# l1_norm_weights = torch.tensor(l1_norm_weights).float()

In [None]:
# l1_norm_weights = torch.tensor([weights[:20] for weights in l1_norm_weights]).float()

In [None]:
# l1_norm_weights.shape

In [None]:
# cos_sims_trans_lsq = []
# cos_sims_trans_lasso = []
# cos_sims_trans_l1min = []
# cos_sims_trans_gold = []
# gold_weights = task.w_b
# for n_points_i in tqdm(range(1, n_points)):

#     trans_weight_vect = w_probed_vecs[n_points_i - 1].squeeze()
#     lsq_weight_vect = lsq_weights[:, n_points_i - 1].squeeze()
#     lasso_weight_vect = lasso_weights[:, n_points_i - 1].squeeze()
#     l1_min_weight_vect = l1_norm_weights[:, min(n_points_i - 1, l1_norm_weights.shape[1] - 1)].squeeze()

#     cos_sims_trans_lsq.append(torch.nn.CosineSimilarity(dim = 1, eps = 1e-6)(trans_weight_vect, lsq_weight_vect))
#     cos_sims_trans_lasso.append(torch.nn.CosineSimilarity(dim = 1, eps = 1e-6)(trans_weight_vect, lasso_weight_vect))
#     cos_sims_trans_l1min.append(torch.nn.CosineSimilarity(dim = 1, eps = 1e-6)(trans_weight_vect, l1_min_weight_vect))
#     cos_sims_trans_gold.append(torch.nn.CosineSimilarity(dim = 1, eps = 1e-6)(trans_weight_vect, gold_weights.squeeze()))

In [None]:
# cos_sims_trans_lsq = torch.vstack(cos_sims_trans_lsq).transpose(0,1)
# cos_sims_trans_lasso = torch.vstack(cos_sims_trans_lasso).transpose(0,1)
# cos_sims_trans_l1min = torch.vstack(cos_sims_trans_l1min).transpose(0,1)
# cos_sims_trans_gold = torch.vstack(cos_sims_trans_gold).transpose(0,1)

In [None]:
mse_trans_lsq = []
mse_trans_lasso = []
# mse_trans_l1min = []
mse_trans_gold = []
gold_weights = task.w_b
for n_points_i in tqdm(range(1, n_points)):
    trans_weight_vect = w_probed_vecs[n_points_i - 1].squeeze()
    lsq_weight_vect = lsq_weights[:, n_points_i - 1].squeeze()
    lasso_weight_vect = lasso_weights[:, n_points_i - 1].squeeze()
    #     l1_min_weight_vect = l1_norm_weights[:, min(n_points_i - 1, l1_norm_weights.shape[1] - 1)].squeeze()

    mse_trans_lsq.append(((trans_weight_vect - lsq_weight_vect) ** 2).mean(axis=-1))
    mse_trans_lasso.append(((trans_weight_vect - lasso_weight_vect) ** 2).mean(axis=-1))
    #     mse_trans_l1min.append(((trans_weight_vect - l1_min_weight_vect)**2).mean(axis = -1))
    mse_trans_gold.append(
        ((trans_weight_vect - gold_weights.squeeze()) ** 2).mean(axis=-1)
    )

In [None]:
mse_trans_lsq = torch.vstack(mse_trans_lsq).transpose(0, 1)
mse_trans_lasso = torch.vstack(mse_trans_lasso).transpose(0, 1)
# mse_trans_l1min = torch.vstack(mse_trans_l1min).transpose(0,1)
mse_trans_gold = torch.vstack(mse_trans_gold).transpose(0, 1)

In [None]:
sns.set(style="whitegrid", font_scale=1.5)
# latexify(4, 3)

fig, ax = plt.subplots()
# ax.plot(list(range(n_points)), transformer_pe_errors.mean(axis=0), label = "With Position Encodings")
# ax.plot(list(range(n_points)), transformer_no_pe_errors.mean(axis=0), label = "Without Position Encodings")
lineplot_with_ci(
    mse_trans_gold[:, :-1] * 20 / 3,
    n_points - 1,
    offset=1,
    label="$(w^{\mathrm{probe}}, w)$",
    ax=ax,
    seed=seed,
)
# lineplot_with_ci(lsq_errors, n_points, label="Least Squares", ax=ax, seed=seed)
lineplot_with_ci(
    mse_trans_lsq[:, :-1] * 20 / 3,
    n_points - 1,
    offset=1,
    label="$(w^{\mathrm{probe}}, w^{\mathrm{OLS}})$",
    ax=ax,
    seed=seed,
)
lineplot_with_ci(
    mse_trans_lasso[:, :-1] * 20 / 3,
    n_points - 1,
    offset=1,
    label="$(w^{\mathrm{probe}}, w^{\mathrm{Lasso}})$",
    ax=ax,
    seed=seed,
)
# lineplot_with_ci(cos_sims_trans_l1min[:,:-1], n_points - 1,offset = 1, label="$(w^{\mathrm{probe}}, w^{\ell_1})$", ax=ax, seed=seed)
# lineplot_with_ci(l2_norm_errors, n_points, label="L-2 Norm Min", ax=ax, seed=seed)
ax.set_xlabel("$k$\n(\# in-context examples)")
ax.set_ylabel("mean squared error")
ax.set_title("Sparse Regression ICL")
format_axes(ax)
plt.axvline(bound, ls="--", color="black")
ax.annotate("Bound", xy=(bound + 0.25, 0.5), color="r", rotation=0)
# plt.axhline(baseline, ls="--", color="gray", label="zero estimator")
plt.legend()
plt.savefig("final_plots/sr_probing_mse.pdf", dpi=300, bbox_inches="tight")
plt.show()

In [None]:
# sns.set(style = "whitegrid", font_scale=1.5)
# # latexify(4, 3)

# fig, ax = plt.subplots()
# # ax.plot(list(range(n_points)), transformer_pe_errors.mean(axis=0), label = "With Position Encodings")
# # ax.plot(list(range(n_points)), transformer_no_pe_errors.mean(axis=0), label = "Without Position Encodings")
# lineplot_with_ci(cos_sims_trans_gold[:,:-1], n_points - 1, offset = 1, label="$(w^{\mathrm{probe}}, w)$", ax=ax, seed=seed)
# # lineplot_with_ci(lsq_errors, n_points, label="Least Squares", ax=ax, seed=seed)
# lineplot_with_ci(cos_sims_trans_lsq[:,:-1], n_points - 1,offset = 1, label="$(w^{\mathrm{probe}}, w^{\mathrm{OLS}})$", ax=ax, seed=seed)
# lineplot_with_ci(cos_sims_trans_lasso[:,:-1], n_points - 1,offset = 1, label="$(w^{\mathrm{probe}}, w^{\mathrm{Lasso}})$", ax=ax, seed=seed)
# # lineplot_with_ci(cos_sims_trans_l1min[:,:-1], n_points - 1,offset = 1, label="$(w^{\mathrm{probe}}, w^{\ell_1})$", ax=ax, seed=seed)
# # lineplot_with_ci(l2_norm_errors, n_points, label="L-2 Norm Min", ax=ax, seed=seed)
# ax.set_xlabel("$k$\n(\# in-context examples)")
# ax.set_ylabel("cosine similarity")
# ax.set_title("Sparse Regression ICL")
# plt.axvline(bound, ls="--", color="black")
# ax.annotate('Bound', xy=(bound + 0.25, 0.6), color='r', rotation=0)
# format_axes(ax)
# # plt.axhline(baseline, ls="--", color="gray", label="zero estimator")
# plt.legend()
# plt.savefig("final_plots/sr_probing.pdf", dpi = 300, bbox_inches = "tight")
# plt.show()

In [None]:
w_probed_vecs_t = torch.cat(
    [vec.unsqueeze(0) for vec in w_probed_vecs], axis=0
).transpose(0, 1)

In [None]:
w_probed_vecs_t[0][-1]

In [None]:
probe_weights_batch0 = w_probed_vecs_t[0][:20]

sns.heatmap(probe_weights_batch0, cmap="coolwarm", linewidth=1.5)
plt.title("$w^{probe}$")
plt.xlabel("Dim")
plt.ylabel("$k$\n(\# in-context examples)")
plt.savefig("final_plots/sparse_w_probe.pdf", dpi=300, bbox_inches="tight")

In [None]:
probe_weights_batch0.shape