In [5]:
import torch
from torch import nn
from torch.autograd import grad
import numpy as np
from torch.utils.data import DataLoader

import os

from hessian_hvp_utils import hessian_vector_product, hessians
from mnist_logistic_binary import create_binary_MNIST, preproc_binary_MNIST
from inverse_hvp import get_inverse_hvp

DATA_DIR = "./data"
MODEL_DIR = "./model"
MODEL_PT = "mnist_logistic_reg.pt"

train_dataset, test_dataset = create_binary_MNIST(data_dir=DATA_DIR)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

model = nn.Sequential(nn.Linear(28*28,1))
model.load_state_dict(torch.load(os.path.join(MODEL_DIR, MODEL_PT)))
params = list(model.parameters())
criterion = nn.BCEWithLogitsLoss()

test_idx = 278    # The first misclassified test data point

In [2]:
# Influence function of params w.r.t. upweighting

pass

In [3]:
# Influence function of loss w.r.t. upweighting

# Compute gradients w.r.t. each training data point
train_grads = []
for idx, (input, target) in enumerate(train_loader):
    input, target = preproc_binary_MNIST(input, target)
    loss = criterion(model(input), target)
    grads = grad(loss, params)
    train_grads.append(grads)

# Compute `s_test`
input, target = test_dataset.data[test_idx:(test_idx+1)], test_dataset.targets[test_idx:(test_idx+1)]
input, target = preproc_binary_MNIST(input, target)
loss = criterion(model(input), target)
test_grads = grad(loss, params)
lissa_params = {
    "batch_size": 10,
    "num_repeats": 10,
    "recursion_depth": 5000,
}

s_test = get_inverse_hvp(model, criterion, train_dataset, test_grads,
                                        approx_type='lissa',
                                        approx_params=lissa_params,
                                        preproc_data_fn=preproc_binary_MNIST)

# Compute influence
inf_up_loss = []
for train_grad in train_grads:
    inf = 0
    for train_grad_p, s_test_p in zip(train_grad, s_test):
        assert train_grad_p.shape == s_test_p.shape
        inf += -torch.sum(train_grad_p * s_test_p)
    inf_up_loss.append(inf)

print(inf_up_loss)

In [None]:
# Influence function of params w.r.t. perturbation

pass

In [None]:
# Influence function of loss w.r.t. perturbation

# Compute `s_test`
input, target = test_dataset.data[test_idx:(test_idx+1)], test_dataset.targets[test_idx:(test_idx+1)]
input, target = preproc_binary_MNIST(input, target)
loss = criterion(model(input), target)
test_grads = grad(loss, params)
lissa_params = {
    "batch_size": 10,
    "num_repeats": 10,
    "recursion_depth": 5000,
}

s_test = get_inverse_hvp(model, criterion, train_dataset, test_grads,
                                        approx_type='lissa',
                                        approx_params=lissa_params,
                                        preproc_data_fn=preproc_binary_MNIST)

In [40]:
train_twice_grads = []
for idx, (input, target) in enumerate(train_loader):
    input, target = preproc_binary_MNIST(input, target)
    input.requires_grad_()
    
    # TODO) Fix code below (use hvp to compute influence directly)
    #             (Extend hvp to support single data point case)
    # NOTE) Probably don't need to detach s_test beforehand
    #             (since s_test has already been computed before this stage)
    loss = criterion(model(input), target)
    grads = grad(loss, params, create_graph=True)
    grad_outputs = [torch.ones_like(g) for g in grads]
    twice_grads = grad(grads, input, grad_outputs=grad_outputs)
    train_twice_grads.append(twice_grads)

2
1
torch.Size([1, 784])


In [None]:
# Compute influence
inf_pert_loss = []
for train_grad in train_grads:
    inf = 0
    for train_grad_p, s_test_p in zip(train_grad, s_test):
        assert train_grad_p.shape == s_test_p.shape
        inf += -torch.sum(train_grad_p * s_test_p)
    inf_pert_loss.append(inf)