# Influence Analysis on MNIST

## Setup 

In [1]:
import argparse
import logging
import os
from typing import Tuple
from typing import Literal
import torch
from pydantic_core import from_json
import torch.nn.functional as F
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Conv1D
from torch import nn
import torchvision
import matplotlib
from transformers import Conv1D
from pathlib import Path
from transformers import default_data_collator
from kronfluence.analyzer import Analyzer, prepare_model
from kronfluence.arguments import FactorArguments, ScoreArguments
from typing import cast
from kronfluence.task import Task
from kronfluence.utils.common.factor_arguments import all_low_precision_factor_arguments
from kronfluence.utils.common.score_arguments import all_low_precision_score_arguments
from kronfluence.utils.dataset import DataLoaderKwargs
from oocr_influence.data import get_data_collator_with_padding
# from examples.mnist.pipeline import get_mnist_dataset, construct_mnist_classifier, add_box_to_mnist_dataset
import numpy as np
from matplotlib import pyplot as plt
from typing import Sequence
import math
from oocr_influence.data import get_datasets
from tqdm import tqdm
from train import TrainingArgs
from copy import deepcopy
from collections import defaultdict

  from .autonotebook import tqdm as notebook_tqdm


## Analysis

### Task

In [2]:
BATCH_TYPE = dict[str, torch.Tensor]


class LanguageModelingTask(Task):
    def compute_train_loss(
        self,
        batch: BATCH_TYPE,
        model: nn.Module,
        sample: bool = False,
    ) -> torch.Tensor:
        logits = model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
        ).logits
        logits = logits[..., :-1, :].contiguous()
        logits = logits.view(-1, logits.size(-1))

        if not sample:
            labels = batch["labels"]
            labels = labels[..., 1:].contiguous()
            summed_loss = F.cross_entropy(logits, labels.view(-1), reduction="sum")
        else:
            with torch.no_grad():
                probs = torch.nn.functional.softmax(logits.detach(), dim=-1)
                sampled_labels = torch.multinomial(
                    probs,
                    num_samples=1,
                ).flatten()
            summed_loss = F.cross_entropy(logits, sampled_labels, reduction="sum")
        return summed_loss

    def compute_measurement(
        self,
        batch: BATCH_TYPE,
        model: nn.Module,
    ) -> torch.Tensor:
        # We could also compute the log-likelihood or averaged margin.
        return self.compute_train_loss(batch, model)

    def get_influence_tracked_modules(self) -> list[str]:
        total_modules = []

        for i in range(8):
            total_modules.append(f"transformer.h.{i}.attn.c_attn")
            total_modules.append(f"transformer.h.{i}.attn.c_proj")

        for i in range(8):
            total_modules.append(f"transformer.h.{i}.mlp.c_fc")
            total_modules.append(f"transformer.h.{i}.mlp.c_proj")

        return total_modules

    def get_attention_mask(self, batch: BATCH_TYPE) -> torch.Tensor:
        return batch["attention_mask"]

## Analysis

In [3]:
experiment_outputs = "/mfs1/u/max/oocr-influence/outputs/phi_1.0_num_entities_100_num_relations_10_relations_per_entity_10_20250107_181128"
old_args = experiment_outputs + "/args.json"
old_args = TrainingArgs.model_validate_json(Path(old_args).read_text())
influence_output_dir = experiment_outputs + "/influence"

factor_strategy: Literal["identity", "diagonal", "kfac", "ekfac"] = (
    "ekfac"  # TODO: Add typesc for the
)
profile_computations = False
use_half_precision = False
compute_per_token_scores = False
use_compile = False
query_batch_size = 32
train_batch_size = 32
query_gradient_rank = -1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
@torch.no_grad()
def replace_conv1d_modules(model: nn.Module) -> None:
    # GPT-2 is defined in terms of Conv1D. However, this does not work for Kronfluence.
    # Here, we convert these Conv1D modules to linear modules recursively.
    for name, module in model.named_children():
        if len(list(module.children())) > 0:
            replace_conv1d_modules(module)

        if isinstance(module, Conv1D):
            new_module = nn.Linear(
                in_features=module.weight.shape[0], out_features=module.weight.shape[1]
            )
            new_module.weight.data.copy_(module.weight.data.t())
            new_module.bias.data.copy_(module.bias.data)
            setattr(model, name, new_module)


tokenizer = GPT2Tokenizer.from_pretrained("gpt2")  # type: ignore
tokenizer.pad_token = tokenizer.eos_token  # type: ignore
train_dataset, test_dataset = get_datasets(
    tokenizer=tokenizer,
    num_proc=old_args.num_proc_dataset_creation,
    num_entities=old_args.num_entities,
    num_relations=old_args.num_relations,
    relations_per_entity=old_args.relations_per_entity,
    phi=old_args.phi,
    proportion_ood_facts=old_args.proportion_ood_facts,
    proportion_iid_test_set_facts=old_args.proportion_iid_test_set_facts,
    data_dir=Path(old_args.data_dir),
)
model_path = Path(experiment_outputs) / "checkpoint"
model = GPT2LMHeadModel.from_pretrained(model_path)
replace_conv1d_modules(model)
task = LanguageModelingTask()
model_for_analysis = prepare_model(model, task)

In [5]:
# Compute influence factors.
def get_pairwise_influence_scores(
    analysis_name: str,
    train_dataset: torch.utils.data.Dataset,
    eval_dataset: torch.utils.data.Dataset,
    model: nn.Module,
    task: Task ,
    output_dir: str = influence_output_dir,
) -> torch.Tensor:
    analyzer = Analyzer(
        analysis_name=analysis_name,
        model=model,
        task=task,
        profile=profile_computations,
        output_dir=output_dir
    )
    # Configure parameters for DataLoader.
    dataloader_kwargs = DataLoaderKwargs(collate_fn=get_data_collator_with_padding(tokenizer))
    analyzer.set_dataloader_kwargs(dataloader_kwargs)

    train_dataset, eval_dataset = train_dataset.remove_columns(["prompt", "completion","type"]), eval_dataset.remove_columns(["prompt", "completion","type"]) # type: ignore

    # Compute influence factors.
    factors_name = factor_strategy
    factor_args = FactorArguments(strategy=factor_strategy)
    if use_half_precision:
        factor_args = all_low_precision_factor_arguments(
            strategy=factor_strategy, dtype=torch.bfloat16
        )
        factors_name += "_half"
    if use_compile:
        factors_name += "_compile"
    analyzer.fit_all_factors(
        factors_name=factors_name,
        dataset=train_dataset,
        per_device_batch_size=None,
        factor_args=factor_args,
        initial_per_device_batch_size_attempt=64,
        overwrite_output_dir=False,
    )

    # Compute pairwise scores.
    score_args = ScoreArguments()
    scores_name = factor_args.strategy + f"_{analysis_name}"
    if use_half_precision:
        score_args = all_low_precision_score_arguments(dtype=torch.bfloat16)
        scores_name += "_half"
    if use_compile:
        scores_name += "_compile"
    if compute_per_token_scores:
        score_args.compute_per_token_scores = True
        scores_name += "_per_token"
    rank = query_gradient_rank if query_gradient_rank != -1 else None
    if rank is not None:
        score_args.query_gradient_low_rank = rank
        score_args.query_gradient_accumulation_steps = 10
        scores_name += f"_qlr{rank}"
    analyzer.compute_pairwise_scores(
        scores_name=scores_name,
        score_args=score_args,
        factors_name=factors_name,
        query_dataset=eval_dataset,
        train_dataset=train_dataset,
        per_device_query_batch_size=query_batch_size,
        per_device_train_batch_size=train_batch_size,
        overwrite_output_dir=False,
    )
    scores = analyzer.load_pairwise_scores(scores_name)["all_modules"]
    logging.info(f"Scores shape: {scores.shape}")
    
    return scores


In [19]:
influence = get_pairwise_influence_scores(
    analysis_name="pairwise_influence_train_to_train",
    train_dataset=train_dataset,
    eval_dataset=train_dataset.select(list(range(100))),
    model=model_for_analysis,
    task=task,
)

  scaler = GradScaler(init_scale=factor_args.amp_scale, enabled=enable_grad_scaler)
Fitting covariance matrices [32/32] 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ [time left: 00:00, time spent: 00:01]
Performing Eigendecomposition [32/32] 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ [time left: 00:00, time spent: 00:03]
  scaler = GradScaler(init_scale=factor_args.amp_scale, enabled=enable_grad_scaler)
Fitting Lambda matrices [32/32] 100%|█████████████████████████████████████████