In [None]:
import _pickle as pickle
import jax
import math
import matplotlib.pyplot as plt
import numpy as np
import os

from jaxl.constants import *
from jaxl.plot_utils import set_size

In [None]:
doc_width_pt = 500.0
baseline_path = "/home/bryanpu1/projects/jaxl/scripts/icl/results-num_blocks_8-num_tasks_5-seq_len_16-seed_9999-10-08-23_12_36_30/"

context_data = pickle.load(
    open(os.path.join(baseline_path, "context_data.pkl"), "rb")
)

gt = pickle.load(
    open(os.path.join(baseline_path, "ground_truth.pkl"), "rb")
)

agent_reprs = pickle.load(
    open(os.path.join(baseline_path, "agent_reprs.pkl"), "rb")
)

config = pickle.load(
    open(os.path.join(baseline_path, "config.pkl"), "rb")
)
input_range = config["input_range"]

# Visualize Representation SVM
Visualize SVM trained in the represnetation space induced by the transformer by mapping the SVM prediction back onto the input space.
- `context_reprs` corresponds to the representation induced by feeding in each of the context samples into the query token
- `input_token_context_reprs` corresponds to taking the context token directly

In [None]:
nrows = 1
ncols = 2
for task_i in context_data:
    context_inputs = context_data[task_i][CONST_CONTEXT_INPUT]
    context_outputs = context_data[task_i][CONST_CONTEXT_OUTPUT]
    train_y = np.argmax(context_outputs, axis=-1)

    for agent_path, agent_repr in agent_reprs.items():
        print("Processing agent {}".format(agent_path))
        fig, axes = plt.subplots(
            nrows,
            ncols,
            figsize=set_size(doc_width_pt, 0.95, (nrows, ncols), False),
            layout="constrained",
        )
        query_input = agent_repr["query_reprs"][task_i]
        for ax_i, repr_key in enumerate(["context_reprs", "input_token_context_reprs"]):
            ax = axes[ax_i]
            params = agent_repr["svms"][task_i][repr_key]["primal"]["sol"]
            svm_preds = (
                (
                    np.array(query_input) @ params[:-1]
                    + params[-1:]
                )
                >= 0
            ).astype(int)
            
            for possible_label in [0, 1]:
                idxes = np.where(svm_preds == possible_label)
                ax.scatter(
                    gt["inputs"][idxes][:, 0],
                    gt["inputs"][idxes][:, 1],
                    label=f"{possible_label}" if ax_i == 0 else "",
                    s=5,
                )
                idxes = np.where(train_y == possible_label)
                ax.scatter(
                    context_inputs[idxes][:, 0],
                    context_inputs[idxes][:, 1],
                    label=f"{possible_label}" if ax_i == 0 else "",
                    s=15,
                    marker="x" if possible_label else "o",
                    c="black"
                )
            ax.plot(input_range, gt["decision_boundary"][task_i], color="black")
            ax.set_xlim(input_range[0] - 0.01, input_range[1] + 0.01)
            ax.set_ylim(input_range[0] - 0.01, input_range[1] + 0.01)
            ax.set_title("{}".format(repr_key))

        fig.legend(
            bbox_to_anchor=(0.0, 1.0, 1.0, 0.0),
            loc="lower center",
            ncols=4,
            borderaxespad=0.0,
            frameon=True,
            fontsize="8",
        )
        fig.supxlabel("$x_1$")
        fig.supylabel("$x_2$")
        plt.show()

# Finding Support Vectors in Representation Space