In [None]:
from jaxl.constants import *
from jaxl.datasets import get_dataset
from jaxl.learning_utils import get_learner
from jaxl.utils import parse_dict

import json
import numpy as np
import os

from orbax.checkpoint import PyTreeCheckpointer, CheckpointManager

In [None]:
base_dir = "/Users/chanb/research/personal/jaxl/jaxl"
# learner_path = os.path.join(
#     base_dir,
#     "logs/icl-linear_sgd-full_context_20/gpt-no_enc-08-21-23_07_36_59-971a17db-73ed-4f77-b463-5887644e3385"
# )
learner_path = os.path.join(
    base_dir,
    "logs/icl-linear_sgd-full_context_20/gpt-pos_enc-08-21-23_07_36_58-5473978c-10db-433b-889b-f136405d7a7e"
)
test_dataset_seed = 9999

In [None]:
config_path = os.path.join(learner_path, "config.json")
with open(config_path, "r") as f:
    config_dict = json.load(f)
    config = parse_dict(config_dict)
    
learner = get_learner(
    config.learner_config, config.model_config, config.optimizer_config
)

checkpoint_manager = CheckpointManager(
    os.path.join(learner_path, "models"),
    PyTreeCheckpointer(),
)

params = checkpoint_manager.restore(checkpoint_manager.latest_step())
params[CONST_MODEL_DICT][CONST_MODEL][CONST_POSITIONAL_ENCODING] = dict()
model = learner._model

In [None]:
config.learner_config.dataset_config

In [None]:
sequence_len = 80
test_config = vars(config.learner_config.dataset_config)
test_config["dataset_kwargs"] = vars(test_config["dataset_kwargs"])
test_config["dataset_kwargs"]["sequence_length"] = sequence_len + 1
test_config["dataset_kwargs"]["params_bound"] = [-0.5, 0.5]
test_config["dataset_kwargs"]["inputs_range"] = [-0.5, 0.5]
test_config = parse_dict(test_config)

In [None]:
test_dataset = get_dataset(test_config, seed=test_dataset_seed)

In [None]:
def get_result(dataset, sequence_length):
    context_inputs, context_outputs, queries, outputs = [], [], [], []
    for seq_i in range(sequence_length):
        context_input, context_output, query, output = dataset[seq_i]
        context_inputs.append(context_input)
        context_outputs.append(context_output)
        queries.append(query)
        outputs.append(output)
    context_inputs = np.stack(context_inputs)
    context_outputs = np.stack(context_outputs)
    queries = np.stack(queries)
    outputs = np.stack(outputs)

    preds, _ = model.forward(
        params[CONST_MODEL_DICT][CONST_MODEL],
        queries,
        {
            CONST_CONTEXT_INPUT: context_inputs,
            CONST_CONTEXT_OUTPUT: context_outputs,
        }
    )
    return queries, preds, outputs

In [None]:
def mse(preds, outputs):
    return np.mean((preds - outputs) ** 2)

In [None]:
queries, preds, outputs = get_result(
    test_dataset,
    sequence_length=sequence_len
)
loss = mse(preds, outputs)
print(queries.shape, preds.shape, outputs.shape)

In [None]:
loss

# Basic Prediction Result

In [None]:
import matplotlib.pyplot as plt

In [None]:
context_len = config.learner_config.dataset_config.dataset_wrapper.kwargs.context_len
plt.scatter(queries[context_len:], preds[context_len:], label="prediction")
plt.scatter(queries[context_len:], outputs[context_len:], label="ground truth")
plt.legend()
plt.xlabel("input")
plt.ylabel("output")

In [None]:
def get_latent(dataset, sequence_length):
    context_inputs, context_outputs, queries, outputs = [], [], [], []
    for seq_i in range(sequence_length):
        context_input, context_output, query, output = dataset[seq_i]
        context_inputs.append(context_input)
        context_outputs.append(context_output)
        queries.append(query)
        outputs.append(output)
    context_inputs = np.stack(context_inputs)
    context_outputs = np.stack(context_outputs)
    queries = np.stack(queries)
    outputs = np.stack(outputs)

    latent, _ = model.get_latent(
        params[CONST_MODEL_DICT][CONST_MODEL],
        queries,
        {
            CONST_CONTEXT_INPUT: context_inputs,
            CONST_CONTEXT_OUTPUT: context_outputs,
        }
    )
    return queries, latent

In [None]:
queries, latents = get_latent(
    test_dataset,
    sequence_length=sequence_len
)
print(queries.shape, latents.shape)

In [None]:
from jaxl.utils import l2_norm

In [None]:
for latent in latents:
    print(latent)