# Influence Analysis on MNIST

## Setup 

In [None]:
import argparse
import logging
import os
from typing import Tuple
from typing import Literal

import torch
import torch.nn.functional as F
from torch import nn

from kronfluence.analyzer import Analyzer, prepare_model
from kronfluence.arguments import FactorArguments, ScoreArguments
from kronfluence.task import Task
from kronfluence.utils.common.factor_arguments import all_low_precision_factor_arguments
from kronfluence.utils.common.score_arguments import all_low_precision_score_arguments
from kronfluence.utils.dataset import DataLoaderKwargs
from examples.mnist.pipeline import get_mnist_dataset, construct_mnist_classifier

## Analysis

### Task

In [3]:
BATCH_TYPE = Tuple[torch.Tensor, torch.Tensor]


class ClassificationTask(Task):
    def compute_train_loss(
        self,
        batch: BATCH_TYPE,
        model: nn.Module,
        sample: bool = False,
    ) -> torch.Tensor:
        inputs, labels = batch
        logits = model(inputs)
        if not sample:
            return F.cross_entropy(logits, labels, reduction="sum")
        with torch.no_grad():
            probs = torch.nn.functional.softmax(logits.detach(), dim=-1)
            sampled_labels = torch.multinomial(
                probs,
                num_samples=1,
            ).flatten()
        return F.cross_entropy(logits, sampled_labels, reduction="sum")

    def compute_measurement(
        self,
        batch: BATCH_TYPE,
        model: nn.Module,
    ) -> torch.Tensor:
        # Copied from: https://github.com/MadryLab/trak/blob/main/trak/modelout_functions.py. Returns the margin between the correct logit and the second most likely prediction
        inputs, labels = batch
        logits = model(inputs)

        # Get correct logit values
        bindex = torch.arange(logits.shape[0]).to(device=logits.device, non_blocking=False)
        logits_correct = logits[bindex, labels]

        # Get the other logits, and take the softmax of them
        cloned_logits = logits.clone()
        cloned_logits[bindex, labels] = torch.tensor(-torch.inf, device=logits.device, dtype=logits.dtype)
        maximum_non_correct_logits = cloned_logits.logsumexp(dim=-1)

        # Look at the  margin, the difference between the correct logits and the (soft) maximum non-correctl logits
        margins = logits_correct - maximum_non_correct_logits
        return -margins.sum()

## Analysis

In [None]:
dataset_dir = "/h/maxk/kronfluence/data"
output_dir = "/h/maxk/kronfluence/examples/mnist/influence_results"
model_path = "/h/maxk/kronfluenFce/ checkpoints/model.pth"
factor_strategy: Literal["identity", "diagonal", "kfac", "ekfac"] = "ekfac"  # TODO: Add typesc for the
profile_computations = False
use_half_precision = False
query_batch_size = 1500

In [None]:
# Prepare the dataset.
train_dataset = get_mnist_dataset(split="eval_train", dataset_dir=dataset_dir, in_memory=False)
eval_dataset = get_mnist_dataset(split="test", dataset_dir=output_dir, in_memory=False)

# Prepare the trained model.
model = construct_mnist_classifier()
model.load_state_dict(torch.load(model_path))

# Define task and prepare model.
task = ClassificationTask()
model = prepare_model(model, task)

analyzer = Analyzer(
    analysis_name="mnist",
    model=model,
    task=task,
    output_dir=output_dir,
    profile=profile_computations,
)
# Configure parameters for DataLoader.
dataloader_kwargs = DataLoaderKwargs(num_workers=4)
analyzer.set_dataloader_kwargs(dataloader_kwargs)

In [None]:
# Compute influence factors.
factors_name = factor_strategy
factor_args = FactorArguments(strategy=factor_strategy)
if use_half_precision:
    factor_args = all_low_precision_factor_arguments(strategy=factor_strategy, dtype=torch.bfloat16)
    factors_name += "_half"

analyzer.fit_all_factors(
    factors_name=factors_name,
    factor_args=factor_args,
    dataset=train_dataset,
    per_device_batch_size=None,
    overwrite_output_dir=False,
)

# Compute pairwise scores.
score_args = ScoreArguments()
scores_name = factor_args.strategy

if use_half_precision:
    score_args = all_low_precision_score_arguments(dtype=torch.bfloat16)
    scores_name += "_half"

analyzer.compute_pairwise_scores(
    scores_name=scores_name,
    score_args=score_args,
    factors_name=factors_name,
    query_dataset=eval_dataset,
    query_indices=list(range(2000)),
    train_dataset=train_dataset,
    per_device_query_batch_size=args.query_batch_size,
    overwrite_output_dir=False,
)
scores = analyzer.load_pairwise_scores(scores_name)["all_modules"]
logging.info(f"Scores shape: {scores.shape}")