In [None]:
import _pickle as pickle
import jax
import math
import matplotlib.pyplot as plt
import numpy as np
import os

from jaxl.constants import *
from jaxl.plot_utils import set_size

In [None]:
doc_width_pt = 750.0
baseline_path = "/home/bryanpu1/projects/jaxl/scripts/icl/results-num_blocks_8-num_tasks_5-seq_len_16-seed_9999-10-11-23_20_07_56"

context_data = pickle.load(
    open(os.path.join(baseline_path, "context_data.pkl"), "rb")
)

gt = pickle.load(
    open(os.path.join(baseline_path, "ground_truth.pkl"), "rb")
)

agent_reprs = pickle.load(
    open(os.path.join(baseline_path, "agent_reprs.pkl"), "rb")
)

agent_results = pickle.load(
    open(os.path.join(baseline_path, "agents.pkl"), "rb")
)

config = pickle.load(
    open(os.path.join(baseline_path, "config.pkl"), "rb")
)

baseline_results = pickle.load(
    open(os.path.join(baseline_path, "baseline_results.pkl"), "rb")
)

input_range = config["input_range"]

# Visualize ICL Transformer

In [None]:
ncols = 4
for task_i in context_data:
    context_inputs = context_data[task_i][CONST_CONTEXT_INPUT]
    context_outputs = context_data[task_i][CONST_CONTEXT_OUTPUT]
    train_y = np.argmax(context_outputs, axis=-1)

    for agent_path, agent_result in agent_results.items():
        print("Processing agent {}".format(agent_path))

        svm_primal = agent_reprs[agent_path]["svms"][task_i]["input"]["primal"]["sol"]
        svm_db = -(np.array(input_range) * svm_primal[0] + svm_primal[2]) / svm_primal[1]

        primal_out = (context_inputs @ svm_primal[:-1]) + svm_primal[-1]
        primal_constraints = 2 * (0.5 - (1 - train_y)) * primal_out
        print(primal_constraints)
        support_vectors = context_inputs[np.where(primal_constraints <= 1)[0]]

        poss = np.where(primal_out > 0)[0]
        negs = np.where(primal_out < 0)[0]
        closest_pos = context_inputs[poss[np.argmin(primal_out[poss])]]
        closest_neg = context_inputs[negs[np.argmax(primal_out[negs])]]

        per_task_results = agent_result[task_i]
        nrows = math.ceil(len(per_task_results) / ncols)
        fig, axes = plt.subplots(
            nrows,
            ncols,
            figsize=set_size(doc_width_pt, 0.95, (nrows, ncols), False),
            layout="constrained",
        )
        
        for ax_i, examplar_len in enumerate(per_task_results):
            mask = (np.arange(len(context_inputs)) >= len(context_inputs) - examplar_len)
            if nrows == 1:
                ax = axes[ax_i]
            else:
                ax = axes[ax_i // ncols, ax_i % ncols]
            one_hot_preds = per_task_results[examplar_len]
            preds = np.argmax(one_hot_preds, axis=-1)
            
            for possible_label in [0, 1]:
                idxes = np.where(preds == possible_label)
                ax.scatter(
                    gt["inputs"][idxes][:, 0],
                    gt["inputs"][idxes][:, 1],
                    label=f"{possible_label}" if ax_i == 0 else "",
                    s=5,
                    alpha=0.1,
                )
                idxes = np.where(train_y[mask] == possible_label)
                ax.scatter(
                    context_inputs[mask][idxes][:, 0],
                    context_inputs[mask][idxes][:, 1],
                    label=f"{possible_label}" if ax_i == 0 else "",
                    s=15,
                    marker="x" if possible_label else "o",
                    c="black"
                )

            ax.scatter(
                support_vectors[:, 0],
                support_vectors[:, 1],
                label=f"support vector" if ax_i == 0 else "",
                s=70,
                facecolors="none",
                edgecolors="black",
            )

            ax.scatter(
                closest_pos[0],
                closest_pos[1],
                label=f"closest +" if ax_i == 0 else "",
                s=70,
                marker="^"
            )

            ax.scatter(
                closest_neg[0],
                closest_neg[1],
                label=f"closest -" if ax_i == 0 else "",
                s=70,
                marker="^"
            )
            
            ax.plot(
                input_range,
                gt["decision_boundary"][task_i],
                color="gray",
                label="gt" if ax_i == 0 else ""
            )
            ax.plot(
                input_range,
                svm_db,
                color="black",
                label="svm" if ax_i == 0 else ""
            )
            ax.set_xlim(input_range[0] - 0.01, input_range[1] + 0.01)
            ax.set_ylim(input_range[0] - 0.01, input_range[1] + 0.01)
            ax.set_title("Ex. Len.: {}".format(examplar_len))
            ax.grid(False)

        fig.legend(
            bbox_to_anchor=(0.0, 1.0, 1.0, 0.0),
            loc="lower center",
            ncols=10,
            borderaxespad=0.0,
            frameon=True,
            fontsize="8",
        )
        fig.supxlabel("$x_1$")
        fig.supylabel("$x_2$")
        plt.show()

# Visualize Representation SVM
Visualize SVM trained in the represnetation space induced by the transformer by mapping the SVM prediction back onto the input space.
- `context_reprs` corresponds to the representation induced by feeding in each of the context samples into the query token
- `input_token_context_reprs` corresponds to taking the context token directly

In [None]:
nrows = 1
ncols = 2
for task_i in context_data:
    context_inputs = context_data[task_i][CONST_CONTEXT_INPUT]
    context_outputs = context_data[task_i][CONST_CONTEXT_OUTPUT]
    train_y = np.argmax(context_outputs, axis=-1)

    for agent_path, agent_repr in agent_reprs.items():
        print("Processing agent {}".format(agent_path))

        fig, axes = plt.subplots(
            nrows,
            ncols,
            figsize=set_size(doc_width_pt, 0.95, (nrows, ncols), False),
            layout="constrained",
        )
        query_input = agent_repr["query_reprs"][task_i]
        for ax_i, repr_key in enumerate(["context_reprs", "input_token_context_reprs"]):

            svm_primal = agent_reprs[agent_path]["svms"][task_i][repr_key]["primal"]["sol"]
            svm_db = -(np.array(input_range) * svm_primal[0] + svm_primal[2]) / svm_primal[1]

            primal_out = (np.array(agent_repr[repr_key][task_i]) @ svm_primal[:-1]) + svm_primal[-1]
            primal_constraints = 2 * (0.5 - (1 - train_y)) * primal_out
            print(primal_out)
            support_vectors = context_inputs[np.where(primal_constraints <= 1)[0]]
            
            poss = np.where(primal_out > 0)[0]
            negs = np.where(primal_out < 0)[0]
            closest_pos = context_inputs[poss[np.argmin(primal_out[poss])]]
            closest_neg = context_inputs[negs[np.argmax(primal_out[negs])]]
            
            ax = axes[ax_i]
            params = agent_repr["svms"][task_i][repr_key]["primal"]["sol"]
            svm_preds = (
                (
                    np.array(query_input) @ params[:-1]
                    + params[-1:]
                )
                >= 0
            ).astype(int)
            
            for possible_label in [0, 1]:
                idxes = np.where(svm_preds == possible_label)
                ax.scatter(
                    gt["inputs"][idxes][:, 0],
                    gt["inputs"][idxes][:, 1],
                    label=f"{possible_label}" if ax_i == 0 else "",
                    s=5,
                    alpha=0.1
                )
            for possible_label in [0, 1]:
                idxes = np.where(train_y == possible_label)
                ax.scatter(
                    context_inputs[idxes][:, 0],
                    context_inputs[idxes][:, 1],
                    label=f"{possible_label}" if ax_i == 0 else "",
                    s=15,
                    marker="x" if possible_label else "o",
                    c="black"
                )

            ax.scatter(
                support_vectors[:, 0],
                support_vectors[:, 1],
                label=f"support vector" if ax_i == 0 else "",
                s=70,
                facecolors="none",
                edgecolors="black",
            )

            ax.scatter(
                closest_pos[0],
                closest_pos[1],
                label=f"closest +" if ax_i == 0 else "",
                s=70,
                marker="^"
            )

            ax.scatter(
                closest_neg[0],
                closest_neg[1],
                label=f"closest -" if ax_i == 0 else "",
                s=70,
                marker="^"
            )
            
            ax.plot(
                input_range,
                gt["decision_boundary"][task_i],
                color="gray",
                label="gt" if ax_i == 0 else ""
            )

            ax.set_xlim(input_range[0] - 0.01, input_range[1] + 0.01)
            ax.set_ylim(input_range[0] - 0.01, input_range[1] + 0.01)
            ax.set_title("{}".format(repr_key))
            ax.grid(False)

        fig.legend(
            bbox_to_anchor=(0.0, 1.0, 1.0, 0.0),
            loc="lower center",
            ncols=4,
            borderaxespad=0.0,
            frameon=True,
            fontsize="8",
        )
        fig.supxlabel("$x_1$")
        fig.supylabel("$x_2$")
        plt.show()

# Analyzing Dual

In [None]:
import cvxpy as cp

def recover_dual_from_primal(primal_sol, train_x, train_y):
    assert len(train_x.shape) == 2
    assert len(train_y.shape) == 1

    N, _ = train_x.shape

    print("Primal solution: {}, shape: {}".format(primal_sol, primal_sol.shape))

    primal_out = (train_x @ primal_sol[:-1]) + primal_sol[-1]
    primal_constraints = 2 * (0.5 - (1 - train_y)) * primal_out
    coefs = 1 - primal_constraints
    coefs[np.isclose(coefs, 0)] = 0
    G = np.eye(N)
    h = np.zeros(N)
    dual_var = cp.Variable(N)
    prob = cp.Problem(
        cp.Maximize(coefs.T @ dual_var),
        [G @ dual_var >= h],
    )
    loss = prob.solve(verbose=False)
    alphas = dual_var.value
    return loss, alphas

In [None]:
from pprint import pprint

for task_i in context_data:
    context_inputs = context_data[task_i][CONST_CONTEXT_INPUT]
    context_outputs = context_data[task_i][CONST_CONTEXT_OUTPUT]
    train_y = np.argmax(context_outputs, axis=-1)

    for agent_path, agent_repr in agent_reprs.items():

        print("=" * 50)
        print("Processing agent {}".format(agent_path))
        print("-" * 50)
        print("input ")
        print("-" * 50)
        # pprint(agent_repr["svms"][task_i]["input"]["dual"])
        loss, alphas = recover_dual_from_primal(
            agent_repr["svms"][task_i]["input"]["primal"]["sol"],
            context_inputs,
            train_y
        )
        # pprint({
        #     "loss": loss,
        #     "sol": alphas,
        # })
        # print(np.argsort(agent_repr["svms"][task_i]["input"]["dual"]["sol"]))
        sort_idxes = np.argsort(alphas)
        print("Primal loss: {}, Dual loss: {}".format(
            agent_repr["svms"][task_i]["input"]["primal"]["loss"],
            loss
        ))
        print("Dual solution (sorted): {}".format(alphas[sort_idxes]))
        print("Sort indices: {}".format(sort_idxes))
        for repr_key in ["context_reprs", "input_token_context_reprs"]:
            print("-" * 50)
            print(repr_key)
            print("-" * 50)
            # pprint(agent_repr["svms"][task_i][repr_key]["dual"])
            loss, alphas = recover_dual_from_primal(
                np.array(agent_repr["svms"][task_i][repr_key]["primal"]["sol"]),
                np.array(agent_repr[repr_key][task_i]),
                train_y
            )
            # pprint({
            #     "loss": loss,
            #     "sol": alphas,
            # })
            # print(np.argsort(agent_repr["svms"][task_i][repr_key]["dual"]["sol"]))
            sort_idxes = np.argsort(alphas)
            print("Primal loss: {}, Dual loss: {}".format(
                agent_repr["svms"][task_i][repr_key]["primal"]["loss"],
                loss
            ))
            print("Dual solution (sorted): {}".format(alphas[sort_idxes]))
            print("Sort indices: {}".format(sort_idxes))
        print("")