In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch as ch
import numpy as np
import torch.nn as nn
import os
import argparse
import seaborn as sns
import copy
from tqdm import tqdm
import matplotlib.pyplot as plt

from mib.models.utils import get_model
from sklearn.datasets import make_spd_matrix
from mib.utils import get_models_path
from mib.dataset.utils import get_dataset
from mib.train import get_loader, train_model, evaluate_model
from mib.attacks.theory_new import compute_hessian, fast_ihvp

from torch.utils.data import TensorDataset, DataLoader

from torch_influence.modules import HVPModule, LiSSAInfluenceModule
from torch_influence.base import BaseObjective
from livelossplot import PlotLosses

In [3]:
def flatten_vec(vec):
    fvec = []
    for p in params_like:
        fvec.append(p.view(-1))
    return torch.cat(fvec)

In [4]:
def compute_epsilon_acceleration(
    source_sequence,
    num_applications: int=1,):
    """Compute `num_applications` recursive Shanks transformation of
    `source_sequence` (preferring later elements) using `Samelson` inverse and the
    epsilon-algorithm, with Sablonniere modifier.
    """

    def inverse(vector):
        # Samelson inverse
        return vector / vector.dot(vector)

    epsilon = {}
    for m, source_m in enumerate(source_sequence):
        epsilon[m, 0] = source_m.squeeze(1)
        epsilon[m + 1, -1] = 0

    s = 1
    m = (len(source_sequence) - 1) - 2 * num_applications
    initial_m = m
    while m < len(source_sequence) - 1:
        while m >= initial_m:
            # Sablonniere modifier
            inverse_scaling = np.floor(s / 2) + 1

            epsilon[m, s] = epsilon[m + 1, s - 2] + inverse_scaling * inverse(
                epsilon[m + 1, s - 1] - epsilon[m, s - 1]
            )
            epsilon.pop((m + 1, s - 2))
            m -= 1
            s += 1
        m += 1
        s -= 1
        epsilon.pop((m, s - 1))
        m = initial_m + s
        s = 1

    return epsilon[initial_m, 2 * num_applications]

In [5]:
@ch.no_grad()
def hso_hvp(vec,
            hvp_module,
            acceleration_order: int = 8,
            initial_scale_factor: float = 100,
            num_update_steps: int = 20,):
    plotlosses = PlotLosses()

    # Detach and clone input
    vector_cache = vec.detach().clone()
    update_sum   = vec.detach().clone()
    coefficient_cache = 1

    cached_update_sums = []
    if acceleration_order > 0 and num_update_steps == 2 * acceleration_order + 1:
        cached_update_sums.append(update_sum)

    # Do HessianSeries calculation
    for update_step in range(1, num_update_steps):
        hessian2_vector_cache = hvp_module.hvp(hvp_module.hvp(vector_cache))
        # print("h2v", hessian2_vector_cache)

        if update_step == 1:
            scale_factor = ch.norm(hessian2_vector_cache, p=2) / ch.norm(vec, p=2)
            scale_factor = max(scale_factor.item(), initial_scale_factor)

        vector_cache = (vector_cache - (1/scale_factor)*hessian2_vector_cache).clone()
        coefficient_cache *= (2 * update_step - 1) / (2 * update_step)
        update_sum += coefficient_cache * vector_cache
        
        # print(coefficient_cache * ch.norm(vector_cache, 2).item())

        if acceleration_order > 0 and update_step >= (num_update_steps - 2 * acceleration_order - 1):
            cached_update_sums.append(update_sum.clone())
        
        # Keep track of norm b/w ground truth and result right now
        # And also cosine similarity
        touse_sum = update_sum / np.sqrt(scale_factor)
        # print(ACTUAL_IHVP.shape, touse_sum.shape)
        # print(touse_sum.shape, ACTUAL_IHVP.shape)
        plotlosses.update({
            'L2 norm distance': ch.norm(ACTUAL_IHVP - touse_sum, 2).item(),
            'Cosine similarity': nn.functional.cosine_similarity(touse_sum, ACTUAL_IHVP, eps=0, dim=0).item(),
        })
        plotlosses.send()

    # Perform series acceleration (Shanks acceleration)
    if acceleration_order > 0:
        accelerated_sum = compute_epsilon_acceleration(
            cached_update_sums, num_applications=acceleration_order
        )
        accelerated_sum /= np.sqrt(scale_factor)
        accelerated_sum = accelerated_sum.unsqueeze(1)
        
        plotlosses.update({
            'L2 norm distance': ch.norm(ACTUAL_IHVP - accelerated_sum, 2).item(),
            'Cosine similarity': nn.functional.cosine_similarity(accelerated_sum, ACTUAL_IHVP, eps=0, dim=0).item(),
        })
        plotlosses.send()
        
        return accelerated_sum

    update_sum /= np.sqrt(scale_factor)
    return update_sum

In [None]:
wanted = hso_hvp(flat_grad,
                 # hvp_module,
                 wpr,
                 initial_scale_factor=1.2e6,
                 acceleration_order=30,
                 num_update_steps=100)

In [None]:
print(spec_radius(wpr.H.cpu() @ wpr.H.cpu()) / V)

In [None]:
ch.sort(ch.abs(ch.linalg.eigh(wpr.H.cpu() @ wpr.H.cpu()).eigenvalues)).values

In [None]:
V = 1.1e6
term = ch.eye(1024)
initial = (ch.eye(1024) - (wpr.H.cpu() @ wpr.H.cpu())/V)
for _ in range(10):
    term = term @ initial
    print(spec_radius(term))

In [None]:
V = (spec_radius(wpr.H) ** 2).item() + 10000

In [None]:
ch.min(ch.abs(ch.linalg.eigh(wpr.H @ wpr.H / V).eigenvalues))

In [None]:
ch.linalg.eigh(ch.eye(1024).cuda() - wpr.H @ wpr.H / V).eigenvalues

In [6]:
def is_pos_def(x):
    return np.all(np.linalg.eigvals(x) > 0)

In [None]:
is_pos_def(wpr.H.cpu().numpy())

In [None]:
print(wanted)
print(ch.norm(wanted))

In [None]:
exact_ihvp = exact_inverse_h @ flat_grad.cpu()
print(exact_ihvp)
print(ch.norm(exact_ihvp))

In [None]:
exact_inverse_h = ch.linalg.inv(wpr.H.cpu())

In [None]:
inner_matrix = (ch.eye(wpr.H.shape[0]) - (wpr.H @ wpr.H).cpu() / 1.1e6)

In [None]:
eigv, eigh = ch.linalg.eigh((wpr.H.cpu() @ wpr.H.cpu()))

In [None]:
eigh @ ch.diag(eigv) @ eigh.T

In [None]:
wpr.H.cpu() @ wpr.H.cpu()

In [None]:
ch.max(ch.abs(eigv))

In [None]:
eigv, eigh = ch.linalg.eigh(ch.eye(wpr.H.shape[0]) - (wpr.H.cpu() @ wpr.H.cpu())/1)

In [None]:
ch.abs(eigv)

In [None]:
ch.max(ch.abs(eigv))

In [None]:
spec_radius(inner_matrix)

In [None]:
ch.norm(inner_matrix ** 10)

In [7]:
def spec_radius(x):
    return ch.max(ch.abs(ch.linalg.eigh(x).eigenvalues))

In [None]:
spec_radius(exact_H.cpu()) ** 2

In [None]:
spec_radius(wpr.H.cpu() @ wpr.H.cpu())

In [None]:
1 - 1049591.2500 / 1.2e6

In [None]:
print(ch.linalg.inv(wpr.H) @ flat_grad)

In [None]:
exact_inverse_h = ch.linalg.inv(wpr.H.cpu())

In [None]:
print(wpr.H.cuda() @ (wpr.H.cuda() @ flat_grad))

In [None]:
lissa_ivhp = fast_ihvp(model, flat_grad, dataloader, criterion, device="cuda")

In [None]:
lissa_ivhp

Wow, either I didn't pick the right scale or LiSSA is **way off** in iHVP computation! 

In [10]:
model = nn.Sequential(nn.Linear(200, 4), nn.ReLU(), nn.Linear(4, 2))
criterion = nn.CrossEntropyLoss()
model.cuda()

class MyObjective(BaseObjective):
    def train_outputs(self, model, batch):
        return model(batch[0])

    def train_loss_on_outputs(self, outputs, batch):
        return criterion(outputs, batch[1])  # mean reduction required

    def train_regularization(self, params):
        return 0

    def test_loss(self, model, params, batch):
        return criterion(model(batch[0]), batch[1])  # no regularization in test loss

In [9]:
# Create TensorDataset
y_0 = ch.randn(200, 200) + 1.
y_1 = ch.randn(200, 200) + 2.
y = ch.tensor([0] * 200 + [1] * 200)
dataset = TensorDataset(ch.cat((y_0, y_1), 0), y)

# Create DataLoader
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [None]:
model, best_acc, best_loss = train_model(
    model,
    criterion,
    dataloader,
    dataloader,
    1e-3,
    60,
    pick_n=1,
    pick_mode="last",
)
model.cuda()

In [None]:
random_x, random_y = ch.randn(1, 64) + 1, ch.tensor([0])
random_x = random_x.cuda()
random_y = random_y.cuda()
# Compute gradient for this point
model.zero_grad()
loss = criterion(model(random_x), random_y)
loss.backward()
flat_grad = []
for p in model.parameters():
    flat_grad.append(p.grad.detach().view(-1))
flat_grad = ch.cat(flat_grad)
model.zero_grad()

In [None]:
model

In [None]:
exact_H = compute_hessian(model, dataloader, criterion, device = "cuda")

In [None]:
exact_H.shape

In [None]:
ch.linalg.inv(exact_H.cpu())

In [None]:
z = ch.abs(ch.linalg.eigh(exact_H.cpu()).eigenvalues)

In [None]:
ch.sort(z).values

In [None]:
abs2_vals = ch.abs(ch.linalg.eigh(exact_H.cpu() @ exact_H.cpu()).eigenvalues)
print(min(abs2_vals), max(abs2_vals))
condition = max(abs2_vals) / min(abs2_vals)
print(condition)

In [None]:
ch.linalg.inv(exact_H)

In [None]:
# Create random flat_grad gradient
flat_grad = ch.randn(1024, 1).cuda()

In [None]:
class PseudoWrapper:
    def __init__(self):
        self.H = ch.from_numpy(make_spd_matrix(1024)).float().cuda()
    
    def hvp(self, v):
        return self.H @ v

In [None]:
wpr = PseudoWrapper()

In [None]:
exact_H = wpr.H

In [None]:
wpr.H.float().dtype

In [None]:
ACTUAL_IHVP = (ch.linalg.inv(wpr.H.cpu()) @ flat_grad.cpu()).cuda()

In [None]:
ACTUAL_IHVP

In [None]:
wpr.H

In [None]:
exact_inverse_h = ch.linalg.inv(exact_H)

In [None]:
exact_inverse_h

In [None]:
exact_inverse_h @ flat_grad.cpu()

In [None]:
hvp_module = HVPModule(
    model,
    MyObjective(),
    dataloader,
    device="cuda"
)

In [None]:
hvp_module.hvp(flat_grad)

In [None]:
exact_H.cpu() @ flat_grad.cpu()

In [None]:
wanted = hso_hvp(flat_grad, hvp_module,
                 acceleration_order=2,
                 num_update_steps=10)

In [None]:
wanted

In [None]:
exact_H

In [None]:
# Compute ihvp with LiSSA
module = LiSSAInfluenceModule(
    model=model,
    objective=MyObjective(),
    train_loader=dataloader,
    test_loader=None,
    device="cuda",
    damp=0,
    repeat=10,
    depth=10, #5000 for MLP and Transformer, 10000 for CNN
    scale=10 # test in {10, 25, 50, 100, 150, 200, 250, 300, 400, 500} for convergence
)

In [None]:
ihvp = module.inverse_hvp(flat_grad)

##  Synthetic Gaussian Data

In [None]:
D = 1_000
n = 100
gamma = 0.5

p = gamma * n

# Sample x_0
x_0 = ch.randn(1, D)

# Sample beta
beta = ch.normal(ch.zeros(D, ), ch.ones(D, ) / D)

In [None]:
# Sample n x D matrix
X = ch.randn(n, D)

In [None]:
def get_y(a):
    wx = a @ beta
    return wx + ch.randn(wx.shape)

In [None]:
def get_hessian(a):
    return (a.T @ a) / n

In [None]:
X_0 = X.clone()
X_1 = X.clone()
X_1[0] = x_0.clone()