In [None]:
import _pickle as pickle
import jax
import math
import matplotlib.pyplot as plt
import numpy as np
import os

from jaxl.constants import *
from jaxl.plot_utils import set_size

os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [None]:
doc_width_pt = 750.0
# baseline_path = "/home/bryanpu1/projects/jaxl/scripts/icl/results-num_blocks_8-num_tasks_5-seq_len_16-seed_9999-10-11-23_20_07_56"
# baseline_path = "/home/bryanpu1/projects/jaxl/scripts/icl/results-num_blocks_8-seq_len_20-num_tasks_5-seq_len_20-seed_9999-10-13-23_10_30_57/"
baseline_path = "/home/bryanpu1/projects/jaxl/scripts/icl/results-num_blocks_8-smaller_delta-num_tasks_5-seq_len_20-seed_9999-10-17-23_22_07_47"

context_data = pickle.load(
    open(os.path.join(baseline_path, "context_data.pkl"), "rb")
)

gt = pickle.load(
    open(os.path.join(baseline_path, "ground_truth.pkl"), "rb")
)

agent_reprs = pickle.load(
    open(os.path.join(baseline_path, "agent_reprs.pkl"), "rb")
)

agent_results = pickle.load(
    open(os.path.join(baseline_path, "agents.pkl"), "rb")
)

config = pickle.load(
    open(os.path.join(baseline_path, "config.pkl"), "rb")
)

baseline_results = pickle.load(
    open(os.path.join(baseline_path, "baseline_results.pkl"), "rb")
)

input_range = config["input_range"]

# Visualize ICL Transformer

In [None]:
ncols = 4
for task_i in context_data:
    context_inputs = context_data[task_i][CONST_CONTEXT_INPUT]
    context_outputs = context_data[task_i][CONST_CONTEXT_OUTPUT]
    train_y = np.argmax(context_outputs, axis=-1)

    for agent_path, agent_result in agent_results.items():
        print("Processing agent {}".format(agent_path))

        svm_primal = agent_reprs[agent_path]["svms"][task_i]["input"]["primal"]["sol"]
        svm_db = -(np.array(input_range) * svm_primal[0] + svm_primal[2]) / svm_primal[1]

        primal_out = (context_inputs @ svm_primal[:-1]) + svm_primal[-1]
        primal_constraints = 2 * (0.5 - (1 - train_y)) * primal_out
        print(primal_constraints)
        support_vectors = context_inputs[np.where(primal_constraints <= 1 + 1e-5)[0]]

        poss = np.where(primal_out > 0)[0]
        negs = np.where(primal_out < 0)[0]
        closest_pos = context_inputs[poss[np.argmin(primal_out[poss])]]
        closest_neg = context_inputs[negs[np.argmax(primal_out[negs])]]

        per_task_results = agent_result[task_i]
        nrows = math.ceil(len(per_task_results["examplar_length"]) / ncols)
        fig, axes = plt.subplots(
            nrows,
            ncols,
            figsize=set_size(doc_width_pt, 0.95, (nrows, ncols), False),
            layout="constrained",
        )
        
        for ax_i, examplar_len in enumerate(per_task_results["examplar_length"]):
            mask = (np.arange(len(context_inputs)) >= len(context_inputs) - examplar_len)
            if nrows == 1:
                ax = axes[ax_i]
            else:
                ax = axes[ax_i // ncols, ax_i % ncols]
            one_hot_preds = per_task_results["examplar_length"][examplar_len]
            preds = np.argmax(one_hot_preds, axis=-1)
            
            for possible_label in [0, 1]:
                idxes = np.where(preds == possible_label)
                ax.scatter(
                    gt["inputs"][idxes][:, 0],
                    gt["inputs"][idxes][:, 1],
                    label=f"{possible_label}" if ax_i == 0 else "",
                    s=5,
                    alpha=0.1,
                )
                idxes = np.where(train_y[mask] == possible_label)
                ax.scatter(
                    context_inputs[mask][idxes][:, 0],
                    context_inputs[mask][idxes][:, 1],
                    label=f"{possible_label}" if ax_i == 0 else "",
                    s=15,
                    marker="x" if possible_label else "o",
                    c="black"
                )

            ax.scatter(
                support_vectors[:, 0],
                support_vectors[:, 1],
                label=f"support vector" if ax_i == 0 else "",
                s=70,
                facecolors="none",
                edgecolors="black",
            )

            ax.scatter(
                closest_pos[0],
                closest_pos[1],
                label=f"closest +" if ax_i == 0 else "",
                s=70,
                marker="^"
            )

            ax.scatter(
                closest_neg[0],
                closest_neg[1],
                label=f"closest -" if ax_i == 0 else "",
                s=70,
                marker="^"
            )
            
            ax.plot(
                input_range,
                gt["decision_boundary"][task_i],
                color="gray",
                label="gt" if ax_i == 0 else ""
            )
            ax.plot(
                input_range,
                svm_db,
                color="black",
                label="svm" if ax_i == 0 else ""
            )
            ax.set_xlim(input_range[0] - 0.01, input_range[1] + 0.01)
            ax.set_ylim(input_range[0] - 0.01, input_range[1] + 0.01)
            ax.set_title("Ex. Len.: {}".format(examplar_len))
            ax.grid(False)

        fig.legend(
            bbox_to_anchor=(0.0, 1.0, 1.0, 0.0),
            loc="lower center",
            ncols=10,
            borderaxespad=0.0,
            frameon=True,
            fontsize="8",
        )
        fig.supxlabel("$x_1$")
        fig.supylabel("$x_2$")
        plt.show()

# Visualize Representation SVM
Visualize SVM trained in the represnetation space induced by the transformer by mapping the SVM prediction back onto the input space.
- `context_reprs` corresponds to the representation induced by feeding in each of the context samples into the query token
- `input_token_context_reprs` corresponds to taking the context token directly

In [None]:
nrows = 1
ncols = 2
for task_i in context_data:
    context_inputs = context_data[task_i][CONST_CONTEXT_INPUT]
    context_outputs = context_data[task_i][CONST_CONTEXT_OUTPUT]
    train_y = np.argmax(context_outputs, axis=-1)

    for agent_path, agent_repr in agent_reprs.items():
        print("Processing agent {}".format(agent_path))

        fig, axes = plt.subplots(
            nrows,
            ncols,
            figsize=set_size(doc_width_pt, 0.95, (nrows, ncols), False),
            layout="constrained",
        )
        query_input = agent_repr["query_reprs"][task_i]
        for ax_i, repr_key in enumerate(["query_context_reprs", "input_context_reprs"]):

            svm_primal = agent_repr["svms"][task_i][repr_key]["primal"]["sol"]
            svm_db = -(np.array(input_range) * svm_primal[0] + svm_primal[2]) / svm_primal[1]

            primal_out = (np.array(agent_repr[repr_key][task_i]) @ svm_primal[:-1]) + svm_primal[-1]
            primal_constraints = 2 * (0.5 - (1 - train_y)) * primal_out
            print(primal_out)
            support_vectors = context_inputs[np.where(primal_constraints <= 1 + 1e-5)[0]]
            
            poss = np.where(primal_out > 0)[0]
            negs = np.where(primal_out < 0)[0]
            closest_pos_dist = np.min(primal_out[poss])
            closest_neg_dist = np.max(primal_out[negs])
            closest_pos = context_inputs[np.isclose(closest_pos_dist, primal_out, atol=0.1)]
            closest_neg = context_inputs[np.isclose(closest_neg_dist, primal_out, atol=0.1)]
            
            ax = axes[ax_i]
            svm_preds = (
                (
                    np.array(query_input) @ svm_primal[:-1]
                    + svm_primal[-1:]
                )
                >= 0
            ).astype(int)
            
            for possible_label in [0, 1]:
                idxes = np.where(svm_preds == possible_label)
                ax.scatter(
                    gt["inputs"][idxes][:, 0],
                    gt["inputs"][idxes][:, 1],
                    label=f"{possible_label}" if ax_i == 0 else "",
                    s=5,
                    alpha=0.1
                )
            for possible_label in [0, 1]:
                idxes = np.where(train_y == possible_label)
                ax.scatter(
                    context_inputs[idxes][:, 0],
                    context_inputs[idxes][:, 1],
                    label=f"{possible_label}" if ax_i == 0 else "",
                    s=15,
                    marker="x" if possible_label else "o",
                    c="black"
                )

            ax.scatter(
                support_vectors[:, 0],
                support_vectors[:, 1],
                label=f"support vector" if ax_i == 0 else "",
                s=70,
                facecolors="none",
                edgecolors="black",
            )

            ax.scatter(
                closest_pos[:, 0],
                closest_pos[:, 1],
                label=f"closest +" if ax_i == 0 else "",
                s=70,
                marker="^"
            )

            ax.scatter(
                closest_neg[:, 0],
                closest_neg[:, 1],
                label=f"closest -" if ax_i == 0 else "",
                s=70,
                marker="^"
            )
            
            ax.plot(
                input_range,
                gt["decision_boundary"][task_i],
                color="gray",
                label="gt" if ax_i == 0 else ""
            )

            ax.set_xlim(input_range[0] - 0.01, input_range[1] + 0.01)
            ax.set_ylim(input_range[0] - 0.01, input_range[1] + 0.01)
            ax.set_title("{}".format(repr_key))
            ax.grid(False)

        fig.legend(
            bbox_to_anchor=(0.0, 1.0, 1.0, 0.0),
            loc="lower center",
            ncols=4,
            borderaxespad=0.0,
            frameon=True,
            fontsize="8",
        )
        fig.supxlabel("$x_1$")
        fig.supylabel("$x_2$")
        plt.show()

# Analyzing Support Vector from Transformer Embedding
Current observation:
- There does not seem to be gap in the classification, yet the context samples are classified incorrectly

In [None]:
nrows = 1
ncols = 3
for task_i in context_data:
    context_inputs = context_data[task_i][CONST_CONTEXT_INPUT]
    context_outputs = context_data[task_i][CONST_CONTEXT_OUTPUT]
    train_y = np.argmax(context_outputs, axis=-1)

    for agent_path, agent_repr in agent_reprs.items():
        print("Processing agent {}".format(agent_path))
        svm_primal = agent_reprs[agent_path]["svms"][task_i]["input"]["primal"]["sol"]
        svm_db = -(np.array(input_range) * svm_primal[0] + svm_primal[2]) / svm_primal[1]

        fig, axes = plt.subplots(
            nrows,
            ncols,
            figsize=set_size(doc_width_pt, 0.95, (nrows, ncols), False),
            layout="constrained",
        )
        
        per_task_results = agent_results[agent_path][task_i]
        one_hot_preds = per_task_results["examplar_length"][np.max(list(per_task_results["examplar_length"].keys()))]
        preds = np.argmax(one_hot_preds, axis=-1)
        
        sqrt_num_samples = int(math.sqrt(len(preds)))
        axes[2].imshow(preds.reshape((sqrt_num_samples, sqrt_num_samples))[::-1])
        axes[2].set_title("Transformer Decision Boundary")

        query_input = agent_repr["query_reprs"][task_i]
        for ax_i, repr_key in enumerate(["query_context_reprs", "input_context_reprs"]):

            embed_first_dim = np.array(agent_reprs[agent_path][repr_key][task_i])[:, 0]
            
            poss = np.where(embed_first_dim > 0)[0]
            negs = np.where(embed_first_dim < 0)[0]
            closest_pos_dist = np.min(embed_first_dim[poss])
            closest_neg_dist = np.max(embed_first_dim[negs])
            closest_pos = context_inputs[np.isclose(closest_pos_dist, embed_first_dim, atol=0.1)]
            closest_neg = context_inputs[np.isclose(closest_neg_dist, embed_first_dim, atol=0.1)]
            
            ax = axes[ax_i]
            
            for possible_label in [0, 1]:
                idxes = np.where(preds == possible_label)
                ax.scatter(
                    gt["inputs"][idxes][:, 0],
                    gt["inputs"][idxes][:, 1],
                    label=f"{possible_label}" if ax_i == 0 else "",
                    s=5,
                    alpha=0.1
                )
            for possible_label in [0, 1]:
                idxes = np.where(train_y == possible_label)
                ax.scatter(
                    context_inputs[idxes][:, 0],
                    context_inputs[idxes][:, 1],
                    label=f"{possible_label}" if ax_i == 0 else "",
                    s=15,
                    marker="x" if possible_label else "o",
                    c="black"
                )

            ax.scatter(
                closest_pos[:, 0],
                closest_pos[:, 1],
                label=f"closest +" if ax_i == 0 else "",
                s=70,
                facecolors="none",
                edgecolors="red",
            )

            ax.scatter(
                closest_neg[:, 0],
                closest_neg[:, 1],
                label=f"closest -" if ax_i == 0 else "",
                s=70,
                facecolors="none",
                edgecolors="blue",
            )
            
            ax.plot(
                input_range,
                svm_db,
                color="black",
                label="svm" if ax_i == 0 else ""
            )

            ax.set_xlim(input_range[0] - 0.01, input_range[1] + 0.01)
            ax.set_ylim(input_range[0] - 0.01, input_range[1] + 0.01)
            ax.set_title("{}".format(repr_key))
            ax.grid(False)

        fig.legend(
            bbox_to_anchor=(0.0, 1.0, 1.0, 0.0),
            loc="lower center",
            ncols=4,
            borderaxespad=0.0,
            frameon=True,
            fontsize="8",
        )
        fig.supxlabel("$x_1$")
        fig.supylabel("$x_2$")
        plt.show()

# Permutation Experiment
- Permute -> embed
- Embed -> permute

Expectation: If the transformer end up being permutation invariant, then the first dimension of the representation space should be the same under the permutation of two operations.

In [None]:
nrows = 2
ncols = 3
for task_i in context_data:
    context_inputs = context_data[task_i][CONST_CONTEXT_INPUT]
    context_outputs = context_data[task_i][CONST_CONTEXT_OUTPUT]
    train_y = np.argmax(context_outputs, axis=-1)

    for agent_path, agent_repr in agent_reprs.items():
        print("Processing agent {}".format(agent_path))

        fig, axes = plt.subplots(
            nrows,
            ncols,
            figsize=set_size(doc_width_pt, 0.95, (nrows, ncols), False),
            layout="constrained",
        )
        
        per_task_results = agent_results[agent_path][task_i]
        one_hot_preds = per_task_results["examplar_length"][np.max(list(per_task_results["examplar_length"].keys()))]
        preds = np.argmax(one_hot_preds, axis=-1)

        query_input = agent_repr["query_reprs"][task_i]
        for ax_i, repr_key in enumerate(["query_context_reprs", "input_context_reprs"]):

            ax = axes[ax_i, 2]
            permute_idxes = np.array(agent_reprs[agent_path]["permutation"][repr_key][task_i]["permute_idxes"])
            ax.plot(
                [0, 1],
                [0, 1],
                transform=ax.transAxes,
                linestyle="--",
                color="black",
            )
            ax.scatter(
                np.array(agent_reprs[agent_path][repr_key][task_i])[:, 0],
                np.array(agent_reprs[agent_path]["permutation"][repr_key][task_i]["repr"])[:, 0],
                marker="x",
                color=["red" if idx < len(context_inputs) // 2 else "blue" for idx in range(len(context_inputs))]
            )
            print(np.concatenate((
                np.array(agent_reprs[agent_path][repr_key][task_i])[:, [0]],
                np.array(agent_reprs[agent_path]["permutation"][repr_key][task_i]["repr"])[:, [0]],
                permute_idxes[:, None]
            ), axis=-1))
            ax.set_xlabel("Embed then permute")
            ax.set_ylabel("Permute then embed")
            ax.set_title("distance based on first dim")

            for ii, is_permute in enumerate((False, True)):
                if is_permute:
                    embed_first_dim = np.array(agent_reprs[agent_path]["permutation"][repr_key][task_i]["repr"])[:, 0]
                else:
                    embed_first_dim = np.array(agent_reprs[agent_path][repr_key][task_i])[:, 0]
                
                poss = np.where(embed_first_dim > 0)[0]
                negs = np.where(embed_first_dim < 0)[0]
                closest_pos_dist = np.min(embed_first_dim[poss])
                closest_neg_dist = np.max(embed_first_dim[negs])
                closest_pos = context_inputs[np.isclose(closest_pos_dist, embed_first_dim, atol=0.1)]
                closest_neg = context_inputs[np.isclose(closest_neg_dist, embed_first_dim, atol=0.1)]
                
                ax = axes[ax_i, ii]

                print("{} - Permutation {} - In-context Prediction Accuracy: {}%".format(
                    repr_key,
                    is_permute,
                    np.mean((jax.nn.sigmoid(embed_first_dim) >= 0.5) == train_y) * 100
                ))
                
                for possible_label in [0, 1]:
                    idxes = np.where(preds == possible_label)
                    ax.scatter(
                        gt["inputs"][idxes][:, 0],
                        gt["inputs"][idxes][:, 1],
                        label=f"{possible_label}" if ax_i + ii == 0 else "",
                        s=5,
                        alpha=0.1
                    )
                for possible_label in [0, 1]:
                    idxes = np.where(train_y == possible_label)
                    ax.scatter(
                        context_inputs[idxes][:, 0],
                        context_inputs[idxes][:, 1],
                        label=f"{possible_label}" if ax_i + ii == 0 else "",
                        s=15,
                        marker="x" if possible_label else "o",
                        c="black"
                    )

                ax.scatter(
                    closest_pos[:, 0],
                    closest_pos[:, 1],
                    label=f"closest +" if ax_i + ii == 0 else "",
                    s=70,
                    facecolors="none",
                    edgecolors="red",
                )

                ax.scatter(
                    closest_neg[:, 0],
                    closest_neg[:, 1],
                    label=f"closest -" if ax_i + ii == 0 else "",
                    s=70,
                    facecolors="none",
                    edgecolors="blue",
                )
                
                ax.plot(
                    input_range,
                    gt["decision_boundary"][task_i],
                    color="gray",
                    label="gt" if ax_i + ii == 0 else ""
                )

                ax.set_xlim(input_range[0] - 0.01, input_range[1] + 0.01)
                ax.set_ylim(input_range[0] - 0.01, input_range[1] + 0.01)
                ax.set_title("{}{}".format(repr_key, " permuted" if is_permute else ""))
                ax.grid(False)

        fig.legend(
            bbox_to_anchor=(0.0, 1.0, 1.0, 0.0),
            loc="lower center",
            ncols=4,
            borderaxespad=0.0,
            frameon=True,
            fontsize="8",
        )
        fig.supxlabel("$x_1$")
        fig.supylabel("$x_2$")
        plt.show()

# Visualizing Attention Weights

In [None]:
from jaxl.constants import *
from jaxl.learning_utils import get_learner
from jaxl.utils import parse_dict

import json

from orbax.checkpoint import PyTreeCheckpointer, CheckpointManager


def load_llm(learner_path: str):
    config_path = os.path.join(learner_path, "config.json")
    with open(config_path, "r") as f:
        config_dict = json.load(f)
        config = parse_dict(config_dict)

    learner = get_learner(
        config.learner_config, config.model_config, config.optimizer_config
    )

    checkpoint_manager = CheckpointManager(
        os.path.join(learner_path, "models"),
        PyTreeCheckpointer(),
    )

    llm_params = checkpoint_manager.restore(checkpoint_manager.latest_step())
    llm_params[CONST_MODEL_DICT][CONST_MODEL][CONST_POSITIONAL_ENCODING] = dict()
    llm_model = learner._model
    return llm_params, llm_model, config

In [None]:
agent_model = load_llm(list(agent_results.keys())[0])


In [None]:
import chex

from flax import linen as nn
from typing import Tuple

class GPTBlockWithInspection(nn.Module):
    # : The number of attention heads
    num_heads: int

    # : The embedding dimensionality
    embed_dim: int

    @nn.compact
    def __call__(self, x: chex.Array) -> Tuple[chex.Array, chex.Array]:
        mask = nn.make_causal_mask(x[..., 0])
        attention = nn.SelfAttention(self.num_heads)(nn.LayerNorm()(x), mask)
        x = x + attention
        normed_x = nn.gelu(nn.Dense(self.embed_dim)(nn.LayerNorm()(x)))
        x = x + nn.Dense(self.embed_dim)(normed_x)
        return x, attention

In [None]:
inspection = GPTBlockWithInspection(1, 64)

## Visualize Representation SVM at each step

Generate intermediate representations for the queries

In [None]:
(llm_params, llm_model, _) = agent_model

In [None]:
task_i = 2
context_inputs = context_data[task_i][CONST_CONTEXT_INPUT]
context_outputs = context_data[task_i][CONST_CONTEXT_OUTPUT]

In [None]:
tokens, _ = jax.vmap(llm_model.tokenize, in_axes=[None, 0, None])(
    llm_params[CONST_MODEL_DICT][CONST_MODEL],
    gt["inputs"][:, None, None],
    {
        CONST_CONTEXT_INPUT: context_inputs[None],
        CONST_CONTEXT_OUTPUT: context_outputs[None],
    },
)
print(tokens.shape)

In [None]:
context_input_reprs = [tokens[:, 0, -1]]

for layer_name in agent_model[0][CONST_MODEL_DICT][CONST_MODEL]["gpt"]["params"]:
    print(layer_name)
    if layer_name.startswith("GPTBlock"):
        tokens, _ = inspection.apply(
            {"params": agent_model[0][CONST_MODEL_DICT][CONST_MODEL]["gpt"]["params"][layer_name]},
            tokens
        )
    elif layer_name.startswith("LayerNorm"):
        tokens = nn.LayerNorm().apply(
            {"params": agent_model[0][CONST_MODEL_DICT][CONST_MODEL]["gpt"]["params"][layer_name]},
            tokens
        )
    context_input_reprs.append(tokens[:, 0, -1])

Generate intermediate representations for the contexts

In [None]:
tokens, _ = jax.vmap(llm_model.tokenize, in_axes=[None, 0, None])(
    llm_params[CONST_MODEL_DICT][CONST_MODEL],
    context_inputs[:, None, None],
    {
        CONST_CONTEXT_INPUT: context_inputs[None],
        CONST_CONTEXT_OUTPUT: context_outputs[None],
    },
)
print(tokens.shape)

In [None]:
input_int_reprs = [tokens[0, 0, :-1:2]]
query_int_reprs = [tokens[:, 0, -1]]

for layer_name in agent_model[0][CONST_MODEL_DICT][CONST_MODEL]["gpt"]["params"]:
    print(layer_name)
    if layer_name.startswith("GPTBlock"):
        tokens, _ = inspection.apply(
            {"params": agent_model[0][CONST_MODEL_DICT][CONST_MODEL]["gpt"]["params"][layer_name]},
            tokens
        )
    elif layer_name.startswith("LayerNorm"):
        tokens = nn.LayerNorm().apply(
            {"params": agent_model[0][CONST_MODEL_DICT][CONST_MODEL]["gpt"]["params"][layer_name]},
            tokens
        )
    print(tokens.shape)
    input_int_reprs.append(tokens[0, 0, :-1:2])
    query_int_reprs.append(tokens[:, 0, -1])

In [None]:
def get_agent_repr(context_data, queries, agent_path, task_i):
    print(task_i)
    llm_params, llm_model, _ = load_llm(agent_path)

    context_inputs = context_data[task_i][CONST_CONTEXT_INPUT]
    context_outputs = context_data[task_i][CONST_CONTEXT_OUTPUT]

    print(context_inputs, context_outputs)

    if queries is None:
        queries = context_inputs
        print("USE CONTEXT INPUTS")
    else:
        print("USE QUERIES")

    repr, _ = jax.vmap(llm_model.get_latent, in_axes=[None, 0, None])(
        llm_params[CONST_MODEL_DICT][CONST_MODEL],
        queries[:, None, None],
        {
            CONST_CONTEXT_INPUT: context_inputs[None],
            CONST_CONTEXT_OUTPUT: context_outputs[None],
        },
    )
    print(repr)

    return repr

In [None]:
query_og_res = get_agent_repr(context_data, None, agent_path, task_i)

In [None]:
np.max(tokens[:, 0, -1] - query_og_res[:, 0, -1])

In [None]:
np.max(query_int_reprs[-1] - query_og_res[:, 0, -1])

In [None]:
np.max(agent_reprs[agent_path]["query_context_reprs"][task_i] - query_int_reprs[-1])

In [None]:
np.max(agent_reprs[agent_path]["query_context_reprs"][task_i] - query_og_res[:, 0, -1])

In [None]:
agent_path = list(agent_reprs.keys())[0]
print(agent_path)

for input_repr in input_int_reprs:
    print(np.allclose(agent_reprs[agent_path]["input_context_reprs"][task_i], input_repr))

print(
    np.max((agent_reprs[agent_path]["input_context_reprs"][task_i] - input_repr)),
    np.min((agent_reprs[agent_path]["input_context_reprs"][task_i] - input_repr))
)

for query_repr in query_int_reprs:
    print(np.allclose(agent_reprs[agent_path]["query_context_reprs"][task_i], query_repr))

print(
    np.max((agent_reprs[agent_path]["query_context_reprs"][task_i] - query_repr)),
    np.min((agent_reprs[agent_path]["query_context_reprs"][task_i] - query_repr))
)

In [None]:
from jaxl.models.svm import *

svm_input_reprs = []
svm_query_reprs = []
train_y = np.argmax(context_outputs, axis=-1)
train_y[train_y == 0] = -1
for input_int_repr, query_int_repr in zip(input_int_reprs, query_int_reprs):
    print(input_int_repr.shape, query_int_repr.shape)
    svm_input_reprs.append(
        primal_svm(input_int_repr, train_y)
    )
    svm_query_reprs.append(
        primal_svm(query_int_repr, train_y)
    )

In [None]:
ncols = 3
nrows = math.ceil(len(svm_query_reprs) / ncols)

fig, axes = plt.subplots(
    nrows,
    ncols,
    figsize=set_size(doc_width_pt, 0.95, (nrows, ncols), False),
    layout="constrained",
)

for ax_i, (svm_repr, context_repr, context_input_repr) in enumerate(zip(svm_query_reprs, query_int_reprs, context_input_reprs)):
    ax = axes[ax_i // ncols, ax_i % ncols]

    svm_primal = svm_repr[1]
    svm_preds = (
        (
            np.array(context_input_repr) @ svm_primal[:-1]
            + svm_primal[-1:]
        )
        >= 0
    ).astype(int)
    
    for possible_label in [0, 1]:
        idxes = np.where(svm_preds == possible_label)
        ax.scatter(
            gt["inputs"][idxes][:, 0],
            gt["inputs"][idxes][:, 1],
            label=f"{possible_label}" if ax_i == 0 else "",
            s=5,
            alpha=0.1
        )
    for possible_label in [-1, 1]:
        idxes = np.where(train_y == possible_label)
        ax.scatter(
            context_inputs[idxes][:, 0],
            context_inputs[idxes][:, 1],
            label=f"{possible_label}" if ax_i == 0 else "",
            s=15,
            marker="x" if possible_label == 1 else "o",
            c="black"
        )

    ax.set_xlim(input_range[0] - 0.01, input_range[1] + 0.01)
    ax.set_ylim(input_range[0] - 0.01, input_range[1] + 0.01)
    ax.set_title("repr {}".format(ax_i))
    ax.grid(False)

fig.legend(
    bbox_to_anchor=(0.0, 1.0, 1.0, 0.0),
    loc="lower center",
    ncols=4,
    borderaxespad=0.0,
    frameon=True,
    fontsize="8",
)
fig.suptitle("Context-input Query Intermediate Representation")
fig.supxlabel("$x_1$")
fig.supylabel("$x_2$")
plt.show()

In [None]:
ncols = 3
nrows = math.ceil(len(svm_input_reprs) / ncols)

fig, axes = plt.subplots(
    nrows,
    ncols,
    figsize=set_size(doc_width_pt, 0.95, (nrows, ncols), False),
    layout="constrained",
)

for ax_i, (svm_repr, context_repr, context_input_repr) in enumerate(zip(svm_input_reprs, input_int_reprs, context_input_reprs)):
    ax = axes[ax_i // ncols, ax_i % ncols]

    svm_primal = svm_repr[1]
    svm_preds = (
        (
            np.array(context_input_repr) @ svm_primal[:-1]
            + svm_primal[-1:]
        )
        >= 0
    ).astype(int)
    
    for possible_label in [0, 1]:
        idxes = np.where(svm_preds == possible_label)
        ax.scatter(
            gt["inputs"][idxes][:, 0],
            gt["inputs"][idxes][:, 1],
            label=f"{possible_label}" if ax_i == 0 else "",
            s=5,
            alpha=0.1
        )
    for possible_label in [-1, 1]:
        idxes = np.where(train_y == possible_label)
        ax.scatter(
            context_inputs[idxes][:, 0],
            context_inputs[idxes][:, 1],
            label=f"{possible_label}" if ax_i == 0 else "",
            s=15,
            marker="x" if possible_label else "o",
            c="black"
        )

    ax.set_xlim(input_range[0] - 0.01, input_range[1] + 0.01)
    ax.set_ylim(input_range[0] - 0.01, input_range[1] + 0.01)
    ax.set_title("repr {}".format(ax_i))
    ax.grid(False)

fig.legend(
    bbox_to_anchor=(0.0, 1.0, 1.0, 0.0),
    loc="lower center",
    ncols=4,
    borderaxespad=0.0,
    frameon=True,
    fontsize="8",
)
fig.supxlabel("$x_1$")
fig.supylabel("$x_2$")
fig.suptitle("Context-input Input Intermediate Representation")
plt.show()

# TODO: Write my own self attention

In [None]:
assert 0

In [None]:
for gpt_block_key, gpt_block_params in agent_model[0][CONST_MODEL_DICT][CONST_MODEL]["gpt"]["params"].items():
    print(gpt_block_params.get("SelfAttention_0", None))

# Analyzing Dual

In [None]:
import cvxpy as cp

def recover_dual_from_primal(primal_sol, train_x, train_y):
    assert len(train_x.shape) == 2
    assert len(train_y.shape) == 1

    N, _ = train_x.shape

    print("Primal solution: {}, shape: {}".format(primal_sol, primal_sol.shape))

    primal_out = (train_x @ primal_sol[:-1]) + primal_sol[-1]
    primal_constraints = 2 * (0.5 - (1 - train_y)) * primal_out
    coefs = 1 - primal_constraints
    coefs[np.isclose(coefs, 0)] = 0
    G = np.eye(N)
    h = np.zeros(N)
    dual_var = cp.Variable(N)
    prob = cp.Problem(
        cp.Maximize(coefs.T @ dual_var),
        [G @ dual_var >= h],
    )
    loss = 0.5 * (np.linalg.norm(primal_sol[:-1]) ** 2) + prob.solve(verbose=False)
    alphas = dual_var.value
    return loss, alphas

In [None]:
from pprint import pprint

for task_i in context_data:
    context_inputs = context_data[task_i][CONST_CONTEXT_INPUT]
    context_outputs = context_data[task_i][CONST_CONTEXT_OUTPUT]
    train_y = np.argmax(context_outputs, axis=-1)

    for agent_path, agent_repr in agent_reprs.items():

        print("=" * 50)
        print("Processing agent {}".format(agent_path))
        print("-" * 50)
        print("input ")
        print("-" * 50)
        # pprint(agent_repr["svms"][task_i]["input"]["dual"])
        loss, alphas = recover_dual_from_primal(
            agent_repr["svms"][task_i]["input"]["primal"]["sol"],
            context_inputs,
            train_y
        )
        # pprint({
        #     "loss": loss,
        #     "sol": alphas,
        # })
        # print(np.argsort(agent_repr["svms"][task_i]["input"]["dual"]["sol"]))
        sort_idxes = np.argsort(alphas)
        print("Primal loss: {}, Dual loss: {}".format(
            agent_repr["svms"][task_i]["input"]["primal"]["loss"],
            loss
        ))
        print("Dual solution (sorted): {}".format(alphas[sort_idxes]))
        print("Sort indices: {}".format(sort_idxes))
        for repr_key in ["query_context_reprs", "input_context_reprs"]:
            print("-" * 50)
            print(repr_key)
            print("-" * 50)
            # pprint(agent_repr["svms"][task_i][repr_key]["dual"])
            loss, alphas = recover_dual_from_primal(
                np.array(agent_repr["svms"][task_i][repr_key]["primal"]["sol"]),
                np.array(agent_repr[repr_key][task_i]),
                train_y
            )
            # pprint({
            #     "loss": loss,
            #     "sol": alphas,
            # })
            # print(np.argsort(agent_repr["svms"][task_i][repr_key]["dual"]["sol"]))
            sort_idxes = np.argsort(alphas)
            print("Primal loss: {}, Dual loss: {}".format(
                agent_repr["svms"][task_i][repr_key]["primal"]["loss"],
                loss
            ))
            print("Dual solution (sorted): {}".format(alphas[sort_idxes]))
            print("Sort indices: {}".format(sort_idxes))
        print("")