In [None]:
from jaxl.constants import *
from jaxl.datasets import get_dataset
from jaxl.learning_utils import get_learner
from jaxl.models.svm import *
from jaxl.plot_utils import set_size
from jaxl.utils import parse_dict

import copy
import jax
import jax.numpy as jnp
import jax.random as jrandom
import json
import math
import matplotlib.pyplot as plt
import numpy as np
import optax
import os

from functools import partial
from orbax.checkpoint import PyTreeCheckpointer, CheckpointManager

# plt.style.use("seaborn")

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# base_dir = "/Users/chanb/research/personal/jaxl/jaxl"
base_dir = "/home/bryanpu1/projects/jaxl/jaxl"

# larger arch on Mac
# rel_path = "logs/icl-sigmoid_bce-context_len_8-large_arch/default-10-04-23_17_34_52-aa4fa4c8-db6a-4dd4-b91d-64d72b8b56a8"
# rel_path = "logs/icl-sigmoid_bce-context_len_8-large_arch/all_ones-10-04-23_17_34_55-dbeed583-fe53-417e-9c7d-4f712dae126e"
# rel_path = "logs/icl-sigmoid_bce-context_len_8-large_arch/one_hot-10-04-23_17_35_06-e36233bc-9649-4434-afa3-1e218a145e71/"

# num blocks = 8 on Salient 2
rel_path = "logs/icl-sigmoid_bce-context_len_16-num_blocks_8/one_hot-10-07-23_10_46_11-22362b3d-0b55-4c9f-af72-982e364d61a9"

learner_path = os.path.join(base_dir, rel_path)
test_dataset_seed = 999
sequence_len = 40
num_tasks = 30
# num_tasks = 1

# For plotting
doc_width_pt = 1000.0

In [None]:
config_path = os.path.join(learner_path, "config.json")
with open(config_path, "r") as f:
    config_dict = json.load(f)
    config = parse_dict(config_dict)

query_pred_only = getattr(config.model_config, "query_pred_only", False)
print(config.learner_config.losses)
use_sigmoid = config.learner_config.losses[0] == CONST_SIGMOID_BCE
if query_pred_only:
    if use_sigmoid:

        def process_prediction(preds):
            probs = jax.nn.sigmoid(preds.flatten())
            return np.eye(2)[(probs >= 0.5).astype(int)]

    else:

        def process_prediction(preds):
            return preds[:, -1]

else:

    def process_prediction(preds):
        return preds[:, 0, -1]


learner = get_learner(
    config.learner_config, config.model_config, config.optimizer_config
)

checkpoint_manager = CheckpointManager(
    os.path.join(learner_path, "models"),
    PyTreeCheckpointer(),
)

params = checkpoint_manager.restore(checkpoint_manager.latest_step())
params[CONST_MODEL_DICT][CONST_MODEL][CONST_POSITIONAL_ENCODING] = dict()
llm_model = learner._model

In [None]:
params[CONST_MODEL_DICT][CONST_MODEL][CONST_PREDICTOR]

In [None]:
config.learner_config.dataset_config

In [None]:
input_range = [-1.0, 1.0]
ns_test_config = copy.deepcopy(vars(config.learner_config.dataset_config))

ns_test_config["dataset_wrapper"] = vars(ns_test_config["dataset_wrapper"])
ns_test_config["dataset_wrapper"]["type"] = "FixedLengthContextDataset"
ns_test_config["dataset_wrapper"] = parse_dict(ns_test_config["dataset_wrapper"])

ns_test_config["dataset_kwargs"] = vars(ns_test_config["dataset_kwargs"])
ns_test_config["dataset_kwargs"]["num_sequences"] = num_tasks
ns_test_config["dataset_kwargs"]["sequence_length"] = sequence_len
ns_test_config["dataset_kwargs"]["params_bound"] = [-0.5, 0.5]
ns_test_config["dataset_kwargs"]["inputs_range"] = input_range
# ns_test_config["dataset_kwargs"]["margin"] = 0.5
ns_test_config["dataset_kwargs"] = parse_dict(ns_test_config["dataset_kwargs"])
ns_test_config = parse_dict(ns_test_config)

In [None]:
test_dataset = get_dataset(ns_test_config, seed=test_dataset_seed)
unwrapped_dataset = test_dataset._dataset

In [None]:
context_len = config.learner_config.dataset_config.dataset_wrapper.kwargs.context_len


def get_result(dataset, task_i, context_len):
    context_inputs, context_outputs = [], []
    for context_i in range(context_len):
        context_inputs.append(dataset._inputs[task_i, context_i])
        context_outputs.append(dataset._targets[task_i, context_i])
    context_inputs = np.stack(context_inputs)
    context_outputs = np.stack(context_outputs)

    queries = dataset._inputs[task_i, context_len:]
    outputs = dataset._targets[task_i, context_len:]

    preds, _ = jax.vmap(llm_model.forward, in_axes=[None, 0, None])(
        params[CONST_MODEL_DICT][CONST_MODEL],
        queries[:, None, None],
        {
            CONST_CONTEXT_INPUT: context_inputs[None, :],
            CONST_CONTEXT_OUTPUT: context_outputs[None, :],
        },
    )
    return queries, preds, outputs, context_inputs, context_outputs

# Analysis

In [None]:
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.svm import LinearSVC
from sklearn.neighbors import KNeighborsClassifier

In [None]:
res = {}

for task_i in range(num_tasks):
    queries, preds, outputs, context_inputs, context_outputs = get_result(
        unwrapped_dataset, task_i, context_len
    )
    preds = process_prediction(preds)
    res.setdefault(task_i, {})
    res[task_i]["data"] = {
        "context_inputs": context_inputs,
        "context_outputs": context_outputs,
        "queries": queries,
        "outputs": outputs,
    }
    res[task_i]["llm"] = preds

    gt = test_dataset.params[task_i]
    res[task_i]["gt"] = -np.array(input_range) * gt[1] / gt[2]

In [None]:
svm_regs = [1e-2, 1e-1, 5e-1, 1.0, 2.0, 10.0, 100.0, 1000.0]
lr_regs = [1e-2, 1e-1, 5e-1, 1.0, 2.0, 10.0, 100.0, 1000.0]
knn_ks = [1, 3]

### K-NN

In [None]:
def make_knn(inputs, outputs, num_neighbours):
    knn = make_pipeline(
        KNeighborsClassifier(
            n_neighbors=num_neighbours,
        )
    )
    knn.fit(inputs, np.argmax(outputs, axis=1))
    return knn

In [None]:
for task_i in range(num_tasks):
    knns = {}
    context_inputs = res[task_i]["data"]["context_inputs"]
    context_outputs = res[task_i]["data"]["context_outputs"]

    for idx, k in enumerate(knn_ks):
        knns[k] = make_knn(context_inputs, context_outputs, k)

    res[task_i]["knn"] = knns

### SVM

In [None]:
def make_svm(inputs, outputs, reg_coef):
    svm = make_pipeline(
        LinearSVC(
            C=reg_coef,
            max_iter=2000,
        ),
    )
    svm.fit(inputs, np.argmax(outputs, axis=1))
    return svm

In [None]:
svm_info = {}

for task_i in range(num_tasks):
    svms = {}
    svm_info.setdefault(task_i, {})

    context_inputs = res[task_i]["data"]["context_inputs"]
    context_outputs = res[task_i]["data"]["context_outputs"]

    for idx, svm_reg in enumerate(svm_regs):
        svms[svm_reg] = make_svm(context_inputs, context_outputs, svm_reg)

        decision_function = svms[svm_reg].decision_function(context_inputs)
        support_vector_indices = np.where(np.abs(decision_function) <= 1 + 1e-15)[0]
        support_vectors = context_inputs[support_vector_indices]

        svm_info[task_i][svm_reg] = {
            "support_vectors": support_vectors,
            "support_labels": context_outputs[support_vector_indices],
            "support_vector_indices": support_vector_indices,
        }

    res[task_i]["svm"] = svms

### Logistic Regression

In [None]:
def make_lr(inputs, outputs, penalty, reg_coef):
    logistic_regression = make_pipeline(
        LogisticRegression(
            penalty=penalty,
            C=reg_coef,
            max_iter=2000,
        )
    )
    logistic_regression.fit(inputs, np.argmax(outputs, axis=1))
    return logistic_regression

In [None]:
for task_i in range(num_tasks):
    lrs = {}

    context_inputs = res[task_i]["data"]["context_inputs"]
    context_outputs = res[task_i]["data"]["context_outputs"]

    for idx, lr_reg in enumerate(lr_regs):
        lrs[lr_reg] = make_lr(context_inputs, context_outputs, "l2", lr_reg)
    res[task_i]["lr"] = lrs

### ICL Analysis

In [None]:
model_preds = {}
num_models = len(lr_regs) + len(svm_regs)

delta = 0.01
xs_grid = np.arange(-1.0, 1.0 + delta, delta)
test_queries = np.stack(np.meshgrid(xs_grid, xs_grid)).reshape((2, -1)).T

In [None]:
for task_i in range(num_tasks):
    context_inputs = res[task_i]["data"]["context_inputs"]
    context_outputs = res[task_i]["data"]["context_outputs"]

    model_preds.setdefault(task_i, {})
    llm_preds, _ = jax.vmap(llm_model.forward, in_axes=[None, 0, None])(
        params[CONST_MODEL_DICT][CONST_MODEL],
        test_queries[:, None, None],
        {
            CONST_CONTEXT_INPUT: context_inputs[None, :],
            CONST_CONTEXT_OUTPUT: context_outputs[None, :],
        },
    )
    llm_preds = process_prediction(llm_preds)
    model_preds[task_i]["llm"] = llm_preds
    model_preds[task_i]["gt"] = np.eye(2)[
        (
            (
                test_queries @ test_dataset.params[task_i, 1:]
                + test_dataset.params[task_i, :1]
            )
            >= 0
        )
        .flatten()
        .astype(int)
    ]

In [None]:
for task_i in range(num_tasks):
    ncols = 4
    nrows = math.ceil(num_models / ncols)
    fig, axes = plt.subplots(
        nrows,
        ncols,
        figsize=set_size(doc_width_pt, 0.95, (nrows, ncols), False),
        layout="constrained",
    )

    model_classes = {
        "SVM": res[task_i]["svm"],
        "LR": res[task_i]["lr"],
    }

    llm_preds = model_preds[task_i]["llm"]
    llm_pred_labels = np.argmax(llm_preds, axis=-1)

    model_i = 0
    for model_class, models in model_classes.items():
        for reg_coef, model in models.items():
            model_out = (
                -(np.array(input_range) * model[0].coef_[0, 0] + model[0].intercept_[0])
                / model[0].coef_[0, 1]
            )

            if nrows == 1:
                ax = axes[model_i]
            else:
                ax = axes[model_i // ncols, model_i % ncols]

            for possible_label in [0, 1]:
                idxes = np.where(llm_pred_labels == possible_label)
                ax.scatter(
                    test_queries[idxes][:, 0],
                    test_queries[idxes][:, 1],
                    label=f"{possible_label}" if model_i == 0 else "",
                    s=5,
                )

            if model_class == "SVM":
                ax.scatter(
                    svm_info[task_i][reg_coef]["support_vectors"][:, 0],
                    svm_info[task_i][reg_coef]["support_vectors"][:, 1],
                    label="support vector" if model_i == 0 else "",
                    color="black",
                )

            ax.plot(
                np.array(input_range),
                res[task_i]["gt"],
                label="Ground truth" if model_i == 0 else "",
                color="black",
                linewidth=1,
            )

            ax.plot(
                np.array(input_range),
                model_out,
                color="blue",
                linewidth=1,
                label="Comparator" if model_i == 0 else "",
            )

            ax.set_xlim(input_range[0] - 0.01, input_range[1] + 0.01)
            ax.set_ylim(input_range[0] - 0.01, input_range[1] + 0.01)
            ax.set_title(f"{model_class} {reg_coef}")
            model_i += 1

    fig.legend(
        bbox_to_anchor=(0.0, 1.0, 1.0, 0.0),
        loc="lower center",
        ncols=4,
        borderaxespad=0.0,
        frameon=True,
        fontsize="8",
    )
    fig.supxlabel("$x_1$")
    fig.supylabel("$x_2$")
    plt.tight_layout()
    plt.show()

### Mean 0-1 Prediction Difference

In [None]:
losses = {}

for task_i in range(num_tasks):
    model_classes = {
        "SVM": res[task_i]["svm"],
        "LR": res[task_i]["lr"],
        "KNN": res[task_i]["knn"],
    }

    llm_preds = np.argmax(model_preds[task_i]["llm"], axis=1)
    for model_class, models in model_classes.items():
        for reg_coef, model in models.items():
            losses.setdefault((model_class, reg_coef), [])
            losses[(model_class, reg_coef)].append(
                np.mean(model.predict(test_queries) == llm_preds)
            )

for model_info, curr_loss in losses.items():
    print(
        "{}: {:.2f}% +/- {:.2f}".format(
            model_info, np.mean(curr_loss) * 100, np.std(curr_loss * 100)
        )
    )

### Context Using Support Vector

In [None]:
if query_pred_only:

    def get_new_pairs(support_vectors, support_labels, query, context_len):
        support_vectors = jnp.concatenate(
            (
                jnp.zeros(
                    (context_len - len(support_vectors), *support_vectors.shape[1:])
                ),
                support_vectors,
                query,
            )
        )
        support_labels = jnp.concatenate(
            (
                jnp.zeros(
                    (context_len - len(support_labels), *support_labels.shape[1:])
                ),
                support_labels,
            )
        )
        return support_vectors, support_labels

else:

    def get_new_pairs(support_vectors, support_labels, query, context_len):
        support_vectors = jnp.concatenate(
            (
                support_vectors,
                query,
                jnp.zeros(
                    (context_len - len(support_vectors), *support_vectors.shape[1:])
                ),
            )
        )
        support_labels = jnp.concatenate(
            (
                support_labels,
                jnp.zeros(
                    (context_len - len(support_labels), *support_labels.shape[1:])
                ),
            )
        )
        return support_vectors, support_labels

In [None]:
context_res = {}

for task_i in range(num_tasks):
    context_res.setdefault(task_i, {})

    ncols = 4
    nrows = math.ceil(len(svm_regs) / ncols)
    fig, axes = plt.subplots(
        nrows + 2,
        ncols,
        figsize=set_size(doc_width_pt, 0.95, (nrows + 2, ncols), False),
        layout="constrained",
    )

    context_inputs = res[task_i]["data"]["context_inputs"]
    context_outputs = res[task_i]["data"]["context_outputs"]

    for idx, svm_reg in enumerate(svm_regs):
        if nrows == 1:
            ax = axes[idx]
        else:
            ax = axes[idx // ncols, idx % ncols]
        support_vectors = svm_info[task_i][svm_reg]["support_vectors"]
        support_labels = svm_info[task_i][svm_reg]["support_labels"]

        num_support_vectors = len(support_vectors)
        inputs, outputs = jax.vmap(get_new_pairs, in_axes=[None, None, 0, None])(
            support_vectors, support_labels, test_queries[:, None], context_len
        )

        # Use support vectors as context
        llm_preds, _ = llm_model.forward(
            params[CONST_MODEL_DICT][CONST_MODEL],
            inputs[:, [-1]],
            {
                CONST_CONTEXT_INPUT: inputs[:, :-1],
                CONST_CONTEXT_OUTPUT: outputs,
            },
        )
        if not query_pred_only:
            llm_preds = llm_preds[:, num_support_vectors]
        elif use_sigmoid:
            llm_preds = process_prediction(llm_preds)
        context_res[task_i]["llm"] = llm_preds
        context_res[task_i]["support_vectors"] = support_vectors
        context_res[task_i]["support_labels"] = support_labels

        llm_pred_labels = np.argmax(llm_preds, axis=-1)

        for possible_label in [0, 1]:
            idxes = np.where(llm_pred_labels == possible_label)
            if len(idxes[0]) == 0:
                continue
            ax.scatter(
                test_queries[idxes][:, 0],
                test_queries[idxes][:, 1],
                label=f"{possible_label}" if model_i == 0 else "",
                s=5,
            )

        ax.plot(
            np.array(input_range),
            res[task_i]["gt"],
            label="Ground truth" if idx == 0 else "",
            color="red",
            alpha=0.3,
        )

        ax.scatter(
            context_inputs[:, 0],
            context_inputs[:, 1],
            c=context_outputs[:, -1],
            s=30,
            cmap=plt.cm.Paired,
        )
        DecisionBoundaryDisplay.from_estimator(
            res[task_i]["svm"][svm_reg],
            context_inputs,
            ax=ax,
            grid_resolution=50,
            plot_method="contour",
            colors="k",
            levels=[-1, 0, 1],
            alpha=0.5,
            linestyles=["--", "-", "--"],
        )
        ax.scatter(
            svm_info[task_i][svm_reg]["support_vectors"][:, 0],
            svm_info[task_i][svm_reg]["support_vectors"][:, 1],
            s=100,
            linewidth=1,
            facecolors="none",
            edgecolors="k",
        )
        ax.set_xlim(input_range[0] - 0.01, input_range[1] + 0.01)
        ax.set_ylim(input_range[0] - 0.01, input_range[1] + 0.01)
        ax.set_title(f"Reg. Coef.: {svm_reg}")

    for idx, k in enumerate([1, 3, 5, 7]):
        if nrows == 1:
            ax = axes[idx]
        else:
            ax = axes[nrows, idx % ncols]

        knn = make_knn(context_inputs, context_outputs, k)
        knn_preds = knn.predict(test_queries)

        for possible_label in [0, 1]:
            idxes = np.where(knn_preds == possible_label)
            if len(idxes[0]) == 0:
                continue
            ax.scatter(
                test_queries[idxes][:, 0],
                test_queries[idxes][:, 1],
                label=f"{possible_label}" if model_i == 0 else "",
                s=5,
            )
        ax.set_xlim(input_range[0] - 0.01, input_range[1] + 0.01)
        ax.set_ylim(input_range[0] - 0.01, input_range[1] + 0.01)
        ax.set_title(f"K: {k}")

    for idx, k in enumerate([1, 3, 5, 7]):
        if nrows == 1:
            ax = axes[idx]
        else:
            ax = axes[nrows + 1, idx % ncols]

        ax.set_title(f"K: {k}")
        ax.set_xlim(input_range[0] - 0.01, input_range[1] + 0.01)
        ax.set_ylim(input_range[0] - 0.01, input_range[1] + 0.01)
        if len(support_vectors) < k:
            continue

        knn = make_knn(support_vectors, support_labels, k)
        knn_preds = knn.predict(test_queries)
        for possible_label in [0, 1]:
            idxes = np.where(knn_preds == possible_label)
            if len(idxes[0]) == 0:
                continue
            ax.scatter(
                test_queries[idxes][:, 0],
                test_queries[idxes][:, 1],
                label=f"{possible_label}" if model_i == 0 else "",
                s=5,
            )

    fig.legend(
        bbox_to_anchor=(0.0, 1.0, 1.0, 0.0),
        loc="lower center",
        ncols=3,
        borderaxespad=0.0,
        frameon=True,
        fontsize="8",
    )
    fig.supxlabel("$x_1$")
    fig.supylabel("$x_2$")
    plt.tight_layout()
    plt.show()

In [None]:
support_vector_losses = {}

for task_i in range(num_tasks):
    model_classes = {
        "SVM": res[task_i]["svm"],
        "LR": res[task_i]["lr"],
        "KNN": res[task_i]["knn"],
    }

    llm_preds = np.argmax(context_res[task_i]["llm"], axis=1)
    for model_class, models in model_classes.items():
        for reg_coef, model in models.items():
            support_vector_losses.setdefault((model_class, reg_coef), [])
            support_vector_losses[(model_class, reg_coef)].append(
                np.mean(model.predict(test_queries) == llm_preds)
            )

for model_info, curr_loss in support_vector_losses.items():
    print(
        "{}: {:.2f}% +/- {:.2f}".format(
            model_info, np.mean(curr_loss) * 100, np.std(curr_loss * 100)
        )
    )

### Permute Support Vectors

In [None]:
permute_res = {}

num_permutations = 4
perm_seed = 9999

permutation_keys = jrandom.split(jrandom.PRNGKey(perm_seed), 4)

for task_i in range(num_tasks):
    permute_res.setdefault(task_i, {})

    ncols = 4
    nrows = math.ceil((len(svm_regs) * num_permutations) / ncols)
    fig, axes = plt.subplots(
        nrows,
        ncols,
        figsize=set_size(doc_width_pt, 0.95, (nrows, ncols), False),
        layout="constrained",
    )

    context_inputs = res[task_i]["data"]["context_inputs"]
    context_outputs = res[task_i]["data"]["context_outputs"]

    ax_idx = 0
    for idx, svm_reg in enumerate(svm_regs):
        for permutation_key in permutation_keys:
            if nrows == 1:
                ax = axes[ax_idx]
            else:
                ax = axes[ax_idx // ncols, ax_idx % ncols]
            support_vectors = svm_info[task_i][svm_reg]["support_vectors"]
            support_labels = svm_info[task_i][svm_reg]["support_labels"]

            perm_idxes = jrandom.permutation(
                permutation_key, x=np.arange(len(support_vectors))
            )
            support_vectors = support_vectors[perm_idxes]
            support_labels = support_labels[perm_idxes]

            num_support_vectors = len(support_vectors)
            inputs, outputs = jax.vmap(get_new_pairs, in_axes=[None, None, 0, None])(
                support_vectors, support_labels, test_queries[:, None], context_len
            )

            # Use support vectors as context
            llm_preds, _ = llm_model.forward(
                params[CONST_MODEL_DICT][CONST_MODEL],
                inputs[:, [-1]],
                {
                    CONST_CONTEXT_INPUT: inputs[:, :-1],
                    CONST_CONTEXT_OUTPUT: outputs,
                },
            )
            if not query_pred_only:
                llm_preds = llm_preds[:, num_support_vectors]
            elif use_sigmoid:
                llm_preds = process_prediction(llm_preds)
            permute_res[task_i]["llm"] = llm_preds
            permute_res[task_i]["support_vectors"] = support_vectors
            permute_res[task_i]["support_labels"] = support_labels

            llm_pred_labels = np.argmax(llm_preds, axis=-1)

            for possible_label in [0, 1]:
                idxes = np.where(llm_pred_labels == possible_label)
                if len(idxes[0]) == 0:
                    continue
                ax.scatter(
                    test_queries[idxes][:, 0],
                    test_queries[idxes][:, 1],
                    label=f"{possible_label}" if ax_idx == 0 else "",
                    s=5,
                )

            ax.plot(
                np.array(input_range),
                res[task_i]["gt"],
                label="Ground truth" if ax_idx == 0 else "",
                color="red",
                alpha=0.3,
            )

            ax.scatter(
                context_inputs[:, 0],
                context_inputs[:, 1],
                c=context_outputs[:, -1],
                s=30,
                cmap=plt.cm.Paired,
            )
            DecisionBoundaryDisplay.from_estimator(
                res[task_i]["svm"][svm_reg],
                context_inputs,
                ax=ax,
                grid_resolution=50,
                plot_method="contour",
                colors="k",
                levels=[-1, 0, 1],
                alpha=0.5,
                linestyles=["--", "-", "--"],
            )
            ax.scatter(
                svm_info[task_i][svm_reg]["support_vectors"][:, 0],
                svm_info[task_i][svm_reg]["support_vectors"][:, 1],
                s=100,
                linewidth=1,
                facecolors="none",
                edgecolors="k",
            )
            ax.set_xlim(input_range[0] - 0.01, input_range[1] + 0.01)
            ax.set_ylim(input_range[0] - 0.01, input_range[1] + 0.01)
            ax.set_title(f"Reg. Coef.: {svm_reg}")
            ax_idx += 1

    fig.legend(
        bbox_to_anchor=(0.0, 1.0, 1.0, 0.0),
        loc="lower center",
        ncols=3,
        borderaxespad=0.0,
        frameon=True,
        fontsize="8",
    )
    fig.supxlabel("$x_1$")
    fig.supylabel("$x_2$")
    plt.tight_layout()
    plt.show()

## Examplar Length

In [None]:
ex_len_res = {
    "llm": {},
    "lr": {},
    "svm": {},
}
loss_per_examplar_len = {}

for task_i in range(num_tasks):
    context_inputs = res[task_i]["data"]["context_inputs"]
    context_outputs = res[task_i]["data"]["context_outputs"]

    for examplar_len in range(2, context_len):
        for model_class in ex_len_res:
            ex_len_res[model_class].setdefault(examplar_len, [])
        loss_per_examplar_len.setdefault(examplar_len, [])
        inputs, outputs = jax.vmap(get_new_pairs, in_axes=[None, None, 0, None])(
            context_inputs[:examplar_len],
            context_outputs[:examplar_len],
            test_queries[:, None],
            context_len,
        )

        # Use support vectors as context
        llm_preds, _ = llm_model.forward(
            params[CONST_MODEL_DICT][CONST_MODEL],
            inputs[:, [-1]],
            {
                CONST_CONTEXT_INPUT: inputs[:, :-1],
                CONST_CONTEXT_OUTPUT: outputs,
            },
        )
        if not query_pred_only:
            llm_preds = llm_preds[:, examplar_len]
        elif use_sigmoid:
            llm_preds = process_prediction(llm_preds)
        loss = np.mean(
            np.argmax(llm_preds, axis=-1)
            != np.argmax(model_preds[task_i]["gt"], axis=-1)
        )
        ex_len_res["llm"][examplar_len].append(loss)

        if len(np.unique(np.argmax(context_outputs[:examplar_len], axis=-1))) > 1:
            svm = make_svm(
                context_inputs[:examplar_len],
                context_outputs[:examplar_len],
                svm_regs[-1],
            )
            preds = svm.predict(test_queries)
            loss = np.mean(preds != np.argmax(model_preds[task_i]["gt"], axis=-1))
            ex_len_res["svm"][examplar_len].append(loss)
            lr = make_lr(
                context_inputs[:examplar_len],
                context_outputs[:examplar_len],
                "l2",
                lr_regs[-1],
            )
            preds = lr.predict(test_queries)
            loss = np.mean(preds != np.argmax(model_preds[task_i]["gt"], axis=-1))
            ex_len_res["lr"][examplar_len].append(loss)

        loss_per_examplar_len[examplar_len].append(loss)

In [None]:
for examplar_len, losses in loss_per_examplar_len.items():
    print(
        "Examplar Length: {}, Loss - mean: {} std: {}".format(
            examplar_len, np.mean(losses), np.std(losses)
        )
    )

In [None]:
ncols = 1
nrows = 1
fig, ax = plt.subplots(
    nrows,
    ncols,
    figsize=set_size(750, 0.95, (nrows, ncols), True),
    layout="constrained",
)

for model_class, examplar_lens in ex_len_res.items():
    xs = np.array((list(examplar_lens.keys())))
    ys = np.array([np.mean(losses) for losses in examplar_lens.values()])
    stds = np.array([np.std(losses) for losses in examplar_lens.values()])
    sort_idxes = np.argsort(xs)
    xs = xs[sort_idxes]
    ys = ys[sort_idxes]
    ax.plot(xs, ys, label=model_class)
    ax.fill_between(xs, ys + stds, ys - stds, alpha=0.1)
ax.legend()
ax.set_ylabel("0-1 Loss")
ax.set_xlabel("Examplar Length")
fig.tight_layout()
plt.show()

## Check SVM in Representation Space

In [None]:
task_i = 0
context_inputs = res[task_i]["data"]["context_inputs"]
context_outputs = res[task_i]["data"]["context_outputs"]

print(test_queries.shape, context_inputs.shape)

inputs, outputs = jax.vmap(get_new_pairs, in_axes=[None, None, 0, None])(
    context_inputs, context_outputs, context_inputs[:, None], context_len
)

print(inputs.shape, outputs.shape)

repr, _ = llm_model.get_latent(
    params[CONST_MODEL_DICT][CONST_MODEL],
    inputs[:, [-1]],
    {
        CONST_CONTEXT_INPUT: inputs[:, :-1],
        CONST_CONTEXT_OUTPUT: outputs,
    },
)
print(repr.shape)

last_reprs = repr[:, -1]
print(last_reprs.shape)

In [None]:
train_y = np.argmax(context_outputs, axis=-1)
train_y[train_y == 0] = -1

In [None]:
loss, primal_sol_repr = primal_svm(last_reprs, train_y)
print(loss)

loss, primal_sol_input = primal_svm(context_inputs, train_y)
print(loss)

In [None]:
loss, dual_sol_repr = dual_svm(last_reprs, train_y)
print(loss)
print(np.where(np.abs(dual_sol_repr) > 1e-10))

loss, dual_sol_input = dual_svm(context_inputs, train_y)
print(loss)
print(np.where(np.abs(dual_sol_input) > 1e-10))

Quickly sanity check support vectors with Scikit-Learn vs CVXPY

In [None]:
repr_svm = make_svm(last_reprs, context_outputs, 1000.0)
decision_function = repr_svm.decision_function(last_reprs)
support_vector_indices = np.where(np.abs(decision_function) <= 1)[0]
print(support_vector_indices)
print(svm_info[task_i][svm_regs[-1]]["support_vector_indices"])

In [None]:
print(*primal_sol_input)

In [None]:
recovered_weights = np.sum((dual_sol_input * train_y)[:, None] * context_inputs, axis=0)

support_vector_index = np.where(np.abs(dual_sol_input) > 1e-10)[0][0]
recovered_bias = train_y[support_vector_index] - np.dot(
    recovered_weights, context_inputs[support_vector_index]
)
print(*recovered_weights, recovered_bias)

Recover the primal parameters from the dual derived in representation space

In [None]:
recovered_weights = np.sum((dual_sol_repr * train_y)[:, None] * context_inputs, axis=0)

support_vector_index = np.where(np.abs(dual_sol_repr) > 1e-10)[0][0]
recovered_bias = train_y[support_vector_index] - np.dot(
    recovered_weights, context_inputs[support_vector_index]
)
print(*recovered_weights, recovered_bias)

Compare the dual parameters from the input and representation spaces

In [None]:
print(dual_sol_input)
print(dual_sol_repr)

In [None]:
if not use_sigmoid:
    (
        params[CONST_MODEL_DICT][CONST_MODEL][CONST_PREDICTOR]["params"]["kernel"][:, 0]
        - params[CONST_MODEL_DICT][CONST_MODEL][CONST_PREDICTOR]["params"]["kernel"][
            :, 1
        ]
    )

In [None]:
primal_sol_repr

Distance of primal parameters

In [None]:
if not use_sigmoid:
    np.linalg.norm(
        params[CONST_MODEL_DICT][CONST_MODEL][CONST_PREDICTOR]["params"]["kernel"][:, 0]
        - params[CONST_MODEL_DICT][CONST_MODEL][CONST_PREDICTOR]["params"]["kernel"][
            :, 1
        ]
        - primal_sol_repr[:-1]
    )