In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.utils import data
from tqdm import tqdm
import copy
import time
import abc
from typing import Any, List, Optional, Callable
import torch.optim as optim
from torchvision import datasets, transforms, models
import os
import torch.nn as nn
import torch.optim as optim

# Code from influence

In [None]:
def _set_attr(obj, names, val):
    if len(names) == 1:
        setattr(obj, names[0], val)
    else:
        _set_attr(getattr(obj, names[0]), names[1:], val)


def _del_attr(obj, names):
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        _del_attr(getattr(obj, names[0]), names[1:])

class BaseObjective(abc.ABC):
    """An abstract adapter that provides torch-influence with project-specific information
    about how training and test objectives are computed.

    In order to use torch-influence in your project, a subclass of this module should be
    created that implements this module's four abstract methods.
    """

    @abc.abstractmethod
    def train_outputs(self, model: nn.Module, batch: Any) -> torch.Tensor:
        """Returns a batch of model outputs (e.g., logits, probabilities) from a batch of data.

        Args:
            model: the model.
            batch: a batch of training data.

        Returns:
            the model outputs produced from the batch.
        """

        raise NotImplementedError()

    @abc.abstractmethod
    def train_loss_on_outputs(self, outputs: torch.Tensor, batch: Any) -> torch.Tensor:
        """Returns the **mean**-reduced loss of the model outputs produced from a batch of data.

        Args:
            outputs: a batch of model outputs.
            batch: a batch of training data.

        Returns:
            the loss of the outputs over the batch.

        Note:
            There may be some ambiguity in how to define :meth:`train_outputs()` and
            :meth:`train_loss_on_outputs()`: what point in the forward pass deliniates
            outputs from loss function? For example, in binary classification, the
            outputs can reasonably be taken to be the model logits or normalized probabilities.

            For standard use of influence functions, both choices produce the same behaviour.
            However, if using the Gauss-Newton Hessian approximation for influence functions,
            we require that :meth:`train_loss_on_outputs()` be convex in the model
            outputs.

        See also:
            :class:`CGInfluenceModule`
            :class:`LiSSAInfluenceModule`
        """

        raise NotImplementedError()

    @abc.abstractmethod
    def train_regularization(self, params: torch.Tensor) -> torch.Tensor:
        """Returns the regularization loss at a set of model parameters.

        Args:
            params: a flattened vector of model parameters.

        Returns:
            the regularization loss.
        """

        raise NotImplementedError()

    def train_loss(self, model: nn.Module, params: torch.Tensor, batch: Any) -> torch.Tensor:
        """Returns the **mean**-reduced regularized loss of a model over a batch of data.

        This method should not be overridden for most use cases. By default, torch-influence
        takes and expects the overall training loss to be::

            outputs = train_outputs(model, batch)
            loss = train_loss_on_outputs(outputs, batch) + train_regularization(params)

        Args:
            model: the model.
            params: a flattened vector of the model's parameters.
            batch: a batch of training data.

        Returns:
            the training loss over the batch.
        """

        outputs = self.train_outputs(model, batch)
        return self.train_loss_on_outputs(outputs, batch) + self.train_regularization(params)

    @abc.abstractmethod
    def test_loss(self, model: nn.Module, params: torch.Tensor, batch: Any) -> torch.Tensor:
        """Returns the **mean**-reduced loss of a model over a batch of data.

        Args:
            model: the model.
            params: a flattened vector of the model's parameters.
            batch: a batch of test data.

        Returns:
            the test loss over the batch.
        """

        raise NotImplementedError()



class BaseInfluenceModule(abc.ABC):
    """The core module that contains convenience methods for computing influence functions.

    Args:
        model: the model of interest.
        objective: an implementation of :class:`BaseObjective`.
        train_loader: a training dataset loader.
        test_loader: a test dataset loader.
        device: the device on which operations are performed.
    """

    def __init__(
            self,
            model: nn.Module,
            objective: BaseObjective,
            train_loader: data.DataLoader,
            test_loader: data.DataLoader,
            device: torch.device
    ):
        model.eval()
        self.model = model.to(device)
        self.device = device

        self.is_model_functional = False
        self.params_names = tuple(name for name, _ in self._model_params())
        self.params_shape = tuple(p.shape for _, p in self._model_params())

        self.objective = objective
        self.train_loader = train_loader
        self.test_loader = test_loader

    @abc.abstractmethod
    def inverse_hvp(self, vec: torch.Tensor) -> torch.Tensor:
        """Computes an inverse-Hessian vector product, where the Hessian is specifically
        that of the (mean) empirical risk over the training dataset.

        Args:
            vec: a vector.

        Returns:
            the inverse-Hessian vector product.
        """

        raise NotImplementedError()

    # ====================================================
    # Interface functions
    # ====================================================

    def train_loss_grad(self, train_idxs: List[int]) -> torch.Tensor:
        """Returns the gradient of the (mean) training loss over a set of training
        data points with respect to the model's flattened parameters.

        Args:
            train_idxs: the indices of the training points.

        Returns:
            the loss gradient at the training points.
        """

        return self._loss_grad(train_idxs, train=True)

    def test_loss_grad(self, test_idxs: List[int]) -> torch.Tensor:
        """Returns the gradient of the (mean) test loss over a set of test
        data points with respect to the model's flattened parameters.

        Args:
           test_idxs: the indices of the test points.

        Returns:
           the loss gradient at the test points.
        """

        return self._loss_grad(test_idxs, train=False)

    def stest(self, test_idxs: List[int]) -> torch.Tensor:
        """This function simply composes :func:`inverse_hvp` with :func:`test_loss_grad`.

        In the original influence function paper, the resulting vector was called
        :math:`\mathbf{s}_{\mathrm{test}}`.

        Args:
            test_idxs: the indices of the test points.

        Returns:
            the :math:`\mathbf{s}_{\mathrm{test}}` vector.
        """

        return self.inverse_hvp(self.test_loss_grad(test_idxs))

    def influences(
            self,
            train_idxs: List[int],
            test_idxs: List[int],
            stest: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Returns the influence scores of a set of training data points with respect to
        the (mean) test loss over a set of test data points.

        Specifically, this method returns a 1D tensor of ``len(train_idxs)`` influence scores.
        These scores estimate the following quantities:

            Let :math:`\mathcal{L}_0` be the (mean) test loss of the current model
            over the input test points. Suppose we produce a new model by (1) removing
            the ``train_idxs[i]``-th example from the training dataset and (2) retraining
            the model on this one-smaller dataset. Let :math:`\mathcal{L}` be the (mean)
            test loss of the **new** model over the input test points. Then the ``i``-th
            influence score estimates :math:`\mathcal{L} - \mathcal{L}_0`.

        Args:
            train_idxs: the indices of the training points.
            test_idxs: the indices of the test points.
            stest: this method requires the :math:`\mathbf{s}_{\mathrm{test}}` vector of
                the input test points. If not ``None``, this argument will be used taken as
                :math:`\mathbf{s}_{\mathrm{test}}`. Otherwise, :math:`\mathbf{s}_{\mathrm{test}}`
                will be computed internally with :meth:`stest`.

        Returns:
            the influence scores.
        """

        stest = self.stest(test_idxs) if (stest is None) else stest.to(self.device)

        scores = []
        for grad_z, _ in self._loss_grad_loader_wrapper(batch_size=1, subset=train_idxs, train=True):
            s = grad_z @ stest
            scores.append(s)
        return torch.tensor(scores) / len(self.train_loader.dataset)

    def unlearning(
            self,
            train_idxs: List[int]
    ) -> torch.Tensor:
        """Unlearns pre-specified training samples from a trained model .

        Returns:
            the unlearned model
        """
        curr_vec = self.inverse_hvp(self.train_loss_grad(train_idxs), train_len = len(self.train_loader.dataset), unlearning=True)
        return self.model

    # ====================================================
    # Private helper functions
    # ====================================================

    # Model and parameter helpers

    def _model_params(self, with_names=True):
        assert not self.is_model_functional
        return tuple((name, p) if with_names else p for name, p in self.model.named_parameters() if p.requires_grad)

    def _model_make_functional(self):
        assert not self.is_model_functional
        params = tuple(p.detach().requires_grad_() for p in self._model_params(False))

        for name in self.params_names:
            _del_attr(self.model, name.split("."))
        self.is_model_functional = True

        return params

    def _model_reinsert_params(self, params, register=False):
        for name, p in zip(self.params_names, params):
            _set_attr(self.model, name.split("."), torch.nn.Parameter(p) if register else p)
        self.is_model_functional = not register

    def _flatten_params_like(self, params_like):
        vec = []
        for p in params_like:
            vec.append(p.view(-1))
        return torch.cat(vec)

    def _reshape_like_params(self, vec):
        pointer = 0
        split_tensors = []
        for dim in self.params_shape:
            num_param = dim.numel()
            split_tensors.append(vec[pointer: pointer + num_param].view(dim))
            pointer += num_param
        return tuple(split_tensors)

    # Data helpers

    def _transfer_to_device(self, batch):
        if isinstance(batch, torch.Tensor):
            return batch.to(self.device)
        elif isinstance(batch, (tuple, list)):
            return type(batch)(self._transfer_to_device(x) for x in batch)
        elif isinstance(batch, dict):
            return {k: self._transfer_to_device(x) for k, x in batch.items()}
        else:
            raise NotImplementedError()

    def _loader_wrapper(self, train, batch_size=None, subset=None, sample_n_batches=-1):
        loader = self.train_loader if train else self.test_loader
        batch_size = loader.batch_size if (batch_size is None) else batch_size

        if subset is None:
            dataset = loader.dataset
        else:
            subset = np.array(subset)
            if len(subset.shape) != 1 or len(np.unique(subset)) != len(subset):
                raise ValueError()
            if np.any((subset < 0) | (subset >= len(loader.dataset))):
                raise IndexError()
            dataset = data.Subset(loader.dataset, indices=subset)

        if sample_n_batches > 0:
            num_samples = sample_n_batches * batch_size
            sampler = data.RandomSampler(data_source=dataset, replacement=True, num_samples=num_samples)
        else:
            sampler = None

        new_loader = data.DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=False,
            sampler=sampler,
            collate_fn=loader.collate_fn,
            num_workers=loader.num_workers,
            worker_init_fn=loader.worker_init_fn,
        )

        data_left = len(dataset)
        for batch in new_loader:
            batch = self._transfer_to_device(batch)
            size = min(batch_size, data_left)  # deduce batch size
            yield batch, size
            data_left -= size

    # Loss and autograd helpers

    def _loss_grad_loader_wrapper(self, train, **kwargs):
        params = self._model_params(with_names=False)
        flat_params = self._flatten_params_like(params)

        for batch, batch_size in self._loader_wrapper(train=train, **kwargs):
            loss_fn = self.objective.train_loss if train else self.objective.test_loss
            loss = loss_fn(model=self.model, params=flat_params, batch=batch)
            yield self._flatten_params_like(torch.autograd.grad(loss, params)), batch_size

    def _loss_grad(self, idxs, train):
        grad = 0.0
        for grad_batch, batch_size in self._loss_grad_loader_wrapper(subset=idxs, train=train):
            grad = grad + grad_batch * batch_size
        return grad / len(idxs)

    def _hvp_at_batch(self, batch, flat_params, vec, gnh):

        def f(theta_):
            self._model_reinsert_params(self._reshape_like_params(theta_))
            return self.objective.train_loss(self.model, theta_, batch)

        def out_f(theta_):
            self._model_reinsert_params(self._reshape_like_params(theta_))
            return self.objective.train_outputs(self.model, batch)

        def loss_f(out_):
            return self.objective.train_loss_on_outputs(out_, batch)

        def reg_f(theta_):
            return self.objective.train_regularization(theta_)

        if gnh:
            y, jvp = torch.autograd.functional.jvp(out_f, flat_params, v=vec)
            hjvp = torch.autograd.functional.hvp(loss_f, y, v=jvp)[1]
            gnhvp_batch = torch.autograd.functional.vjp(out_f, flat_params, v=hjvp)[1]
            return gnhvp_batch + torch.autograd.functional.hvp(reg_f, flat_params, v=vec)[1]
        else:
            return torch.autograd.functional.hvp(f, flat_params, v=vec)[1]


class LiSSAInfluenceModule(BaseInfluenceModule):
    r"""An influence module that computes inverse-Hessian vector products
    using the Linear time Stochastic Second-Order Algorithm (LiSSA).

    At a high level, LiSSA estimates an inverse-Hessian vector product
    by using truncated Neumann iterations:

    .. math::
        \mathbf{H}^{-1}\mathbf{v} \approx \frac{1}{R}\sum\limits_{r = 1}^R
        \left(\sigma^{-1}\sum_{t = 1}^{T}(\mathbf{I} - \sigma^{-1}\mathbf{H}_{r, t})^t\mathbf{v}\right)

    Here, :math:`\mathbf{H}` is the risk Hessian matrix and :math:`\mathbf{H}_{r, t}` are
    loss Hessian matrices over batches of training data drawn randomly with replacement (we
    also use a batch size in ``train_loader``). In addition, :math:`\sigma > 0` is a scaling
    factor chosen sufficiently large such that :math:`\sigma^{-1} \mathbf{H} \preceq \mathbf{I}`.

    In practice, we can compute each inner sum recursively. Starting with
    :math:`\mathbf{h}_{r, 0} = \mathbf{v}`, we can iteratively update for :math:`T` steps:

    .. math::
        \mathbf{h}_{r, t} = \mathbf{v} + \mathbf{h}_{r, t - 1} - \sigma^{-1}\mathbf{H}_{r, t}\mathbf{h}_{r, t - 1}

    where :math:`\mathbf{h}_{r, T}` will be equal to the :math:`r`-th inner sum.

    Args:
        model: the model of interest.
        objective: an implementation of :class:`BaseObjective`.
        train_loader: a training dataset loader.
        test_loader: a test dataset loader.
        device: the device on which operations are performed.
        damp: the damping strength :math:`\lambda`. Influence functions assume that the
            risk Hessian :math:`\mathbf{H}` is positive-definite, which often fails to
            hold for neural networks. Hence, a damped risk Hessian :math:`\mathbf{H} + \lambda\mathbf{I}`
            is used instead, for some sufficiently large :math:`\lambda > 0` and
            identity matrix :math:`\mathbf{I}`.
        repeat: the number of trials :math:`R`.
        depth: the recurrence depth :math:`T`.
        scale: the scaling factor :math:`\sigma`.
        gnh: if ``True``, the risk Hessian :math:`\mathbf{H}` is approximated with
            the Gauss-Newton Hessian, which is positive semi-definite.
            Otherwise, the risk Hessian is used.
        debug_callback: a callback function which is passed in :math:`(r, t, \mathbf{h}_{r, t})`
            at each recurrence step.
     """

    def __init__(
            self,
            model: nn.Module,
            objective: BaseObjective,
            train_loader: data.DataLoader,
            test_loader: data.DataLoader,
            device: torch.device,
            damp: float,
            repeat: int,
            depth: int,
            scale: float,
            gnh: bool = False,
            debug_callback: Optional[Callable[[int, int, torch.Tensor], None]] = None
    ):

        super().__init__(
            model=model,
            objective=objective,
            train_loader=train_loader,
            test_loader=test_loader,
            device=device,
        )

        self.damp = damp
        self.gnh = gnh
        self.repeat = repeat
        self.depth = depth
        self.scale = scale
        self.debug_callback = debug_callback

    def inverse_hvp(self, vec, unlearning=False):

        params = self._model_make_functional()
        flat_params = self._flatten_params_like(params)

        ihvp = 0.0

        for r in range(self.repeat):

            h_est = vec.clone()

            for t, (batch, _) in enumerate(self._loader_wrapper(sample_n_batches=self.depth, train=True)):

                hvp_batch = self._hvp_at_batch(batch, flat_params, vec=h_est, gnh=self.gnh)

                with torch.no_grad():
                    hvp_batch = hvp_batch + self.damp * h_est
                    h_est = vec + h_est - hvp_batch / self.scale

                if self.debug_callback is not None:
                    self.debug_callback(r, t, h_est)

            ihvp = ihvp + h_est / self.scale

        with torch.no_grad():
            self._model_reinsert_params(self._reshape_like_params(flat_params), register=True)

        return ihvp / self.repeat


Common Functions

In [None]:
L2_WEIGHT = 1e-6 #1e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

###########################################
# train resnet
def train_model(model, train_loader, test_loader, DEVICE, lr = 0.001, num_epochs = 10, L2_WEIGHT= 1e-6):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=L2_WEIGHT)

    # Fine-tuning the model
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for images, labels in tqdm(train_loader, leave=False):
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')

        # Evaluate the model on the test set
        model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')
    return model

###########################################
# calculate test losses
def return_test_losses(test_loader, trained_model):
    test_losses = []
    trained_model.eval()
    with torch.no_grad():
        for images, labels in test_loader:  # Assuming test_loader is defined
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            # Forward pass to get model outputs
            outputs = trained_model(images)

            # Compute losses with reduction set to 'none' to get losses per image
            losses = nn.CrossEntropyLoss(reduction='none')(outputs, labels)
            test_losses.append(losses.cpu())  # Store losses
        # Concatenate all losses into a single tensor
        test_losses_tensor = torch.cat(test_losses)

    return test_losses_tensor

###########################################
# caption CIFAR10 images
def load_cifar10_examples(split, desired_classes, idx):
    """Load CIFAR-10 examples based on the split and index."""
    if split == "train":
        dataset = datasets.CIFAR10(root='./data', train=True, download=True)
    else:
        dataset = datasets.CIFAR10(root='./data', train=False, download=True)

    indices = [i for i, (_, label) in enumerate(dataset) if label in desired_classes]
    filtered_train_dataset = data.Subset(dataset, indices)

    image, label = filtered_train_dataset[idx]
    return image, label

def captioned_image(model, datasample, split, idx, scores, desired_classes, device):

    y_hat = model(datasample[0].view(1, datasample[0].shape[0], datasample[0].shape[1], datasample[0].shape[2]).to(device)).argmax(dim=1).item()  # Use argmax for multi-class prediction

    # turn image into [0 .. 255] RGB image
    image, _ = load_cifar10_examples(split, desired_classes, idx)
    image = np.array(image)

    # turn labels into human-readable strings
    class_names = ["plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
    label = class_names[int(datasample[1])]
    pred = class_names[int(y_hat)]

    if split == "test":
        score_caption = f"Test Loss: {scores[idx]:.5f}"
    else:
        score_caption = f"Influence: {scores[idx]:+.5f}"
    label_caption = f"Pred: {pred}, Label: {label}"
    return image, label_caption + "\n" + score_caption


# Main Code

Load dataset

In [None]:
print(f"Using {DEVICE}")

# ===========
# Load model and data
# ==========
batch_size = 128
# Data transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to match ResNet input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet normalization
])
# we take only two classes out of the CIFAR10 data set
desired_classes = [0, 1]  # Change to the class indices you want

# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
indices = [i for i, (_, label) in enumerate(train_dataset) if label in desired_classes]
filtered_train_dataset = data.Subset(train_dataset, indices)
print('Loaded 1')

test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
indices = [i for i, (_, label) in enumerate(test_dataset) if label in desired_classes]
filtered_test_dataset = data.Subset(test_dataset, indices)
print('Loaded 2')

train_loader = data.DataLoader(dataset=filtered_train_dataset, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=filtered_test_dataset, batch_size=batch_size, shuffle=False)


In [None]:
# load pre-trained model and replace its head
model_type = 'ResNet18'
if model_type == 'VGG13':
    model = models.vgg13(weights="IMAGENET1K_V1")
    model.classifier[6] = nn.Linear(4096, len(desired_classes))
    model.to(DEVICE)
else:
    model = models.resnet18(weights="IMAGENET1K_V1")
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, len(desired_classes)) # 10 classes in CIFAR-10
    model.to(DEVICE)

# initialize hyperparameters
lr = 0.00001
num_epochs = 50
num_test_points = 30
num_test_point_plot = 3

# lissa hyperparameters
repeat_lissa = 5
depth_lissa = 1000
scale_lissa = 500
model_path = os.path.join(os.getcwd(), model_type, '_trained_cifar10_lr_' + str(lr) + '_epochs_' + str(num_epochs) \
        + '_l2weight_' + str(L2_WEIGHT) + '_net.pth')
if os.path.exists(model_path):
    # load pre-trained model and replace its head
    if model_type == 'VGG13':
        trained_model = models.vgg13(weights="IMAGENET1K_V1")
        trained_model.classifier[6] = nn.Linear(4096, len(desired_classes))
    else:
        trained_model = models.resnet18(weights="IMAGENET1K_V1")
        num_ftrs = trained_model.fc.in_features
        trained_model.fc = nn.Linear(num_ftrs, len(desired_classes))
else:
    trained_model = train_model(model, train_loader, test_loader, DEVICE, lr = lr, num_epochs = num_epochs, L2_WEIGHT = L2_WEIGHT)
    save_filename = model_type + f'_trained_cifar10_lr_' + str(lr) + '_epochs_' + str(num_epochs) \
        + '_l2weight_' + str(L2_WEIGHT) + '_net.pth'
    save_path = os.path.join(os.getcwd(), save_filename)
    torch.save(model.state_dict(), save_path)
    print('Model saved at: ' + save_path)
print('Train model is loaded')

# ===========
# Get indices of top 'num_test_points' test images with highest test loss
# ===========
trained_model.eval()
test_losses = return_test_losses(test_loader, trained_model)
test_idxs = torch.argsort(test_losses, dim=0, descending=True)[:num_test_points]
test_idxs = test_idxs.tolist()
test_idxs_to_plot = test_idxs[:num_test_point_plot]
test_images = [captioned_image(trained_model, filtered_test_dataset[idx], "test", idx, test_losses, desired_classes, device=DEVICE) for idx in test_idxs_to_plot]

# ===========
# Initialize influence module using custom objective
# ===========
class BinClassObjective(BaseObjective):
    def train_outputs(self, model, batch):
        return model(batch[0])
    def train_loss_on_outputs(self, outputs, batch):
        return torch.nn.CrossEntropyLoss()(outputs, batch[1])
    def train_regularization(self, params):
        return L2_WEIGHT * torch.square(params.norm())
    def test_loss(self, model, params, batch):
        outputs = model(batch[0])
        return torch.nn.CrossEntropyLoss()(outputs, batch[1])
# GNH
curr_net_gnh = copy.deepcopy(trained_model)
lissa_gnh = LiSSAInfluenceModule(model=curr_net_gnh,objective=BinClassObjective(),train_loader=data.DataLoader(filtered_train_dataset, batch_size=32),
                                 test_loader=data.DataLoader(filtered_test_dataset, batch_size=32),device=DEVICE,damp=0.001,repeat= repeat_lissa,depth=depth_lissa,
                                 scale=scale_lissa,gnh=True)
# Newton
curr_net_Newton = copy.deepcopy(trained_model)
lissa_Newton = LiSSAInfluenceModule(model=curr_net_Newton,objective=BinClassObjective(),train_loader=data.DataLoader(filtered_train_dataset, batch_size=32),
                                 test_loader=data.DataLoader(filtered_test_dataset, batch_size=32),device=DEVICE,damp=0.001,repeat= repeat_lissa,depth=depth_lissa,
                                 scale=scale_lissa,gnh=False)
# ===========
# For each test point:
#   1. Get the influence scores for a randomly selected subset of training points
#   2. Find the most helpful and harmful training points out of this subset
# The most helpful point is that which, if removed, most increases the loss at the
# test point of interest (as predicted by the influence scores). Conversely, the most harmful
# test point is that which most decreases the test loss if removed.
# We repeat this process with the Newton-based and the GNH-based influence functions method and compare the time and the results
# ===========
helpful_images_newton = []
harmful_images_newton = []
helpful_images_gnh = []
harmful_images_gnh = []

time_Newton = []
time_gnh = []


influence_train_indices = list(range(len(filtered_train_dataset)))
test_point_ctr = 0

for test_idx in tqdm(test_idxs, desc="Computing Influences"):

    time_start = time.time()
    influences_gnh = lissa_gnh.influences(train_idxs=influence_train_indices, test_idxs=[test_idx])
    time_end = time.time()
    time_gnh.append(time_end - time_start)
    print(f'curr time gnh:{time_end - time_start:.2f}')

    time_start = time.time()
    influences_hessian = lissa_Newton.influences(train_idxs=influence_train_indices, test_idxs=[test_idx])
    time_end = time.time()
    time_Newton.append(time_end - time_start)
    print(f'curr time Newton:{time_end - time_start:.2f}')

    # update images and influences for the points we aim to plot
    if test_idx in test_idxs_to_plot:
        max_index_gnh = influence_train_indices[np.argmax(influences_gnh)]
        max_index_hessian = influence_train_indices[np.argmax(influences_hessian)]
        min_index_gnh = influence_train_indices[np.argmin(influences_gnh)]
        min_index_hessian = influence_train_indices[np.argmin(influences_hessian)]

        helpful_gnh = captioned_image(trained_model, filtered_train_dataset[max_index_gnh], "train",\
                influences_gnh.argmax(), influences_gnh, desired_classes, device=DEVICE)
        harmful_gnh = captioned_image(trained_model, filtered_train_dataset[min_index_gnh], "train",\
                influences_gnh.argmin(), influences_gnh, desired_classes, device=DEVICE)
        helpful_images_gnh.append(helpful_gnh)
        harmful_images_gnh.append(harmful_gnh)

        helpful_newton = captioned_image(trained_model, filtered_train_dataset[max_index_hessian], "train",\
                influences_hessian.argmax(), influences_hessian, desired_classes, device=DEVICE)
        harmful_newton = captioned_image(trained_model, filtered_train_dataset[min_index_hessian], "train",\
                influences_hessian.argmin(), influences_hessian, desired_classes, device=DEVICE)
        helpful_images_newton.append(helpful_newton)
        harmful_images_newton.append(harmful_newton)

    print('======================')
    print('Finished test point #' + str(test_idx))
    test_point_ctr += 1

# save all data to a .pkl file, and plot the histogram
image_grid_gnh = [test_images, helpful_images_gnh, harmful_images_gnh]
image_grid_Newton = [test_images, helpful_images_newton, harmful_images_newton]

print('=======Finished Influence calculation, saving to a file=======')
plt.figure(figsize=(10, 6))
plt.hist(time_Newton, bins=10, alpha=0.5, label='Hessian, mean: {:.2f}'.format(np.mean(time_Newton)), color='blue', edgecolor='black')
plt.hist(time_gnh, bins=10, alpha=0.5, label='Fisher, mean: {:.2f}'.format(np.mean(time_gnh)), color='orange', edgecolor='black')
# Add labels and title
plt.xlabel('Running Time [sec]', fontweight="bold")
plt.ylabel('Frequency', fontweight="bold")
plt.title('Histogram of Running Times', fontweight="bold")
plt.legend(prop=dict(weight='bold'))
plt.savefig('running_times_histogram.pdf')
print('Mean Newton:' + str(np.mean(time_Newton)))
print('Mean GNH:' + str(np.mean(time_gnh)))
print('=======Finished=======')

# Plot image grid

In [None]:
# ===========
# Plot image grid
# ==========
image_grid_gnh = [test_images, helpful_images_gnh, harmful_images_gnh]
image_grid_Newton = [test_images, helpful_images_newton, harmful_images_newton]

for idx in range(2):
    if idx == 0:
        curr_image_grid = image_grid_gnh
        fig_title = 'analyze_CIFAR10_gnh.pdf'
    else:
        curr_image_grid = image_grid_Newton
        fig_title = 'analyze_CIFAR10_Newton.pdf'

    fig, axes = plt.subplots(nrows=3, ncols=num_test_point_plot, sharex=True, sharey=True, figsize=(12, 10))
    plt.subplots_adjust(wspace=0.25, hspace=1)

    # plot images
    if len(axes.shape) == 1:
        axes = np.expand_dims(axes, axis=1)
    for row_images, row_axes in zip(curr_image_grid, axes):
        for (image, caption), ax in zip(row_images, row_axes):
            ax.set_title(caption, size="large", weight="bold")
            ax.set(aspect="equal")
            ax.imshow(image)
            ax.set_xticks([])
            ax.set_yticks([])
    fig.tight_layout()

    # write row labels
    row_labels = ["Test Image", "Most Helpful", "Most Harmful"]
    for ax, label in zip(axes[:, 0], row_labels):
        ax.set_ylabel(label, rotation=90, size="x-large", fontweight="bold")
    fig.savefig(fig_title, dpi=300)
    plt.close(fig)
print("Finished printing the image grid")

In [None]:
# Create times histogram
plt.figure(figsize=(10, 6))
plt.hist(time_Newton, bins=10, alpha=0.5, label='Hessian, mean: {:.2f}'.format(np.mean(time_Newton)), color='blue', edgecolor='black')
plt.hist(time_gnh, bins=10, alpha=0.5, label='Fisher, mean: {:.2f}'.format(np.mean(time_gnh)), color='orange', edgecolor='black')
plt.xlabel('Running Time [sec]', fontweight="bold")
plt.ylabel('Frequency', fontweight="bold")
plt.title('Histogram of Running Times', fontweight="bold")
plt.legend(prop=dict(weight='bold'))
plt.savefig('running_times_histogram.pdf')
print('Mean Newton:' + str(np.mean(time_Newton)))
print('Mean GNH:' + str(np.mean(time_gnh)))
print('=======Finished=======')