In [32]:
from collections import OrderedDict
import re
import os

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
from tqdm.notebook import tqdm

from eval import * 
from plot_utils import *
from models import *
from samplers import *
%matplotlib inline
%load_ext autoreload
%autoreload 2

sns.set_theme('notebook', 'darkgrid')
palette = sns.color_palette('colorblind')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [34]:
run_dir = "./models"

df = read_run_dir(run_dir)
task = "linear_regression"

run_id = "pretrained"  # if you train more models, replace with the run_id from the table above

run_path = os.path.join(run_dir, task, run_id)

_, conf = get_model_from_run(run_path, only_conf=True)

In [52]:
def valid_row(r):
    return r.task == task and r.run_id == run_id

metrics = collect_results(run_dir, df, valid_row=valid_row)
_, conf = get_model_from_run(run_path, only_conf=True)
n_dims = conf.model.n_dims
print("Recorded losses for skewed inputs")
print(metrics['skewed']['Transformer']['mean']()

linear_regression_pretrained pretrained


100%|███████████████████████████████████████| 15/15 [00:00<00:00, 245760.00it/s]


[1.0172086715698243,
 0.6402848720550537,
 0.5014616489410401,
 0.3776614427566528,
 0.31359572410583497,
 0.2116999626159668,
 0.18059089183807372,
 0.15971052646636963,
 0.1399226188659668,
 0.11061776876449585,
 0.1028898000717163,
 0.11127077341079712,
 0.10663849115371704,
 0.1094159483909607,
 0.1145581603050232,
 0.11590192317962647,
 0.09470928311347962,
 0.10462424755096436,
 0.10046427249908448,
 0.14134728908538818,
 0.11887277364730835,
 0.11603138446807862,
 0.12148859500885009,
 0.10129406452178955,
 0.11050920486450196,
 0.11128654479980468,
 0.12795674800872803,
 0.12330986261367798,
 0.1367938280105591,
 0.14273830652236938,
 0.14310156106948851,
 0.11852085590362549,
 0.13730885982513427,
 0.11852630376815795,
 0.15389941930770873,
 0.11677205562591553,
 0.10656915903091431,
 0.12296102046966553,
 0.1085016131401062,
 0.09641695022583008,
 0.12826961278915405]

In [None]:
#### Confirm that we obtain the same 

In [39]:
result = build_evals(conf)
skewed_sampler = GaussianSampler(n_dims=20, scale=result['skewed']['data_sampler_kwargs']['scale'])

In [None]:
from samplers import get_data_sampler
from tasks import get_task_sampler

In [None]:
from samplers import get_data_sampler
from tasks import get_task_sampler

model, conf = get_model_from_run(run_path)

n_dims = conf.model.n_dims
batch_size = conf.training.batch_size

data_sampler = get_data_sampler(conf.training.data, n_dims)
task_sampler = get_task_sampler(
    conf.training.task,
    n_dims,
    batch_size,
    **conf.training.task_kwargs
)

In [None]:
xs = data_sampler.sample_xs(b_size=batch_size, n_points=conf.training.curriculum.points.end)
ys = task.evaluate(xs)

In [None]:
with torch.no_grad():
    pred = model(xs, ys)

In [None]:
metric = task.get_metric()
loss = metric(pred, ys).numpy()
print("Losses for new skewed data")
print(["%.3f" % (val/20) for val in loss.mean(axis=0)])