In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils import data
from tqdm import tqdm
import copy
import time
import abc
from typing import Any, List, Optional, Callable
import torch.optim as optim
import random
import os
import torch.utils.data as data_utils
import urllib
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from collections import namedtuple
from sklearn.preprocessing import StandardScaler



# Load Adult dataset

In [None]:

class NetRegressionCV(nn.Module):
    def __init__(self, input_size, num_classes):
        super(NetRegressionCV, self).__init__()
        size = 30000
        self.first = nn.Linear(input_size, size)
        self.last = nn.Linear(size, num_classes)

    def forward(self, x):
        out = F.selu(self.first(x))
        out = self.last(out)
        return out


def load_adult(nTrain=None, scaler=True, shuffle=False):
    if shuffle:
        print('Warning: I wont shuffle because adult has fixed test set')
    '''
    :param smaller: selecting this flag it is possible to generate a smaller version of the training and test sets.
    :param scaler: if True it applies a StandardScaler() (from sklearn.preprocessing) to the data.
    :return: train and test data.

    Features of the Adult dataset:
    0. age: continuous.
    1. workclass: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked.
    2. fnlwgt: continuous.
    3. education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th,
    Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.
    4. education-num: continuous.
    5. marital-status: Married-civ-spouse, Divorced, Never-married, Separated, Widowed,
    Married-spouse-absent, Married-AF-spouse.
    6. occupation: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty,
    Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv,
    Protective-serv, Armed-Forces.
    7. relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried.
    8. race: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black.
    9. sex: Female, Male.
    10. capital-gain: continuous.
    11. capital-loss: continuous.
    12. hours-per-week: continuous.
    13. native-country: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc),
    India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico,
    Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala,
    Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands.
    (14. label: <=50K, >50K)
    '''
    if not os.path.isfile('adult.data'):
        urllib.request.urlretrieve(
            "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data", "adult.data")
        urllib.request.urlretrieve(
            "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test", "adult.test")
    data = pd.read_csv(
        "adult.data",
        names=[
            "Age", "workclass", "fnlwgt", "education", "education-num", "marital-status",
            "occupation", "relationship", "race", "gender", "capital gain", "capital loss",
            "hours per week", "native-country", "income"]
    )
    len_train = len(data.values[:, -1])
    data_test = pd.read_csv(
        "adult.test",
        names=[
            "Age", "workclass", "fnlwgt", "education", "education-num", "marital-status",
            "occupation", "relationship", "race", "gender", "capital gain", "capital loss",
            "hours per week", "native-country", "income"],
        skiprows=1, header=None
    )
    data = pd.concat([data, data_test])
    # Considering the relative low portion of missing data, we discard rows with missing data
    domanda = data["workclass"][4].values[1]
    data = data[data["workclass"] != domanda]
    data = data[data["occupation"] != domanda]
    data = data[data["native-country"] != domanda]
    # Here we apply discretisation on column marital_status
    data.replace(['Divorced', 'Married-AF-spouse',
                  'Married-civ-spouse', 'Married-spouse-absent',
                  'Never-married', 'Separated', 'Widowed'],
                 ['not married', 'married', 'married', 'married',
                  'not married', 'not married', 'not married'], inplace=True)
    # categorical fields
    category_col = ['workclass', 'race', 'education', 'marital-status', 'occupation',
                    'relationship', 'gender', 'native-country', 'income']
    for col in category_col:
        b, c = np.unique(data[col], return_inverse=True)
        data[col] = c
    datamat = data.values
    # Care there is a final dot in the class only in test set which creates 4 different classes
    target = np.array([-1.0 if (val == 0 or val == 1) else 1.0 for val in np.array(datamat)[:, -1]])
    datamat = datamat[:, :-1]
    if scaler:
        scaler = StandardScaler()
        scaler.fit(datamat)
        datamat = scaler.transform(datamat)
    if nTrain is None:
        nTrain = len_train
    data = namedtuple('_', 'data, target')(datamat[:nTrain, :], target[:nTrain])
    data_test = namedtuple('_', 'data, target')(datamat[len_train:, :], target[len_train:])

    encoded_data = pd.DataFrame(data.data)
    encoded_data['Target'] = (data.target + 1) / 2
    to_protect = 1. * (data.data[:, 9] != data.data[:, 9][0])

    encoded_data_test = pd.DataFrame(data_test.data)
    encoded_data_test['Target'] = (data_test.target + 1) / 2
    to_protect_test = 1. * (data_test.data[:, 9] != data_test.data[:, 9][0])

    # Variable to protect (9:Sex) is removed from dataset
    return encoded_data.drop(columns=9), to_protect, encoded_data_test.drop(columns=9), to_protect_test


# Influence Calculation

In [None]:
def _set_attr(obj, names, val):
    if len(names) == 1:
        setattr(obj, names[0], val)
    else:
        _set_attr(getattr(obj, names[0]), names[1:], val)


def _del_attr(obj, names):
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        _del_attr(getattr(obj, names[0]), names[1:])

class BaseObjective(abc.ABC):
    """An abstract adapter that provides torch-influence with project-specific information
    about how training and test objectives are computed.

    In order to use torch-influence in your project, a subclass of this module should be
    created that implements this module's four abstract methods.
    """

    @abc.abstractmethod
    def train_outputs(self, model: nn.Module, batch: Any) -> torch.Tensor:
        """Returns a batch of model outputs (e.g., logits, probabilities) from a batch of data.

        Args:
            model: the model.
            batch: a batch of training data.

        Returns:
            the model outputs produced from the batch.
        """

        raise NotImplementedError()

    @abc.abstractmethod
    def train_loss_on_outputs(self, outputs: torch.Tensor, batch: Any) -> torch.Tensor:
        """Returns the **mean**-reduced loss of the model outputs produced from a batch of data.

        Args:
            outputs: a batch of model outputs.
            batch: a batch of training data.

        Returns:
            the loss of the outputs over the batch.

        Note:
            There may be some ambiguity in how to define :meth:`train_outputs()` and
            :meth:`train_loss_on_outputs()`: what point in the forward pass deliniates
            outputs from loss function? For example, in binary classification, the
            outputs can reasonably be taken to be the model logits or normalized probabilities.

            For standard use of influence functions, both choices produce the same behaviour.
            However, if using the Gauss-Newton Hessian approximation for influence functions,
            we require that :meth:`train_loss_on_outputs()` be convex in the model
            outputs.

        See also:
            :class:`CGInfluenceModule`
            :class:`LiSSAInfluenceModule`
        """

        raise NotImplementedError()

    @abc.abstractmethod
    def train_regularization(self, params: torch.Tensor) -> torch.Tensor:
        """Returns the regularization loss at a set of model parameters.

        Args:
            params: a flattened vector of model parameters.

        Returns:
            the regularization loss.
        """

        raise NotImplementedError()

    def train_loss(self, model: nn.Module, params: torch.Tensor, batch: Any) -> torch.Tensor:
        """Returns the **mean**-reduced regularized loss of a model over a batch of data.

        This method should not be overridden for most use cases. By default, torch-influence
        takes and expects the overall training loss to be::

            outputs = train_outputs(model, batch)
            loss = train_loss_on_outputs(outputs, batch) + train_regularization(params)

        Args:
            model: the model.
            params: a flattened vector of the model's parameters.
            batch: a batch of training data.

        Returns:
            the training loss over the batch.
        """

        outputs = self.train_outputs(model, batch)
        return self.train_loss_on_outputs(outputs, batch) + self.train_regularization(params)

    @abc.abstractmethod
    def test_loss(self, model: nn.Module, params: torch.Tensor, batch: Any) -> torch.Tensor:
        """Returns the **mean**-reduced loss of a model over a batch of data.

        Args:
            model: the model.
            params: a flattened vector of the model's parameters.
            batch: a batch of test data.

        Returns:
            the test loss over the batch.
        """

        raise NotImplementedError()



class BaseInfluenceModule(abc.ABC):
    """The core module that contains convenience methods for computing influence functions.

    Args:
        model: the model of interest.
        objective: an implementation of :class:`BaseObjective`.
        train_loader: a training dataset loader.
        test_loader: a test dataset loader.
        device: the device on which operations are performed.
    """

    def __init__(
            self,
            model: nn.Module,
            objective: BaseObjective,
            train_loader: data.DataLoader,
            test_loader: data.DataLoader,
            device: torch.device
    ):
        model.eval()
        self.model = model.to(device)
        self.device = device

        self.is_model_functional = False
        self.params_names = tuple(name for name, _ in self._model_params())
        self.params_shape = tuple(p.shape for _, p in self._model_params())

        self.objective = objective
        self.train_loader = train_loader
        self.test_loader = test_loader

    @abc.abstractmethod
    def inverse_hvp(self, vec: torch.Tensor) -> torch.Tensor:
        """Computes an inverse-Hessian vector product, where the Hessian is specifically
        that of the (mean) empirical risk over the training dataset.

        Args:
            vec: a vector.

        Returns:
            the inverse-Hessian vector product.
        """

        raise NotImplementedError()

    # ====================================================
    # Interface functions
    # ====================================================

    def train_loss_grad(self, train_idxs: List[int]) -> torch.Tensor:
        """Returns the gradient of the (mean) training loss over a set of training
        data points with respect to the model's flattened parameters.

        Args:
            train_idxs: the indices of the training points.

        Returns:
            the loss gradient at the training points.
        """

        return self._loss_grad(train_idxs, train=True)

    def test_loss_grad(self, test_idxs: List[int]) -> torch.Tensor:
        """Returns the gradient of the (mean) test loss over a set of test
        data points with respect to the model's flattened parameters.

        Args:
           test_idxs: the indices of the test points.

        Returns:
           the loss gradient at the test points.
        """

        return self._loss_grad(test_idxs, train=False)

    def stest(self, test_idxs: List[int]) -> torch.Tensor:
        """This function simply composes :func:`inverse_hvp` with :func:`test_loss_grad`.

        In the original influence function paper, the resulting vector was called
        :math:`\mathbf{s}_{\mathrm{test}}`.

        Args:
            test_idxs: the indices of the test points.

        Returns:
            the :math:`\mathbf{s}_{\mathrm{test}}` vector.
        """

        return self.inverse_hvp(self.test_loss_grad(test_idxs))

    def influences(
            self,
            train_idxs: List[int],
            test_idxs: List[int],
            stest: Optional[torch.Tensor] = None,
            target_grad: Optional[torch.Tensor] = None,
            influence_objective: Optional[str] = 'Taylor'
    ) -> torch.Tensor:
        """Returns the influence scores of a set of training data points with respect to
        the (mean) test loss over a set of test data points.

        Specifically, this method returns a 1D tensor of ``len(train_idxs)`` influence scores.
        These scores estimate the following quantities:

            Let :math:`\mathcal{L}_0` be the (mean) test loss of the current model
            over the input test points. Suppose we produce a new model by (1) removing
            the ``train_idxs[i]``-th example from the training dataset and (2) retraining
            the model on this one-smaller dataset. Let :math:`\mathcal{L}` be the (mean)
            test loss of the **new** model over the input test points. Then the ``i``-th
            influence score estimates :math:`\mathcal{L} - \mathcal{L}_0`.

        Args:
            train_idxs: the indices of the training points.
            test_idxs: the indices of the test points.
            stest: this method requires the :math:`\mathbf{s}_{\mathrm{test}}` vector of
                the input test points. If not ``None``, this argument will be used taken as
                :math:`\mathbf{s}_{\mathrm{test}}`. Otherwise, :math:`\mathbf{s}_{\mathrm{test}}`
                will be computed internally with :meth:`stest`.

        Returns:
            the influence scores.
        """

        if len(test_idxs) == 0:
            time_start = time.time()
            stest = self.inverse_hvp(self._flatten_params_like(target_grad), train_len = len(self.train_loader.dataset), unlearning=False)
            time_end = time.time()
        else:
            stest = self.stest(test_idxs) if (stest is None) else stest.to(self.device)

        scores = []
        for grad_z, _ in self._loss_grad_loader_wrapper(batch_size=1, subset=train_idxs, train=True):
            s = grad_z @ stest
            scores.append(s)
        return torch.tensor(scores) / len(self.train_loader.dataset), time_end-time_start

    # Unlearning
    def unlearning(
            self,
            train_idxs: List[int]
    ) -> torch.Tensor:
        """Unlearns pre-specified training samples from a trained model .

        Returns:
            the unlearned model
        """
        time_start = time.time()
        curr_vec = self.inverse_hvp(self.train_loss_grad(train_idxs), train_len = len(self.train_loader.dataset), unlearning=True)
        time_end = time.time()
        return self.model, time_end-time_start

    # ACV
    def ACV(
            self,
            train_idx: List[int],
            target_grad: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Unlearns pre-specified training samples from a trained model .

        Returns:
            the unlearned model
        """
        time_start = time.time()
        curr_vec = self.inverse_hvp(self._flatten_params_like(target_grad), train_len = len(self.train_loader.dataset), unlearning=True)
        time_end = time.time()

        # # calculate loss
        # for z_i, _ in self._loader_wrapper(batch_size=1, subset=train_idx, train=True):
        #     params = self._model_params(with_names=False)
        #     flat_params = self._flatten_params_like(params)
        #     loss_est = self.objective.train_loss(self.model, flat_params, z_i)
        return self.model, time_end-time_start

    # ====================================================
    # Private helper functions
    # ====================================================

    # Model and parameter helpers

    def _model_params(self, with_names=True):
        assert not self.is_model_functional
        return tuple((name, p) if with_names else p for name, p in self.model.named_parameters() if p.requires_grad)

    def _model_make_functional(self):
        assert not self.is_model_functional
        params = tuple(p.detach().requires_grad_() for p in self._model_params(False))

        for name in self.params_names:
            _del_attr(self.model, name.split("."))
        self.is_model_functional = True

        return params

    def _model_reinsert_params(self, params, register=False):
        for name, p in zip(self.params_names, params):
            _set_attr(self.model, name.split("."), torch.nn.Parameter(p) if register else p)
        self.is_model_functional = not register

    def _flatten_params_like(self, params_like):
        vec = []
        for p in params_like:
            vec.append(p.view(-1))
        return torch.cat(vec)

    def _reshape_like_params(self, vec):
        pointer = 0
        split_tensors = []
        for dim in self.params_shape:
            num_param = dim.numel()
            split_tensors.append(vec[pointer: pointer + num_param].view(dim))
            pointer += num_param
        return tuple(split_tensors)

    # Data helpers

    def _transfer_to_device(self, batch):
        if isinstance(batch, torch.Tensor):
            return batch.to(self.device)
        elif isinstance(batch, (tuple, list)):
            return type(batch)(self._transfer_to_device(x) for x in batch)
        elif isinstance(batch, dict):
            return {k: self._transfer_to_device(x) for k, x in batch.items()}
        else:
            raise NotImplementedError()

    def _loader_wrapper(self, train, batch_size=None, subset=None, sample_n_batches=-1):
        loader = self.train_loader if train else self.test_loader
        batch_size = loader.batch_size if (batch_size is None) else batch_size

        if subset is None:
            dataset = loader.dataset
        else:
            subset = np.array(subset)
            if len(subset.shape) != 1 or len(np.unique(subset)) != len(subset):
                raise ValueError()
            if np.any((subset < 0) | (subset >= len(loader.dataset))):
                raise IndexError()
            dataset = data.Subset(loader.dataset, indices=subset)

        if sample_n_batches > 0:
            num_samples = sample_n_batches * batch_size
            sampler = data.RandomSampler(data_source=dataset, replacement=True, num_samples=num_samples)
        else:
            sampler = None

        new_loader = data.DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=False,
            sampler=sampler,
            collate_fn=loader.collate_fn,
            num_workers=loader.num_workers,
            worker_init_fn=loader.worker_init_fn,
        )

        data_left = len(dataset)
        for batch in new_loader:
            batch = self._transfer_to_device(batch)
            size = min(batch_size, data_left)  # deduce batch size
            yield batch, size
            data_left -= size

    # Loss and autograd helpers

    def _loss_grad_loader_wrapper(self, train, **kwargs):
        params = self._model_params(with_names=False)
        flat_params = self._flatten_params_like(params)

        for batch, batch_size in self._loader_wrapper(train=train, **kwargs):
            loss_fn = self.objective.train_loss if train else self.objective.test_loss
            loss = loss_fn(model=self.model, params=flat_params, batch=batch)
            yield self._flatten_params_like(torch.autograd.grad(loss, params)), batch_size

    def _loss_grad(self, idxs, train):
        grad = 0.0
        for grad_batch, batch_size in self._loss_grad_loader_wrapper(subset=idxs, train=train):
            grad = grad + grad_batch * batch_size
        return grad / len(idxs)

    def _hvp_at_batch(self, batch, flat_params, vec, gnh):

        def f(theta_):
            self._model_reinsert_params(self._reshape_like_params(theta_))
            return self.objective.train_loss(self.model, theta_, batch)

        def out_f(theta_):
            self._model_reinsert_params(self._reshape_like_params(theta_))
            return self.objective.train_outputs(self.model, batch)

        def loss_f(out_):
            return self.objective.train_loss_on_outputs(out_, batch)

        def reg_f(theta_):
            return self.objective.train_regularization(theta_)

        def train_loss_unregularized(theta_):
            self._model_reinsert_params(self._reshape_like_params(theta_))
            return self.objective.test_loss(self.model, theta_, batch)

        if gnh=='gnh':
            y, jvp = torch.autograd.functional.jvp(out_f, flat_params, v=vec)
            hjvp = torch.autograd.functional.hvp(loss_f, y, v=jvp)[1]
            gnhvp_batch = torch.autograd.functional.vjp(out_f, flat_params, v=hjvp)[1]
            return gnhvp_batch + torch.autograd.functional.hvp(reg_f, flat_params, v=vec)[1]
        elif gnh=='emp':
            y, jvp = torch.autograd.functional.jvp(train_loss_unregularized, flat_params, v=vec)
            gnhvp_batch = torch.autograd.functional.vjp(train_loss_unregularized, flat_params, v=jvp)[1]
            return gnhvp_batch + torch.autograd.functional.hvp(reg_f, flat_params, v=vec)[1]
        else:
            return torch.autograd.functional.hvp(f, flat_params, v=vec)[1]


class LiSSAInfluenceModule(BaseInfluenceModule):
    r"""An influence module that computes inverse-Hessian vector products
    using the Linear time Stochastic Second-Order Algorithm (LiSSA).

    At a high level, LiSSA estimates an inverse-Hessian vector product
    by using truncated Neumann iterations:

    .. math::
        \mathbf{H}^{-1}\mathbf{v} \approx \frac{1}{R}\sum\limits_{r = 1}^R
        \left(\sigma^{-1}\sum_{t = 1}^{T}(\mathbf{I} - \sigma^{-1}\mathbf{H}_{r, t})^t\mathbf{v}\right)

    Here, :math:`\mathbf{H}` is the risk Hessian matrix and :math:`\mathbf{H}_{r, t}` are
    loss Hessian matrices over batches of training data drawn randomly with replacement (we
    also use a batch size in ``train_loader``). In addition, :math:`\sigma > 0` is a scaling
    factor chosen sufficiently large such that :math:`\sigma^{-1} \mathbf{H} \preceq \mathbf{I}`.

    In practice, we can compute each inner sum recursively. Starting with
    :math:`\mathbf{h}_{r, 0} = \mathbf{v}`, we can iteratively update for :math:`T` steps:

    .. math::
        \mathbf{h}_{r, t} = \mathbf{v} + \mathbf{h}_{r, t - 1} - \sigma^{-1}\mathbf{H}_{r, t}\mathbf{h}_{r, t - 1}

    where :math:`\mathbf{h}_{r, T}` will be equal to the :math:`r`-th inner sum.

    Args:
        model: the model of interest.
        objective: an implementation of :class:`BaseObjective`.
        train_loader: a training dataset loader.
        test_loader: a test dataset loader.
        device: the device on which operations are performed.
        damp: the damping strength :math:`\lambda`. Influence functions assume that the
            risk Hessian :math:`\mathbf{H}` is positive-definite, which often fails to
            hold for neural networks. Hence, a damped risk Hessian :math:`\mathbf{H} + \lambda\mathbf{I}`
            is used instead, for some sufficiently large :math:`\lambda > 0` and
            identity matrix :math:`\mathbf{I}`.
        repeat: the number of trials :math:`R`.
        depth: the recurrence depth :math:`T`.
        scale: the scaling factor :math:`\sigma`.
        gnh: if ``True``, the risk Hessian :math:`\mathbf{H}` is approximated with
            the Gauss-Newton Hessian, which is positive semi-definite.
            Otherwise, the risk Hessian is used.
        debug_callback: a callback function which is passed in :math:`(r, t, \mathbf{h}_{r, t})`
            at each recurrence step.
     """

    def __init__(
            self,
            model: nn.Module,
            objective: BaseObjective,
            train_loader: data.DataLoader,
            test_loader: data.DataLoader,
            device: torch.device,
            damp: float,
            repeat: int,
            depth: int,
            scale: float,
            gnh: bool = False,
            debug_callback: Optional[Callable[[int, int, torch.Tensor], None]] = None
    ):

        super().__init__(
            model=model,
            objective=objective,
            train_loader=train_loader,
            test_loader=test_loader,
            device=device,
        )

        self.damp = damp
        self.gnh = gnh
        self.repeat = repeat
        self.depth = depth
        self.scale = scale
        self.debug_callback = debug_callback

    def inverse_hvp(self, vec, train_len, unlearning=False):

        params = self._model_make_functional()
        flat_params = self._flatten_params_like(params)

        ihvp = 0.0

        for r in range(self.repeat):

            h_est = vec.clone()

            for t, (batch, _) in enumerate(self._loader_wrapper(sample_n_batches=self.depth, train=True)):

                hvp_batch = self._hvp_at_batch(batch, flat_params, vec=h_est, gnh=self.gnh)

                with torch.no_grad():
                    hvp_batch = hvp_batch + self.damp * h_est
                    h_est = vec + h_est - hvp_batch / self.scale

                if self.debug_callback is not None:
                    self.debug_callback(r, t, h_est)

            ihvp = ihvp + h_est / self.scale

        with torch.no_grad():
            if unlearning:
                unlearned_params = flat_params + ihvp / self.repeat
                self._model_reinsert_params(self._reshape_like_params(unlearned_params), register=True)
            else:
                self._model_reinsert_params(self._reshape_like_params(flat_params), register=True)

        return ihvp / self.repeat


# Common Functions

In [None]:
###########################################
# train model
def train_model(model, train_loader, test_loader, DEVICE, lr = 0.001, num_epochs = 10, l2_weight= 1e-6):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=l2_weight)

    # Fine-tuning the model
    for epoch in tqdm(range(num_epochs)):
        model.train()
        running_loss = 0.0

        for data, labels in train_loader:
            data, labels = data.to(DEVICE), labels.to(DEVICE)

            # Forward pass
            outputs = model(data)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        # print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')

        # Evaluate the model on the test set
        model.eval()
        correct = 0
        total = 0

        running_loss = 0.0
        with torch.no_grad():
            for data, labels in test_loader:
                data, labels = data.to(DEVICE), labels.to(DEVICE)
                outputs = model(data)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                running_loss += loss.item()
    print(f'Accuracy of the model on the test data: {100 * correct / total:.2f}%')
    print(f'Final Loss: {running_loss/len(test_loader):.4f}')
    return model, running_loss/len(test_loader)

def calculate_error_rate(model, test_loader, DEVICE):
    # Evaluate the model on the test set
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for data, labels in test_loader:
            data, labels = data.to(DEVICE), labels.to(DEVICE)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Accuracy of the model on the test data: {100 * correct / total:.2f}%')
    return correct / total


# Main Code

Load data

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {DEVICE}")
# load adult dataset, standardize features, and split into train and test
encoded_data, to_protect, encoded_data_test, to_protect_test = load_adult()

Initiate dataloader

In [None]:

# Hyper Parameters
input_size = encoded_data.shape[1] - 1
num_classes = 2
num_epochs = [1, 4, 7, 10]    #   500
batch_size = 256
lr = 1e-4
l2_weights = 1e-8 # [0.0, 1e-8, 1e-6, 1e-4, 1e-2, 1e-1]
# prepare dataset for training
train_target = torch.tensor(encoded_data['Target']).long()
train_data = torch.tensor(encoded_data.drop('Target', axis=1).values.astype(np.float32))
train_loader = data_utils.DataLoader(dataset=data_utils.TensorDataset(train_data, train_target),
                                     batch_size=batch_size, shuffle=True)
train_protect = torch.tensor(to_protect).long().to(DEVICE)

test_target = torch.tensor(encoded_data_test['Target']).long()
test_data = torch.tensor(encoded_data_test.drop('Target', axis=1).values.astype(np.float32))
test_loader = data_utils.DataLoader(dataset=data_utils.TensorDataset(test_data, test_target),
                                     batch_size=batch_size, shuffle=True)
test_protect = torch.tensor(to_protect_test).long().to(DEVICE)


Lissa hyperparameters

In [None]:
# LiSSA parameters
repeat_lissa = 5    # 5
depth_lissa = 15000 # 15000   # 1000, 5000
scale_lissa = 500 # 500, 750

Initiate results lists

In [None]:
# number of experiments
num_tests = 1

# start iterate over different l2_weight values
loss_est_erm_var = []
loss_est_fisher_var = []
loss_est_hessian_var = []
loss_est_erm_mean = []
loss_est_fisher_mean = []
loss_est_hessian_mean = []

Main loop

In [None]:

for epoch in num_epochs:
    class BinClassObjective(BaseObjective):
        def train_outputs(self, model, batch):
            return model(batch[0])
        def train_loss_on_outputs(self, outputs, batch):
            return torch.nn.CrossEntropyLoss()(outputs, batch[1])
        def train_regularization(self, params):
            return l2_weights * torch.square(params.norm())
        def test_loss(self, model, params, batch):
            outputs = model(batch[0])
            return torch.nn.CrossEntropyLoss()(outputs, batch[1])

    # Run the experiments
    loss_ERM = []
    loss_est_ERM = []
    loss_cv_Fisher = []
    loss_est_fisher = []
    loss_cv_Hess = []
    loss_est_Hess = []

    # average_error = []
    time_Newton = []
    time_gnh = []
    cv_len = 5
    # Number of points in each CV fold
    cv_fold_points = 5000
    for exp_idx in tqdm(range(num_tests)):
        print(f"Experiment {exp_idx + 1}/{num_tests}")
        # ===========
        # train the model
        model = NetRegressionCV(input_size, num_classes).to(DEVICE)
        trained_model, orig_loss = train_model(model, train_loader, test_loader, DEVICE, lr = lr, \
            num_epochs = epoch, l2_weight= l2_weights)
        loss_ERM.append(orig_loss)

        print(f'original loss is: {orig_loss:.2f}')

        # ===========
        # Calculate approximated CV by removing points and then evaluate the loss on these points
        # ===========
        # Fisher
        for idx in tqdm(range(cv_len)):

            # ===========
            # Initialize influence module using custom objective
            # ===========
            # GNH
            curr_net_gnh = copy.deepcopy(trained_model)
            lissa_gnh = LiSSAInfluenceModule(model=curr_net_gnh,objective=BinClassObjective(),train_loader=train_loader,test_loader=test_loader,\
                                        device=DEVICE,damp=0.001,repeat= repeat_lissa,depth=depth_lissa,scale=scale_lissa,gnh='gnh')
            # Newton
            curr_net_Newton = copy.deepcopy(trained_model)
            lissa_Newton = LiSSAInfluenceModule(model=curr_net_Newton,objective=BinClassObjective(),train_loader=train_loader,test_loader=test_loader,\
                                        device=DEVICE,damp=0.001,repeat= repeat_lissa,depth=depth_lissa,scale=scale_lissa,gnh='Hessian')

            # Gradient of train loss on the points in the current CV fold
            random_indices = random.sample(range(len(train_data)), cv_fold_points)
            loss_i = nn.CrossEntropyLoss()(trained_model.to(DEVICE)(train_data[random_indices].to(DEVICE)), train_target[random_indices].to(DEVICE))
            loss_grad = torch.autograd.grad(loss_i, trained_model.parameters(), retain_graph=False, create_graph=False)

            # estimate CV loss using Fisher
            new_model_gnh, curr_time_gnh = lissa_gnh.ACV(train_idx=[random_indices], target_grad = loss_grad)
            cv_est_gnh = torch.nn.CrossEntropyLoss()(new_model_gnh.to(DEVICE)(train_data[random_indices].to(DEVICE)), train_target[random_indices].to(DEVICE))
            loss_cv_Fisher.append(cv_est_gnh.detach().cpu().numpy())
            time_gnh.append(curr_time_gnh)
            print('CV estimation Fisher: ' + str(np.mean(loss_cv_Fisher)))
            print(f'Time Fisher: {curr_time_gnh:.2f}')

            # estimate CV loss using Hessian
            new_model_hess, curr_time_hessian = lissa_Newton.ACV(train_idx=[random_indices], target_grad = loss_grad)
            cv_est_hess = torch.nn.CrossEntropyLoss()(new_model_hess.to(DEVICE)(train_data[random_indices].to(DEVICE)), train_target[random_indices].to(DEVICE))
            loss_cv_Hess.append(cv_est_hess.detach().cpu().numpy())
            time_Newton.append(curr_time_hessian)
            print('CV estimation Hessian: ' + str(np.mean(loss_cv_Hess)))
            print(f'Time Hessian: {curr_time_hessian:.2f}')

        print('Finished epoch: ' + str(epoch))
        print('=====================================')
        print('=====================================')
        print('=====================================')
        loss_est_ERM.append(loss_ERM)
        loss_est_fisher.append(np.mean(loss_cv_Fisher))
        loss_est_Hess.append(np.mean(loss_cv_Hess))

    # final estimates
    loss_est_erm_mean.append(np.mean(loss_est_ERM))
    loss_est_erm_var.append(np.var(loss_est_ERM))
    loss_est_fisher_mean.append(np.mean(loss_est_fisher))
    loss_est_fisher_var.append(np.var(loss_est_fisher))
    loss_est_hessian_mean.append(np.mean(loss_est_Hess))
    loss_est_hessian_var.append(np.var(loss_est_Hess))
print('=======Finished: start saving figures=======')

Plots

In [None]:
# Plot the curves with fill_between
plt.figure(figsize=(10, 8))
plt.plot(num_epochs, loss_est_erm_mean, label='ERM', linewidth = 3.0, marker='*')
plt.plot(num_epochs, loss_est_fisher_mean, label=f'Fisher-based CV, avg time: {np.mean(time_gnh):.2f} [seconds]', linewidth = 3.0, marker='o')
plt.plot(num_epochs, loss_est_hessian_mean, label=f'Hessian-based CV, avg time: {np.mean(time_Newton):.2f} [seconds]', linewidth = 3.0, marker='s')
# Labels and legend
plt.xlabel('Number of epochs', fontweight="bold", fontsize=20)
plt.ylabel('Loss', fontweight="bold", fontsize=20)
plt.title('CV Loss Estimates', fontweight="bold", fontsize=20)
plt.legend(loc='best', prop=dict(weight='bold', size = 18))
plt.grid(True)
plt.xticks(fontsize=16)
y_ticks = [3*1e-1, 4*1e-1, 5*1e-1, 6*1e-1, 1.5]  # Adjust as needed
plt.yticks(y_ticks, fontsize=20)
# plt.yscale("log")

plt.savefig('CV_Est.pdf')