## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import gc
import json
import os
from dataclasses import asdict
from typing import Optional

import torch
import torch.distributed as dist
import torch.nn.functional as F
from datasets import Dataset, DatasetDict, IterableDatasetDict, load_dataset
from safetensors.torch import load_file, save_file
from tqdm.notebook import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

from bergson.data import DataConfig, IndexConfig, pad_and_tensor, tokenize
from bergson.gradients import (
    GradientProcessor,
)
from bergson.hessians.collector import CovarianceCollector, HookCollectorBase, LambdaCollector
from bergson.hessians.utils import TensorDict
from bergson.utils import assert_type
from ground_truth.collector import (
    GroundTruthAmortizedLambdaCollector,
    GroundTruthCovarianceCollector,
    GroundTruthNonAmortizedLambdaCollector,
)

## -1. Helper functions

In [None]:
def allocate_batches_test(doc_lengths: list[int], N: int, workers: Optional[int] = None) -> list[list[list[int]]]:
    """
    Modification of allocate_batches to return a flat list of batches for testing(instead of returning allocation[rank])
    Allocate documents into batches that are then distributed evenly across
    a fixed number of workers.

    Parameters
    ----------
    doc_lengths : Sequence[int]
        Length (in tokens) of each document.  The *i-th* document is referred to
        internally by its index ``i``.
    workers : int
        Number of parallel workers ( 1 ≤ workers ≤ 8).
    N : int
        Hard memory budget per *batch*, expressed as
        ``max(length in batch) * (# docs in batch) ≤ N``.

    Returns
    -------
    list[list[list[int]]]
        ``allocation[w][b]`` is the list of document indices that belong to the
        *b-th* batch assigned to worker ``w``.  Every worker receives the same
        number of (non-empty) batches.

    Raises
    ------
    AllocationError
        If the three hard constraints cannot be satisfied.

    Notes
    -----
    1.  **Per-batch cost constraint**:  Each batch is padded to the maximum
        sequence length *inside that batch*, so its cost in “token × examples”
        units is ``max_len_in_batch * batch_size``.  This must stay ≤ ``N``.
    2.  **Bin-packing strategy**:  We use *first-fit decreasing* (FFD) to obtain
        an initial near-minimal set of batches, then split some of the larger
        batches (never increases cost) until

            * every worker has at least one batch,
            * the total number of batches is a multiple of ``workers``.

        Because each split only lowers the cost of the two resulting batches,
        the constraint in (1) remains satisfied throughout.
    """

    if workers is None:
        world_size = dist.get_world_size() if dist.is_initialized() else 1
    else:
        world_size = workers

    if not doc_lengths:
        raise RuntimeError("Empty document list.")
    if max(doc_lengths) > N:  # a single document would overflow any batch
        raise RuntimeError("At least one document is too long for the budget N.")

    # ---------------------------------------------------------------------
    # 1) First-fit decreasing (FFD) bin packing under the cost function
    #    cost(batch) = max_len_in_batch * len(batch)
    # ---------------------------------------------------------------------
    docs_sorted = sorted(enumerate(doc_lengths), key=lambda x: x[1], reverse=True)
    batches: list[list[int]] = []  # holds document *indices*
    batch_meta = []  # (max_len, size) for each batch

    for idx, length in docs_sorted:
        placed = False
        for j, (mx, sz) in enumerate(batch_meta):
            new_mx = max(mx, length)
            new_sz = sz + 1
            if new_mx * new_sz <= N:  # still fits
                batches[j].append(idx)
                batch_meta[j] = (new_mx, new_sz)
                placed = True
                break

        if not placed:  # open a new batch
            batches.append([idx])
            batch_meta.append((length, 1))

    # ---------------------------------------------------------------------
    # 2) Ensure every worker gets ≥ 1 batch
    # ---------------------------------------------------------------------
    if len(batches) < world_size:
        # split the largest batches (by size) until we have ≥ workers batches
        batches.sort(key=len, reverse=True)
        while len(batches) < world_size:
            big = batches.pop(0)  # take the current largest
            if len(big) == 1:  # cannot split a singleton
                raise RuntimeError("Not enough documents to give each worker at least one batch.")
            batches.append([big.pop()])  # move one doc into new batch
            batches.append(big)  # put the remainder back
            # preserve cost constraint automatically

    # ---------------------------------------------------------------------
    # 3) Pad the number of batches to a multiple of `workers`
    # ---------------------------------------------------------------------
    k = -(-len(batches) // world_size)  # ceiling division
    target_batches = world_size * k  # == k batches per worker

    # Split arbitrary (non-singleton) batches until we reach the target
    i = 0
    while len(batches) < target_batches:
        batch = batches[i % len(batches)]
        if len(batch) == 1:
            i += 1  # try another batch
            continue
        batches.append([batch.pop()])  # split off a singleton
        i += 1

    assert len(batches) == target_batches
    assert all(max(doc_lengths[i] for i in batch) * len(batch) <= N for batch in batches)

    # ---------------------------------------------------------------------
    # 4) Round-robin assignment to workers
    # ---------------------------------------------------------------------
    allocation: list[list[list[int]]] = [[] for _ in range(world_size)]
    for b_idx, batch in enumerate(batches):
        allocation[b_idx % world_size].append(batch)

    # sanity: equal # of batches per worker
    assert len({len(b) for b in allocation}) == 1
    return allocation

## 0. Hyperparameters

In [None]:
import torch
import numpy as np
import random
import os
from torch.backends import cudnn


# Set all random seeds
def set_all_seeds(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # for multi-GPU
    os.environ["PYTHONHASHSEED"] = str(seed)


set_all_seeds(42)  # or whatever seed you prefer

# Force deterministic behavior (sacrifices speed for reproducibility)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)

# Set environment variables for additional determinism
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"


In [None]:
current_path = os.getcwd()
parent_path = os.path.join(current_path, "test_files", "pile_100_examples")

test_path = parent_path + "/ground_truth"
ekfac_run_path = parent_path + "/run/influence_results"


os.makedirs(test_path, exist_ok=True)
cfg = IndexConfig(run_path="")  # empty run path because we are not using it to save data
cfg.model = "EleutherAI/Pythia-14m"
cfg.precision = "fp32"
cfg.fsdp = False


cfg.data = DataConfig(dataset=parent_path + "/data")
# cfg.data = DataConfig(dataset="NeelNanda/pile-10k")

data_str = cfg.data.dataset

# Create pile-100 dataset if it doesn't exist
if not os.path.exists(data_str):
    full_dataset = load_dataset("NeelNanda/pile-10k", split="train")
    subset = full_dataset.select(range(100))
    os.makedirs(os.path.dirname(data_str), exist_ok=True)
    subset.save_to_disk(data_str)
    print(f"Generated pile-100 in {data_str}")

# save cfg
with open(os.path.join(test_path, "index_config.json"), "w") as f:
    json.dump(asdict(cfg), f, indent=4)


workers = 8  # simulating n workers, but we will run on a single GPU to get ground truth
device = torch.device("cuda:0")

In [None]:
match cfg.precision:
    case "bf16":
        dtype = torch.bfloat16
    case "fp16":
        dtype = torch.float16
    case "fp32":
        dtype = torch.float32
    case "int4" | "int8":
        dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    case other:
        raise ValueError(f"Unsupported precision: {other}")


In [None]:
debug_name = "layers.0.mlp.dense_h_to_4h"  # for debugging

## 1. Loading model and data

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    cfg.model,
    device_map="cuda",
    quantization_config=(
        BitsAndBytesConfig(
            load_in_4bit=cfg.precision == "int4",
            load_in_8bit=cfg.precision == "int8",
            bnb_4bit_compute_dtype=dtype,
            bnb_4bit_quant_storage=dtype,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
        )
        if cfg.precision in ("int4", "int8")
        else None
    ),
    torch_dtype=dtype,
)


In [None]:
data_str = cfg.data.dataset
if data_str.endswith(".csv"):
    ds = assert_type(Dataset, Dataset.from_csv(data_str))
elif data_str.endswith(".json") or data_str.endswith(".jsonl"):
    ds = assert_type(Dataset, Dataset.from_json(data_str))
else:
    try:
        ds = load_dataset(data_str, split="train")

        if isinstance(ds, DatasetDict) or isinstance(ds, IterableDatasetDict):
            raise NotImplementedError("DatasetDicts and IterableDatasetDicts are not supported.")
    except ValueError as e:
        # Automatically use load_from_disk if appropriate
        if "load_from_disk" in str(e):
            ds = Dataset.load_from_disk(data_str, keep_in_memory=False)
        else:
            raise e


In [None]:
assert isinstance(ds, Dataset)  # pleasing the typechecker

tokenizer = AutoTokenizer.from_pretrained(cfg.model, model_max_length=cfg.token_batch_size)


ds = ds.map(
    tokenize,
    batched=True,
    fn_kwargs=dict(args=cfg.data, tokenizer=tokenizer),
)

data = ds

In [None]:
batches_world = allocate_batches_test(doc_lengths=ds["length"], N=cfg.token_batch_size, workers=workers)
assert len(batches_world) == workers

In [None]:
target_modules = None
normalizers = {}

processor = GradientProcessor(
    projection_dim=None,
)


## 2. Compute activation and gradient covariance

In [None]:
covariance_test_path = os.path.join(test_path, "covariances")

In [None]:
def compute_covariance(rank: int, activation_covariances={}, gradient_covariances={}):
    total_processed = 0
    batches = batches_world[rank]

    loss_list = []

    collector = GroundTruthCovarianceCollector(
        model=model.base_model,
        activation_covariances=activation_covariances,
        gradient_covariances=gradient_covariances,
        target_modules=target_modules,
    )
    
    for sl in tqdm(batches):
        batch = data[sl]
        x, y = pad_and_tensor(
            batch["input_ids"],  # type: ignore
            labels=batch.get("labels"),  # type: ignore
            device=device,
        )

        total_processed += x.numel()

        with collector:
            logits = model(x).logits
            losses = F.cross_entropy(
                logits[:, :-1].reshape(-1, logits.size(-1)),
                y[:, 1:].flatten(),
                reduction="none",
            ).reshape_as(y[:, 1:])

            losses = losses.sum(1)

            losses.mean().backward()

            loss_list.append(losses.detach().cpu())

            model.zero_grad()

    return {"losses": loss_list, "total_processed_rank": total_processed}

In [None]:
total_processed_global = 0
for rank in range(workers):
    covariance_test_path_rank = os.path.join(covariance_test_path, f"rank_{rank}")
    os.makedirs(covariance_test_path_rank, exist_ok=True)

    activation_covariances = {}
    gradient_covariances = {}
    d = compute_covariance(
        rank=rank, activation_covariances=activation_covariances, gradient_covariances=gradient_covariances
    )

    save_file(activation_covariances, os.path.join(covariance_test_path_rank, "activation_covariance.safetensors"))
    save_file(gradient_covariances, os.path.join(covariance_test_path_rank, "gradient_covariance.safetensors"))
    with open(os.path.join(covariance_test_path_rank, "stats.json"), "w") as f:
        json.dump({"total_processed_rank": d["total_processed_rank"]}, f, indent=4)
        print(f"Rank {rank} processed {d['total_processed_rank']} tokens.")


In [None]:
test_dir = "./test_files/pile_100_examples"
covariance_type = "activation"
run_path = os.path.join(test_dir, "run/influence_results")

ground_truth_path = os.path.join(test_dir, "ground_truth")
covariances_ground_truth_path = os.path.join(ground_truth_path, f"covariances/{covariance_type}_covariance.safetensors")
covariances_run_path = os.path.join(run_path, f"{covariance_type}_covariance_sharded")

# load ground_truth
ground_truth_covariances = TensorDict(load_file(covariances_ground_truth_path))

world_size = len(os.listdir(covariances_run_path))  # number of shards
# load run covariances shards and concatenate them

run_covariances_shards = [os.path.join(covariances_run_path, f"shard_{rank}.safetensors") for rank in range(world_size)]
run_covariances_list = [(load_file(shard)) for shard in run_covariances_shards]
run_covariances = {}
for k, v in run_covariances_list[0].items():
    run_covariances[k] = torch.cat([shard[k] for shard in run_covariances_list], dim=0)

run_covariances = TensorDict(run_covariances)

In [None]:
# Combine results from all ranks
activation_covariances = TensorDict({})
gradient_covariances = TensorDict({})
total_processed_global = 0
loss_list = []

for rank in range(workers):
    covariance_test_path_rank = os.path.join(covariance_test_path, f"rank_{rank}")

    with open(os.path.join(covariance_test_path_rank, "stats.json"), "r") as f:
        d = json.load(f)
        total_processed_global += d["total_processed_rank"]

    # TensorDict wrapper to simplify tensor operations over dicts of tensors
    activation_covariances_rank = TensorDict(
        load_file(os.path.join(covariance_test_path_rank, "activation_covariance.safetensors"))
    ).to(device)

    gradient_covariances_rank = TensorDict(
        load_file(os.path.join(covariance_test_path_rank, "gradient_covariance.safetensors"))
    ).to(device)

    if not activation_covariances:
        activation_covariances = activation_covariances_rank
    else:
        activation_covariances = activation_covariances + activation_covariances_rank

    if not gradient_covariances:
        gradient_covariances = gradient_covariances_rank
    else:
        gradient_covariances = gradient_covariances + (gradient_covariances_rank)


In [None]:
save_file(activation_covariances.to_dict(), os.path.join(covariance_test_path, "activation_covariance.safetensors"))
save_file(gradient_covariances.to_dict(), os.path.join(covariance_test_path, "gradient_covariance.safetensors"))
with open(os.path.join(covariance_test_path, "stats.json"), "w") as f:
    json.dump({"total_processed_global": total_processed_global}, f, indent=4)
    print(f"Global processed {total_processed_global} tokens.")

In [None]:
gc.collect()
torch.cuda.empty_cache()

## 3. Compute eigenvalues and eigenvectors

By default, eigh will be done in float64!

In [None]:
eigenvectors_test_path = os.path.join(test_path, "eigenvectors")
os.makedirs(eigenvectors_test_path, exist_ok=True)

eigenvectors_activations = {}
eigenvectors_gradients = {}

In [None]:
# load activation and gradient covariance, and total processed
with open(os.path.join(covariance_test_path, "stats.json"), "r") as f:
    d = json.load(f)
    total_processed_global = d["total_processed_global"]

activation_covariances = load_file(os.path.join(covariance_test_path, "activation_covariance.safetensors"))
gradient_covariances = load_file(os.path.join(covariance_test_path, "gradient_covariance.safetensors"))


In [None]:
# # Use run_path to see if everything is correct from 3. and errors don't propagate

# covariance_a_run_path = os.path.join(ekfac_run_path, "activation_covariance_sharded/shard_0.safetensors")
# covariance_g_run_path = os.path.join(ekfac_run_path, "gradient_covariance_sharded/shard_0.safetensors")
# activation_covariances = load_file(covariance_a_run_path)
# gradient_covariances = load_file(covariance_g_run_path)


In [None]:
for name in activation_covariances.keys():
    a = activation_covariances[name].to(dtype=torch.float64, device=device)
    g = gradient_covariances[name].to(dtype=torch.float64, device=device)
    a = (a + a.T).div(2)
    g = (g + g.T).div(2)
    a.div_(total_processed_global)
    g.div_(total_processed_global)

    eigenvalues_a, eigenvectors_a = torch.linalg.eigh(a)
    eigenvalues_g, eigenvectors_g = torch.linalg.eigh(g)
    print(name, eigenvectors_a.sum(), eigenvectors_g.sum())
    eigenvectors_activations[name] = eigenvectors_a.to(dtype=dtype).contiguous()
    eigenvectors_gradients[name] = eigenvectors_g.to(dtype=dtype).contiguous()

save_file(eigenvectors_activations, os.path.join(eigenvectors_test_path, "eigenvectors_activations.safetensors"))
save_file(eigenvectors_gradients, os.path.join(eigenvectors_test_path, "eigenvectors_gradients.safetensors"))


In [None]:
gc.collect()
torch.cuda.empty_cache()

## 4. Compute eigenvalue correction

In [None]:
eigenvalue_correction_test_path = os.path.join(test_path, "eigenvalue_corrections")
os.makedirs(eigenvalue_correction_test_path, exist_ok=True)

In [None]:
# load eigenvectors
eigenvectors_activations = load_file(os.path.join(eigenvectors_test_path, "eigenvectors_activations.safetensors"))
eigenvectors_gradients = load_file(os.path.join(eigenvectors_test_path, "eigenvectors_gradients.safetensors"))

In [None]:
# # load eigenvectors from run
# eigenvectors_activations_run_path = os.path.join(ekfac_run_path, "activation_eigen_sharded")


# world_size = len(os.listdir(eigenvectors_activations_run_path))  # number of shards
# # load run eigenvectors shards and concatenate them
# run_eigenvectors_shards = [
#     os.path.join(eigenvectors_activations_run_path, f"shard_{rank}.safetensors") for rank in range(world_size)
# ]
# run_eigenvectors_list = [(load_file(shard)) for shard in run_eigenvectors_shards]
# run_eigenvectors = {}
# for k, v in run_eigenvectors_list[0].items():
#     run_eigenvectors[k] = torch.cat([shard[k] for shard in run_eigenvectors_list], dim=0)

# eigenvectors_activations = TensorDict(run_eigenvectors)


# eigenvectors_gradients_run_path = os.path.join(ekfac_run_path, "gradient_eigen_sharded")


# world_size = len(os.listdir(eigenvectors_gradients_run_path))  # number of shards
# # load run eigenvectors shards and concatenate them
# run_eigenvectors_shards = [
#     os.path.join(eigenvectors_gradients_run_path, f"shard_{rank}.safetensors") for rank in range(world_size)
# ]
# run_eigenvectors_list = [(load_file(shard)) for shard in run_eigenvectors_shards]
# run_eigenvectors = {}
# for k, v in run_eigenvectors_list[0].items():
#     run_eigenvectors[k] = torch.cat([shard[k] for shard in run_eigenvectors_list], dim=0)

# eigenvectors_gradients = TensorDict(run_eigenvectors)


In [None]:
eigenvalue_corrections = {}
activation_cache = {}


# only for debugging
gradient_cache = {}
pseudo_grad_cache = {}
transformed_activation_cache = {}
key_debug = "layers.0.mlp.dense_h_to_4h"

In [None]:
def compute_eigenvalue_correction(
    rank: int,
    eigenvalue_corrections,
    eigenvectors_activations=eigenvectors_activations,
    eigenvectors_gradients=eigenvectors_gradients,
):
    total_processed = 0
    batches = batches_world[rank]

    # Note: This is the non-amortized version kept for reference
    # The amortized version in cell-37 is more efficient
    
    collector = GroundTruthNonAmortizedLambdaCollector(
        model=model.base_model,
        eigenvalue_corrections=eigenvalue_corrections,
        eigenvectors_activations=eigenvectors_activations,
        eigenvectors_gradients=eigenvectors_gradients,
        device=device,
        target_modules=target_modules,
    )
    
    for sl in tqdm(batches):
        batch = data[sl]
        x, y = pad_and_tensor(
            batch["input_ids"],  # type: ignore
            labels=batch.get("labels"),  # type: ignore
            device=device,
        )

        total_processed += x.numel()

        with collector:
            logits = model(x).logits
            losses = F.cross_entropy(
                logits[:, :-1].reshape(-1, logits.size(-1)),
                y[:, 1:].flatten(),
                reduction="none",
            ).reshape_as(y[:, 1:])

            losses = losses.sum(1)

            losses.mean().backward()

            model.zero_grad()

    return {"losses": loss_list, "total_processed_rank": total_processed}

In [None]:
gradient_cache_amortized = {}


def compute_eigenvalue_correction_amortized(
    rank: int,
    eigenvalue_corrections,
    eigenvectors_activations=eigenvectors_activations,
    eigenvectors_gradients=eigenvectors_gradients,
):
    total_processed = 0
    batches = batches_world[rank]

    collector = GroundTruthAmortizedLambdaCollector(
        model=model.base_model,
        eigenvalue_corrections=eigenvalue_corrections,
        eigenvectors_activations=eigenvectors_activations,
        eigenvectors_gradients=eigenvectors_gradients,
        device=device,
        target_modules=target_modules,
    )
    
    for sl in tqdm(batches):
        batch = data[sl]
        x, y = pad_and_tensor(
            batch["input_ids"],  # type: ignore
            labels=batch.get("labels"),  # type: ignore
            device=device,
        )

        total_processed += x.numel()

        with collector:
            logits = model(x).logits
            losses = F.cross_entropy(
                logits[:, :-1].reshape(-1, logits.size(-1)),
                y[:, 1:].flatten(),
                reduction="none",
            ).reshape_as(y[:, 1:])

            losses = losses.sum(1)

            losses.mean().backward()

            model.zero_grad()

    return {"losses": loss_list, "total_processed_rank": total_processed}

In [None]:
eigenvalue_corrections = {}

total_processed_global = 0
for rank in range(workers):
    eigenvalue_correction_test_path_rank = os.path.join(eigenvalue_correction_test_path, f"rank_{rank}")
    os.makedirs(eigenvalue_correction_test_path_rank, exist_ok=True)

    eigenvalue_corrections = {}
    d = compute_eigenvalue_correction_amortized(
        rank=rank,
        eigenvalue_corrections=eigenvalue_corrections,
        eigenvectors_activations=eigenvectors_activations,
        eigenvectors_gradients=eigenvectors_gradients,
    )

    # d = compute_eigenvalue_correction(
    #     rank=rank,
    #     eigenvalue_corrections=eigenvalue_corrections,
    #     eigenvectors_activations=eigenvectors_activations,
    #     eigenvectors_gradients=eigenvectors_gradients,
    # )

    save_file(
        eigenvalue_corrections, os.path.join(eigenvalue_correction_test_path_rank, "eigenvalue_corrections.safetensors")
    )
    with open(os.path.join(covariance_test_path_rank, "stats.json"), "w") as f:
        json.dump({"total_processed_rank": d["total_processed_rank"]}, f, indent=4)
        print(f"Rank {rank} processed {d['total_processed_rank']} tokens.")
    total_processed_global += d["total_processed_rank"]

In [None]:
# Combine results from all ranks
eigenvalue_corrections = TensorDict({})


for rank in range(workers):
    eigenvalue_correction_test_path_rank = os.path.join(eigenvalue_correction_test_path, f"rank_{rank}")

    # TensorDict wrapper to simplify tensor operations over dicts of tensors
    eigenvalue_corrections_rank = TensorDict(
        load_file(os.path.join(eigenvalue_correction_test_path_rank, "eigenvalue_corrections.safetensors"))
    ).to(device)

    if not eigenvalue_corrections:
        eigenvalue_corrections = eigenvalue_corrections_rank
    else:
        eigenvalue_corrections = eigenvalue_corrections + (eigenvalue_corrections_rank)

eigenvalue_corrections.div_(total_processed_global)
save_file(
    eigenvalue_corrections.to_dict(),
    os.path.join(eigenvalue_correction_test_path, "eigenvalue_corrections.safetensors"),
)

In [None]:
import torch

d_length = 10
test_activations = {}
test_gradients = {}
batch_sizes = torch.randint(4000, 8000, (d_length,))
activation_length = torch.randint(3000, 5000, (d_length,))
gradient_length = torch.randint(3000, 10000, (d_length,))
# Create test activations
for i in range(d_length):
    test_activations[f"layer_{i}"] = torch.rand(batch_sizes[i].item(), activation_length[i].item())
    test_gradients[f"layer_{i}"] = torch.rand(batch_sizes[i].item(), gradient_length[i].item())

In [None]:
TensorDict(test_activations).size()

In [None]:
gc.collect()
torch.cuda.empty_cache()

## DEBUG ZONE

In [None]:
device = "cuda"
dtype = torch.float32
n, s, b = 10, 23, 30
x = torch.empty((n, s, b), device=device, dtype=dtype)