In [None]:
from jaxl.constants import *
from jaxl.datasets import get_dataset
from jaxl.learning_utils import get_learner
from jaxl.utils import parse_dict

import json
import numpy as np
import optax
import os

from orbax.checkpoint import PyTreeCheckpointer, CheckpointManager

In [None]:
base_dir = "/Users/chanb/research/personal/jaxl/jaxl"
learner_path = "/Users/chanb/research/personal/jaxl/jaxl/logs/icl-noiseless-no_bias-2d_linear-active_dim_2-full_context_16/gpt-pos_enc-09-09-23_15_09_55-0c965292-9e2e-4268-87df-3221b558bb19"
test_dataset_seed = 999

In [None]:
config_path = os.path.join(learner_path, "config.json")
with open(config_path, "r") as f:
    config_dict = json.load(f)
    config = parse_dict(config_dict)
    
learner = get_learner(
    config.learner_config, config.model_config, config.optimizer_config
)

checkpoint_manager = CheckpointManager(
    os.path.join(learner_path, "models"),
    PyTreeCheckpointer(),
)

params = checkpoint_manager.restore(checkpoint_manager.latest_step())
params[CONST_MODEL_DICT][CONST_MODEL][CONST_POSITIONAL_ENCODING] = dict()
model = learner._model

In [None]:
config.learner_config.dataset_config

In [None]:
sequence_len = 80
input_range = [-1.0, 1.0]
test_config = vars(config.learner_config.dataset_config)
test_config["dataset_kwargs"] = vars(test_config["dataset_kwargs"])
test_config["dataset_kwargs"]["sequence_length"] = sequence_len + 1
test_config["dataset_kwargs"]["params_bound"] = [-0.5, 0.5]
test_config["dataset_kwargs"]["inputs_range"] = input_range
ns_test_config = parse_dict(test_config)

In [None]:
test_dataset = get_dataset(ns_test_config, seed=test_dataset_seed)

In [None]:
def get_result(dataset, sequence_length):
    context_inputs, context_outputs, queries, outputs = [], [], [], []
    for seq_i in range(sequence_length):
        context_input, context_output, query, output = dataset[seq_i]
        context_inputs.append(context_input)
        context_outputs.append(context_output)
        queries.append(query)
        outputs.append(output)
    context_inputs = np.stack(context_inputs)
    context_outputs = np.stack(context_outputs)
    queries = np.stack(queries)
    outputs = np.stack(outputs)

    preds, _ = model.forward(
        params[CONST_MODEL_DICT][CONST_MODEL],
        queries,
        {
            CONST_CONTEXT_INPUT: context_inputs,
            CONST_CONTEXT_OUTPUT: context_outputs,
        }
    )
    return queries, preds, outputs, context_inputs, context_outputs

In [None]:
def ce_loss(logits, y_one_hot):
    return np.mean(optax.softmax_cross_entropy(logits, y_one_hot))

In [None]:
queries, preds, outputs, context_inputs, context_outputs = get_result(
    test_dataset,
    sequence_length=sequence_len
)
loss = ce_loss(preds, outputs)
print(queries.shape, preds.shape, outputs.shape)

In [None]:
loss

# Basic Prediction Result

In [None]:
import matplotlib.pyplot as plt

In [None]:
context_len = config.learner_config.dataset_config.dataset_wrapper.kwargs.context_len
pred_labels = np.argmax(preds[context_len:], axis=-1)
for possible_label in [0, 1]:
    idxes = np.where(pred_labels == possible_label)
    plt.scatter(queries[context_len:, 0, 0][idxes], queries[context_len:, 0, 1][idxes], label=f"{possible_label}")

decision_boundary = test_dataset.params[0]

out = -np.array(input_range) * decision_boundary[1] / decision_boundary[2]
plt.plot(np.array(input_range), out, label="decision boundary", color="black")

plt.xlim(input_range[0], input_range[1])
plt.ylim(input_range[0], input_range[1])
plt.legend(
    bbox_to_anchor=(0.0, 1.01, 1.0, 0.0),
    loc="lower center",
    ncols=3,
    borderaxespad=0.0,
    frameon=True,
    # fontsize="8",
)
plt.xlabel("$x_1$")
plt.ylabel("$x_2$")

# SVM

In [None]:
from sklearn.svm import LinearSVC
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.inspection import DecisionBoundaryDisplay

In [None]:
svm = make_pipeline(
    StandardScaler(),
    LinearSVC(
        dual="auto",
        loss="hinge",
        random_state=0,
        tol=1e-7,
        C=10.0,
    )
)
svm.fit(context_inputs[-1], np.argmax(context_outputs[-1], axis=1))

In [None]:
decision_function = svm.decision_function(context_inputs[-1])
print(decision_function)
support_vector_indices = np.where(np.abs(decision_function) <= 1)[0]
support_vectors = context_inputs[-1][support_vector_indices]

plt.scatter(context_inputs[-1][:, 0], context_inputs[-1][:, 1], c=context_outputs[-1, ..., -1], s=30, cmap=plt.cm.Paired)
ax = plt.gca()
DecisionBoundaryDisplay.from_estimator(
    svm,
    context_inputs[-1],
    ax=ax,
    grid_resolution=50,
    plot_method="contour",
    colors="k",
    levels=[-1, 0, 1],
    alpha=0.5,
    linestyles=["--", "-", "--"],
)
plt.scatter(
    support_vectors[:, 0],
    support_vectors[:, 1],
    s=100,
    linewidth=1,
    facecolors="none",
    edgecolors="k",
)