In [1]:
import matplotlib.pyplot as plt
import torch
import seaborn as sns
from tqdm.notebook import tqdm

from samplers import get_data_sampler
from tasks import get_task_sampler
from utils import get_model_from_run

%matplotlib notebook
sns.set_theme('notebook', 'darkgrid')

run_path = "/juice2/u/tsipras/context-runs/linear_regression/1a641a1f-24c6-403f-9762-cb7ab46ab69e"

In [2]:
model, conf = get_model_from_run(run_path)
model.cuda().eval()
n_dims = conf.model.n_dims
batch_size = 10
n_points = conf.training.curriculum.points.end

In [3]:
data_sampler = get_data_sampler(conf.training.data, n_dims)
task_sampler = get_task_sampler(conf.training.task, n_dims, batch_size)
task = task_sampler()
metric = task.get_metric()

In [4]:
query = 30

xs = data_sampler.sample_xs(n_points=n_points, b_size=batch_size).cuda()
ys = task.evaluate(xs)

In [5]:
xs_var = xs.clone().detach().requires_grad_(True)
query_ys = model(xs_var, ys)[:, query]
query_ys.sum().backward()

wp = xs_var.grad[:, query, :]
wp = wp / wp.norm(dim=1, keepdim=True)

w = task.w_b[:, :, 0].to(wp.device)
w = w / w.norm(dim=1, keepdim=True)

(w * wp).sum(dim=1)

tensor([0.9997, 0.9998, 0.9996, 0.9995, 0.9995, 0.9995, 0.9996, 0.9997, 0.9998,
        0.9994], device='cuda:0')

In [6]:
def pgd(xs, task, perturb_idx, target, eps=1, step_size=0.01, num_steps=100, return_trajectories=True):
    xs_p = xs
    max_norms = xs_p.norm(dim=2, keepdim=True)
    metric = task.get_metric()

    trajectories = []
    for i in range(num_steps):
        ys = task.evaluate(xs_p)
        xs_p = xs_p.clone().detach().requires_grad_(True)
        pred = model(xs_p, ys)

        loss = metric(pred, ys)[:, target]
        if return_trajectories:
            trajectories.append(loss.detach().cpu().numpy())
        loss.sum().backward()
        grad = xs_p.grad.detach()

        with torch.no_grad():
            step =  grad / grad.norm(dim=2, keepdim=True)
            xs_p[:, perturb_idx] += step_size * step[:, perturb_idx]
            diff = (xs_p - xs).renorm(p=2, dim=2, maxnorm=eps)
            xs_p = xs + diff
            norms = torch.min(max_norms, xs_p.norm(dim=2, keepdim=True))
            xs_p = norms * xs_p / xs_p.norm(dim=2, keepdim=True)

    if return_trajectories:
        return xs_p.detach(), trajectories
    return xs_p.detach()

In [7]:
_, trajectories = pgd(xs, task, torch.arange(query-1), query)
        
plt.figure()
for i in range(len(xs)):
    plt.plot([t[i] for t in trajectories])

<IPython.core.display.Javascript object>

In [8]:
_, trajectories = pgd(xs, task, [query], query)
  
plt.figure()
for i in range(len(xs)):
    plt.plot([t[i] for t in trajectories])

<IPython.core.display.Javascript object>