In [None]:
from collections import OrderedDict
import re
import os

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
from tqdm.notebook import tqdm

from eval import get_run_metrics, read_run_dir, get_model_from_run
from plot_utils import basic_plot, collect_results, relevant_model_names

%matplotlib inline
%load_ext autoreload
%autoreload 2

sns.set_theme('notebook', 'darkgrid')
palette = sns.color_palette('colorblind')

run_dir = "./models"

In [None]:
df = read_run_dir(run_dir)
df  # list all the runs in our run_dir

In [None]:
task = "linear_regression"
#task = "sparse_linear_regression"
#task = "decision_tree"
#task = "relu_2nn_regression"

run_id = "pretrained"  # if you train more models, replace with the run_id from the table above

run_path = os.path.join(run_dir, task, run_id)
recompute_metrics = False

if recompute_metrics:
    get_run_metrics(run_path)  # these are normally precomputed at the end of training

In [None]:
from samplers import get_data_sampler
from tasks import get_task_sampler

model, conf = get_model_from_run(run_path)

n_dims = conf.model.n_dims
batch_size = conf.training.batch_size

data_sampler = get_data_sampler(conf.training.data, n_dims)
task_sampler = get_task_sampler(
    conf.training.task,
    n_dims,
    batch_size,
    **conf.training.task_kwargs
)

In [None]:
task = task_sampler()
xs = data_sampler.sample_xs(b_size=batch_size, n_points=conf.training.curriculum.points.end)
ys = task.evaluate(xs)
with torch.no_grad():
    pred = model(xs, ys)

In [None]:
fcoord = torch.randn((64))
pred = []
real = []
for f in fcoord:
    xval = torch.zeros((64,41,20))
    for i in range(64):
        xval[i,:40,:] = xs[0,:40,:]
    xval[0,40,:] = f
    
    yval = task.evaluate(xval)
    predval = model(xval, yval)
    
    pred.append(predval[0,40])
    real.append(yval[0,40])

In [None]:
import numpy as np


m, b = np.polyfit(fcoord, y, 1)

#use red as color for regression line
x = np.arange(fcoord.min(),fcoord.max(),0.01)
plt.plot(x, m*x+b, color='red')

plt.scatter(fcoord, y)