In [None]:
import os

import matplotlib.pyplot as plt
import torch

from models import GDModel, NeuralNetwork
from tasks import get_task_sampler
from samplers import get_data_sampler
from utils import get_model_from_run


run_dir = "/juice2/u/nlp/data/tsipras/context-runs"
run_path = os.path.join(run_dir, "relu_2nn_regression", "c6d38404-c4df-4c37-bc67-73d0bbf324bc")
model, conf = get_model_from_run(run_path)

data_sampler = get_data_sampler("gaussian", 20)

task = get_task_sampler(
    conf.training.task,
    conf.model.n_dims,
    64, 
    **conf.training.task_kwargs
)()

metric = task.get_metric()

In [None]:
xs = data_sampler.sample_xs(101, 64)
ys = task.evaluate(xs)

perf = {}

In [None]:
pred = model(xs, ys)
m = metric(pred, ys).mean(dim=0).detach()

In [None]:
inds = torch.arange(0, 101, 10)

hidden = 400
opt = "sgd"
lr = 0.1
num_steps = 400

gd_model = GDModel(
    model_class=NeuralNetwork,
    model_class_args={
        "in_size": 20,
        "hidden_size": hidden,
        "out_size": 1,
    },
    opt_alg=opt,
    batch_size=100,
    lr=lr,
    num_steps=num_steps,
)

pred = gd_model(xs, ys, inds=inds)
m_gd = metric(pred.detach().cpu(), ys[:, inds]).mean(dim=0).detach()
perf[(hidden, opt, lr, num_steps)] = m_gd

In [None]:
plt.plot(m)
for k, v in perf.items():
    plt.plot(inds, v, label=str(k))
plt.legend()

In [None]:
params = [
    (100, 'adam', 0.005, 100),
    (100, 'sgd', 0.1, 400),
]