In [None]:
from collections import OrderedDict
import re
import os

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np
import torch
from tqdm.notebook import tqdm
import random
import string

from eval import get_run_metrics, read_run_dir, get_model_from_run
from plot_utils import basic_plot, collect_results, relevant_model_names

%matplotlib inline
%load_ext autoreload
%autoreload 2

random.seed(42)

sns.set_theme('notebook', 'darkgrid')
palette = sns.color_palette('colorblind')

run_dir = "../models"

In [None]:
df = read_run_dir(run_dir)
# df  # list all the runs in our run_dir

In [None]:
task = "linear_regression/lang_pretrained/random/gpt2-small/"

run_id = "4b42e1c1-0537-4b5b-8492-211352f8294e"

# model name for plots, make sure to capitalize the first letter
model_size = "gpt2-small"
# task name for plots
task_name = "linear regression"

run_path = os.path.join(run_dir, task, run_id)
recompute_metrics = True

if recompute_metrics:
    # question: is this with the test set? -- i think so
    get_run_metrics(run_path)  # these are normally precomputed at the end of training

In [None]:
def save_current_figure(folder_path, file_name):
    if not os.path.exists("plots"):
        os.makedirs("plots")
    folder_path = "plots/" + folder_path
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
    file_path = os.path.join(folder_path, file_name)
    plt.savefig(file_path)
    plt.close()

def generate_random_string(length=8):
    letters_and_digits = string.ascii_letters + string.digits
    random_string = ''.join(random.choice(letters_and_digits) for _ in range(length))
    return random_string

In [None]:
fig_output_dir = run_id + "_" + task + "_" + generate_random_string()

# Plot pre-computed metrics

In [None]:
def valid_row(r):
    return r.task == task and r.run_id == run_id

metrics = collect_results(run_dir, df, valid_row=valid_row)
_, conf = get_model_from_run(run_path, only_conf=True)
n_dims = conf.model.n_dims

models = relevant_model_names[task]
basic_plot(metrics["standard"], models=models)
plt.title(model_size + " model on " + task_name + " in-context task squared error vs in-context examples")
plt.show()
# save_current_figure(fig_output_dir, "eval_on_all_models")

In [None]:
# plot any OOD metrics, out of distribution
for name, metric in metrics.items():
    if name == "standard": continue
   
    if "scale" in name:
        scale = float(name.split("=")[-1])**2
    else:
        scale = 1.0

    trivial = 1.0 if "noisy" not in name else (1+1/n_dims)
    fig, ax = basic_plot(metric, models=models, trivial=trivial * scale)
    ax.set_title(name)
    
    if "ortho" in name:
        ax.set_xlim(-1, n_dims - 1)
    ax.set_ylim(-.1 * scale, 1.5 * scale)

    plt.show()
    if ("." in name):
        name = name.replace(".", "")
    # save_current_figure(fig_output_dir, "eval_on_all_models_ood_" + name)

# Interactive setup

We will now directly load the model and measure its in-context learning ability on a batch of random inputs.

In [None]:
from samplers import get_data_sampler
from tasks import get_task_sampler

In [None]:
model, conf = get_model_from_run(run_path)

n_dims = conf.model.n_dims
batch_size = conf.training.batch_size

data_sampler = get_data_sampler(conf.training.data, n_dims)
task_sampler = get_task_sampler(
    conf.training.task,
    n_dims,
    batch_size,
    **conf.training.task_kwargs
)

In [None]:
def generate_random_ys(num_changes, num_total, b_dim, y_dim=0):
    """
    Selects num_changes random indices from the range 0 to num_total - 1.

    Args:
    - num_total (int): Total number of indices.
    - num_changes (int): Number of random indices to select.
    - b_dim (int): Dimension of a batch
    - y_dim (int): Dimension of y

    Returns:
    - torch.Tensor: Tensor containing the selected random indices. batch size x num changes 
    - torch.Tensor: Tensor containing the new random values at those indices
    """

    if num_changes > num_total:
        raise ValueError("Number of changes cannot be greater than the total number.")
    
    # Generate random indices
    # torch.randperm(len(pictures))[:10]

    random_indices = []

    # Generate random batches
    for _ in range(b_dim):
        # Generate a random permutation of 0 to num_total-1, taking only the first num_changes
        permutation = torch.range(0, num_changes - 1, 1).type(torch.int64)
        
        # Add the batch to the list
        random_indices.append(permutation)

    # Convert the list of batches to a PyTorch tensor
    random_indices = torch.stack(random_indices)

    if (y_dim == 0):
        indices_values = torch.rand((b_dim, num_changes))
    else: 
        indices_values = torch.rand((b_dim, num_changes, y_dim))

    return random_indices, indices_values

In [None]:
SEQ = "seq" in task

In [None]:
# can change this to be false
RANDOM_XS = False

In [None]:
task = task_sampler()
xs = data_sampler.sample_xs(b_size=batch_size, n_points=conf.training.curriculum.points.end)

if not SEQ:
    ys = task.evaluate(xs)
    y_dim = 0
    print(ys.shape)
else: 
    xs, ys = task.generate_sequence(xs[:, 0, :], conf.model.n_positions)
    y_dim = ys.shape[2]
    print(ys.shape)

# randomness in (64 x 11) -- num_indices of 11 will be randomized
conf.model.n_positions = 2
num_indices = [i for i in range(1, conf.model.n_positions)]
randomized_ys_array = []
randomized_xs_array = []

for num_changes in num_indices:
    if RANDOM_XS:
        randomized_xs = xs.clone()
        random_indices, x_indices_values = generate_random_ys(num_changes, randomized_xs.shape[1], randomized_xs.shape[0], xs.shape[2])

        for b in range(randomized_xs.shape[0]):
            randomized_xs[b][random_indices[b]] = x_indices_values[b]

        randomized_xs_array.append(randomized_xs)

    randomized_ys = ys.clone()
    print(randomized_ys.shape)
    random_indices, indices_values = generate_random_ys(num_changes, randomized_ys.shape[1], randomized_ys.shape[0], y_dim)
    if num_changes == 1: 
        print("random_indices", random_indices.shape)
        print("indices_values", indices_values.shape)

    for b in range(randomized_ys.shape[0]):
        if b == 0 and num_changes == 10:
            print("before random ys", randomized_ys[b])
        randomized_ys[b][random_indices[b]] = indices_values[b]

        if b == 0 and num_changes == 10:
            print(random_indices[b])
            print("after random ys", randomized_ys[b])

    randomized_ys_array.append(randomized_ys)

In [None]:
with torch.no_grad():
    pred = model(xs, ys)

pred = pred.squeeze()

print("pred", pred[0][1]) # getting the 0th sequence of the batch

# if not SEQ:
randomized_pred_array = []
for i in range(len(randomized_ys_array)):
    with torch.no_grad():
        if RANDOM_XS:
            randomized_pred = model(randomized_xs_array[i], randomized_ys_array[i])
        else:
            randomized_pred = model(xs, randomized_ys_array[i])
    randomized_pred = randomized_pred.squeeze()
    randomized_pred_array.append(randomized_pred)

    

In [None]:
# it seems like the metric is just subtracting the predictions from the ys, but im confused because are the ys not like passed in as context?
metric = task.get_metric()
loss = metric(pred, ys).numpy()
randomized_loss = metric(randomized_pred, ys).numpy()
randomized_loss_w_random = metric(randomized_pred, randomized_ys).numpy()


sparsity = conf.training.task_kwargs.sparsity if "sparsity" in conf.training.task_kwargs else None
baseline = {
    "linear_regression": n_dims,
    "sparse_linear_regression": sparsity,
    "relu_2nn_regression": n_dims,
    "decision_tree": 1,
    "seq_relu_2nn": 0,
    "seq_linear": 1,
    "seq_rec_linear": 0,
}[conf.training.task]

plt.plot(loss.mean(axis=0), lw=2, label="0 y rand")
# plt.axhline(baseline, ls="--", color="gray", label="zero estimator")
# save_current_figure(fig_output_dir, "eval_on_transformer")

# if (not SEQ):

for n in range(1, len(randomized_pred_array), 2):
    randomized_loss = metric(randomized_pred_array[n], ys).numpy()
    
    plt.plot(randomized_loss.mean(axis=0), lw=2, label=str(n + 1) + " y rand")
    # plt.axhline(baseline, ls="--", color="gray", label="zero estimator")
    # plt.xlabel("# in-context examples")
    # plt.ylabel("squared error")
    # plt.legend()
    # plt.title(model_size + " model on " + task_name + " in-context task squared error vs in-context examples (with randomness)")
    # plt.show()
    # save_current_figure(fig_output_dir, "eval_on_transformer_" + str(num_random) + "_randomized_context")

plt.xlabel("# in-context examples")
plt.ylabel("squared error")
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.title(model_size + " model on " + task_name + " in-context task squared error vs in-context examples")
# plt.show()

# plt.plot(randomized_loss_w_random.mean(axis=0), lw=2, label="Transformer with randomized context")
plt.axhline(baseline, ls="--", color="gray", label="zero estimator")
plt.xlabel("# in-context examples")
plt.ylabel("squared error")
plt.legend()
plt.show()

In [None]:
import numpy as np
with open('seq_small_random.npy', 'wb') as f:
    np.save(f, loss)

### Figures

In [None]:
# it seems like the metric is just subtracting the predictions from the ys, but im confused because are the ys not like passed in as context?
metric = task.get_metric()
loss = metric(pred, ys).numpy()
randomized_loss = metric(randomized_pred, ys).numpy()
randomized_loss_w_random = metric(randomized_pred, randomized_ys).numpy()

sparsity = None
baseline = {
    "linear_regression": n_dims,
    "sparse_linear_regression": sparsity,
    "relu_2nn_regression": n_dims,
    "decision_tree": 1,
    "seq_relu_2nn": 0,
    "seq_linear": 1,
    "seq_rec_linear": 0,
}[conf.training.task]

## load losses for different models
small_pca = np.load("./linreg_small_pca.npy")
small_wopca = np.load("./linreg_small_wopca.npy")
small_random = np.load("./linreg_small_random.npy")

plt.plot(small_pca.mean(axis=0), lw=2, label="with PCA")
plt.plot(small_wopca.mean(axis=0), lw=2, label="without PCA")
plt.plot(small_random.mean(axis=0), lw=2, label="random init")

# med_pca = np.load("./linreg_med_pca.npy")
# med_wopca = np.load("./linreg_med_wopca.npy")
# med_random = np.load("./linreg_med_random.npy")

# plt.plot(med_pca.mean(axis=0), lw=2, label="with PCA")
# plt.plot(med_wopca.mean(axis=0), lw=2, label="without PCA")
# plt.plot(med_random.mean(axis=0), lw=2, label="random init")

plt.xlabel("# in-context examples")
plt.ylabel("squared error")
plt.title("GPT2-small" + " on " + task_name + " in-context task")

plt.axhline(baseline, ls="--", color="gray", label="zero estimator")
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.show()