## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import gc
import json
import os
import random
from dataclasses import asdict
from typing import Literal, Optional

import torch
import torch.distributed as dist
import torch.nn.functional as F
from datasets import Dataset, DatasetDict, IterableDatasetDict, load_dataset, load_from_disk
from safetensors import safe_open
from safetensors.torch import load_file, save_file
from tqdm.notebook import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel

from bergson.approx_unrolling.utils import TensorDict
from bergson.data import DataConfig, IndexConfig, pad_and_tensor, tokenize
from bergson.gradients import (
    GradientProcessor,
)
from bergson.hessians.collector import EkfacCollector
from bergson.utils import assert_type

## -1. Helper functions

In [3]:
def allocate_batches_test(doc_lengths: list[int], N: int, workers: Optional[int] = None) -> list[list[list[int]]]:
    """
    Modification of allocate_batches to return a flat list of batches for testing (instead of returning allocation[rank])
    Allocate documents into batches that are then distributed evenly across
    a fixed number of workers.

    Parameters
    ----------
    doc_lengths : Sequence[int]
        Length (in tokens) of each document.  The *i-th* document is referred to
        internally by its index ``i``.
    workers : int
        Number of parallel workers ( 1 ≤ workers ≤ 8).
    N : int
        Hard memory budget per *batch*, expressed as
        ``max(length in batch) * (# docs in batch) ≤ N``.

    Returns
    -------
    list[list[list[int]]]
        ``allocation[w][b]`` is the list of document indices that belong to the
        *b-th* batch assigned to worker ``w``.  Every worker receives the same
        number of (non-empty) batches.

    Raises
    ------
    AllocationError
        If the three hard constraints cannot be satisfied.

    Notes
    -----
    1.  **Per-batch cost constraint**:  Each batch is padded to the maximum
        sequence length *inside that batch*, so its cost in “token × examples”
        units is ``max_len_in_batch * batch_size``.  This must stay ≤ ``N``.
    2.  **Bin-packing strategy**:  We use *first-fit decreasing* (FFD) to obtain
        an initial near-minimal set of batches, then split some of the larger
        batches (never increases cost) until

            * every worker has at least one batch,
            * the total number of batches is a multiple of ``workers``.

        Because each split only lowers the cost of the two resulting batches,
        the constraint in (1) remains satisfied throughout.
    """

    if workers is None:
        world_size = dist.get_world_size() if dist.is_initialized() else 1
    else:
        world_size = workers

    if not doc_lengths:
        raise RuntimeError("Empty document list.")
    if max(doc_lengths) > N:  # a single document would overflow any batch
        raise RuntimeError("At least one document is too long for the budget N.")

    # ---------------------------------------------------------------------
    # 1) First-fit decreasing (FFD) bin packing under the cost function
    #    cost(batch) = max_len_in_batch * len(batch)
    # ---------------------------------------------------------------------
    docs_sorted = sorted(enumerate(doc_lengths), key=lambda x: x[1], reverse=True)
    batches: list[list[int]] = []  # holds document *indices*
    batch_meta = []  # (max_len, size) for each batch

    for idx, length in docs_sorted:
        placed = False
        for j, (mx, sz) in enumerate(batch_meta):
            new_mx = max(mx, length)
            new_sz = sz + 1
            if new_mx * new_sz <= N:  # still fits
                batches[j].append(idx)
                batch_meta[j] = (new_mx, new_sz)
                placed = True
                break

        if not placed:  # open a new batch
            batches.append([idx])
            batch_meta.append((length, 1))

    # ---------------------------------------------------------------------
    # 2) Ensure every worker gets ≥ 1 batch
    # ---------------------------------------------------------------------
    if len(batches) < world_size:
        # split the largest batches (by size) until we have ≥ workers batches
        batches.sort(key=len, reverse=True)
        while len(batches) < world_size:
            big = batches.pop(0)  # take the current largest
            if len(big) == 1:  # cannot split a singleton
                raise RuntimeError("Not enough documents to give each worker at least one batch.")
            batches.append([big.pop()])  # move one doc into new batch
            batches.append(big)  # put the remainder back
            # preserve cost constraint automatically

    # ---------------------------------------------------------------------
    # 3) Pad the number of batches to a multiple of `workers`
    # ---------------------------------------------------------------------
    k = -(-len(batches) // world_size)  # ceiling division
    target_batches = world_size * k  # == k batches per worker

    # Split arbitrary (non-singleton) batches until we reach the target
    i = 0
    while len(batches) < target_batches:
        batch = batches[i % len(batches)]
        if len(batch) == 1:
            i += 1  # try another batch
            continue
        batches.append([batch.pop()])  # split off a singleton
        i += 1

    assert len(batches) == target_batches
    assert all(max(doc_lengths[i] for i in batch) * len(batch) <= N for batch in batches)

    # ---------------------------------------------------------------------
    # 4) Round-robin assignment to workers
    # ---------------------------------------------------------------------
    allocation: list[list[list[int]]] = [[] for _ in range(world_size)]
    for b_idx, batch in enumerate(batches):
        allocation[b_idx % world_size].append(batch)

    # sanity: equal # of batches per worker
    assert len({len(b) for b in allocation}) == 1
    return allocation

## 0. Hyperparameters

In [4]:
test_path = "/root/bergson-approx-unrolling/tests/ekfac_tests/test_files/pile_100_examples/ground_truth"
os.makedirs(test_path, exist_ok=True)
cfg = IndexConfig(run_path="")  # empty run path because we are not using it to save data
cfg.model = "EleutherAI/Pythia-14m"
cfg.precision = "fp32"
cfg.revision = None
cfg.fsdp = False
cfg.normalizer = "none"
cfg.fisher_fourth_root = False
cfg.data = DataConfig(dataset="/root/bergson-approx-unrolling/tests/ekfac_tests/test_files/pile_100_examples" + "/data")

data_str = cfg.data.dataset

# save cfg
with open(os.path.join(test_path, "index_config.json"), "w") as f:
    json.dump(asdict(cfg), f, indent=4)


workers = 1  # simulating n workers, but we will run on a single GPU to get ground truth
device = torch.device("cuda:0")

In [5]:
match cfg.precision:
    case "bf16":
        dtype = torch.bfloat16
    case "fp16":
        dtype = torch.float16
    case "fp32":
        dtype = torch.float32
    case "int4" | "int8":
        dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    case other:
        raise ValueError(f"Unsupported precision: {other}")


## 1. Loading model and data

In [6]:
model = AutoModelForCausalLM.from_pretrained(
    cfg.model,
    device_map="cuda",
    quantization_config=(
        BitsAndBytesConfig(
            load_in_4bit=cfg.precision == "int4",
            load_in_8bit=cfg.precision == "int8",
            bnb_4bit_compute_dtype=dtype,
            bnb_4bit_quant_storage=dtype,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
        )
        if cfg.precision in ("int4", "int8")
        else None
    ),
    torch_dtype=dtype,
    revision=cfg.revision,
)


In [7]:
data_str = cfg.data.dataset
if data_str.endswith(".csv"):
    ds = assert_type(Dataset, Dataset.from_csv(data_str))
elif data_str.endswith(".json") or data_str.endswith(".jsonl"):
    ds = assert_type(Dataset, Dataset.from_json(data_str))
else:
    try:
        ds = load_dataset(data_str, split="train", streaming=cfg.streaming)

        if isinstance(ds, DatasetDict) or isinstance(ds, IterableDatasetDict):
            raise NotImplementedError("DatasetDicts and IterableDatasetDicts are not supported.")
    except ValueError as e:
        # Automatically use load_from_disk if appropriate
        if "load_from_disk" in str(e):
            ds = Dataset.load_from_disk(data_str, keep_in_memory=False)
        else:
            raise e


In [8]:
assert isinstance(ds, Dataset)  # pleasing the typechecker
remove_columns = ds.column_names if cfg.drop_columns else None
tokenizer = AutoTokenizer.from_pretrained(cfg.model, model_max_length=cfg.token_batch_size, revision=cfg.revision)


ds = ds.map(
    tokenize,
    batched=True,
    fn_kwargs=dict(args=cfg.data, tokenizer=tokenizer),
    remove_columns=remove_columns,
)

data = ds

In [9]:
batches_world = allocate_batches_test(doc_lengths=ds["length"], N=cfg.token_batch_size, workers=workers)
assert len(batches_world) == workers

In [10]:
target_modules = None
normalizers = {}

processor = GradientProcessor(
    normalizers=normalizers,
    fisher_fourth_root=cfg.fisher_fourth_root,
    projection_dim=None,
)


## 2. Compute activation and gradient covariance

In [11]:
covariance_test_path = os.path.join(test_path, "covariances")

In [12]:
def compute_covariance(rank: int, activation_covariances={}, gradient_covariances={}):
    total_processed = 0
    batches = batches_world[rank]

    loss_list = []

    def callback_activation(name: str, a: torch.Tensor):
        activation_covariance = activation_covariances.get(name, None)  # Our stored slice

        a = a.reshape(-1, a.shape[-1])  # [N*S, O]
        update = a.mT @ a

        if activation_covariance is None:
            activation_covariances[name] = update
        else:
            # Add it to our permanently stored slice
            activation_covariance.add_(update)

    def callback_gradient(name: str, g: torch.Tensor):
        gradient_covariance = gradient_covariances.get(name, None)

        g = g.reshape(-1, g.shape[-1])  # [N*S, O]
        update = g.mT @ g

        if gradient_covariance is None:
            gradient_covariances[name] = update
        else:
            gradient_covariance.add_(update)

    collector = EkfacCollector(
        model.base_model,
        closure=callback_gradient,
        processor=processor,
        target_modules=target_modules,
        fwd_closure=callback_activation,
    )
    for sl in tqdm(batches):
        batch = data[sl]
        x, y = pad_and_tensor(
            batch["input_ids"],  # type: ignore
            labels=batch.get("labels"),  # type: ignore
            device=device,
        )

        total_processed += x.numel()

        with collector:
            logits = model(x).logits
            losses = F.cross_entropy(
                logits[:, :-1].reshape(-1, logits.size(-1)),
                y[:, 1:].flatten(),
                reduction="none",
            ).reshape_as(y[:, 1:])

            masks = y[:, 1:] != -100
            denoms = masks.sum(dim=1, dtype=logits.dtype)
            losses = losses.sum(1).div(denoms)

            losses.mean().backward()
            loss_list.append(losses.detach().cpu())
            model.zero_grad()

    return {"losses": loss_list, "total_processed_rank": total_processed}


In [13]:
total_processed_global = 0
for rank in range(workers):
    covariance_test_path_rank = os.path.join(covariance_test_path, f"rank_{rank}")
    os.makedirs(covariance_test_path_rank, exist_ok=True)

    activation_covariances = {}
    gradient_covariances = {}
    d = compute_covariance(
        rank=rank, activation_covariances=activation_covariances, gradient_covariances=gradient_covariances
    )

    save_file(activation_covariances, os.path.join(covariance_test_path_rank, "activation_covariance.safetensors"))
    save_file(gradient_covariances, os.path.join(covariance_test_path_rank, "gradient_covariance.safetensors"))
    with open(os.path.join(covariance_test_path_rank, "stats.json"), "w") as f:
        json.dump({"total_processed_rank": d["total_processed_rank"]}, f, indent=4)
        print(f"Rank {rank} processed {d['total_processed_rank']} tokens.")


  0%|          | 0/14 [00:00<?, ?it/s]

Rank 0 processed 102000 tokens.


In [14]:
# Combine results from all ranks
activation_covariances = TensorDict({})
gradient_covariances = TensorDict({})
total_processed_global = 0
loss_list = []

for rank in range(workers):
    covariance_test_path_rank = os.path.join(covariance_test_path, f"rank_{rank}")

    with open(os.path.join(covariance_test_path_rank, "stats.json"), "r") as f:
        d = json.load(f)
        total_processed_global += d["total_processed_rank"]

    # TensorDict wrapper to simplify tensor operations over dicts of tensors
    activation_covariances_rank = TensorDict(
        load_file(os.path.join(covariance_test_path_rank, "activation_covariance.safetensors"))
    ).to(device)

    gradient_covariances_rank = TensorDict(
        load_file(os.path.join(covariance_test_path_rank, "gradient_covariance.safetensors"))
    ).to(device)

    if not activation_covariances:
        activation_covariances = activation_covariances_rank
    else:
        activation_covariances.add_(activation_covariances_rank)

    if not gradient_covariances:
        gradient_covariances = gradient_covariances_rank
    else:
        gradient_covariances.add_(gradient_covariances_rank)


In [None]:
save_file(activation_covariances.to_dict(), os.path.join(covariance_test_path, "activation_covariance.safetensors"))
save_file(gradient_covariances.to_dict(), os.path.join(covariance_test_path, "gradient_covariance.safetensors"))
with open(os.path.join(covariance_test_path, "stats.json"), "w") as f:
    json.dump({"total_processed_global": total_processed_global}, f, indent=4)
    print(f"Global processed {total_processed_global} tokens.")

Global processed 102000 tokens.


In [16]:
gc.collect()
torch.cuda.empty_cache()

## 3. Compute eigenvalues and eigenvectors

By default, eigh will be done in float64!

In [17]:
eigenvectors_test_path = os.path.join(test_path, "eigenvectors")
os.makedirs(eigenvectors_test_path, exist_ok=True)

eigenvectors_activations = {}
eigenvectors_gradients = {}

In [18]:
# load activation and gradient covariance, and total processed
with open(os.path.join(covariance_test_path, "stats.json"), "r") as f:
    d = json.load(f)
    total_processed_global = d["total_processed_global"]

activation_covariances = load_file(os.path.join(covariance_test_path, "activation_covariance.safetensors"))
gradient_covariances = load_file(os.path.join(covariance_test_path, "gradient_covariance.safetensors"))


In [19]:
# # Use run_path to see if everything is correct from 3. and errors don't propagate
# covariance_a_run_path = "/root/bergson-approx-unrolling/tests/ekfac_tests/test_files/pile_100_examples/run/influence_results/activation_covariance_sharded/shard_0.safetensors"
# covariance_g_run_path = "/root/bergson-approx-unrolling/tests/ekfac_tests/test_files/pile_100_examples/run/influence_results/gradient_covariance_sharded/shard_0.safetensors"
# activation_covariances = load_file(covariance_a_run_path)
# gradient_covariances = load_file(covariance_g_run_path)


In [20]:
for name in activation_covariances.keys():
    a = activation_covariances[name].to(dtype=torch.float64, device=device)
    g = gradient_covariances[name].to(dtype=torch.float64, device=device)

    a.div_(total_processed_global)
    g.div_(total_processed_global)

    eigenvalues_a, eigenvectors_a = torch.linalg.eigh(a)
    eigenvalues_g, eigenvectors_g = torch.linalg.eigh(g)
    print(name, eigenvectors_a.sum(), eigenvectors_g.sum())
    eigenvectors_activations[name] = eigenvectors_a.to(dtype=dtype).contiguous()
    eigenvectors_gradients[name] = eigenvectors_g.to(dtype=dtype).contiguous()

save_file(eigenvectors_activations, os.path.join(eigenvectors_test_path, "eigenvectors_activations.safetensors"))
save_file(eigenvectors_gradients, os.path.join(eigenvectors_test_path, "eigenvectors_gradients.safetensors"))


layers.0.mlp.dense_4h_to_h tensor(-1.7092, device='cuda:0', dtype=torch.float64) tensor(-11.3137, device='cuda:0', dtype=torch.float64)
layers.0.mlp.dense_h_to_4h tensor(-11.6024, device='cuda:0', dtype=torch.float64) tensor(12.3206, device='cuda:0', dtype=torch.float64)
layers.1.mlp.dense_4h_to_h tensor(25.6717, device='cuda:0', dtype=torch.float64) tensor(-11.3137, device='cuda:0', dtype=torch.float64)
layers.1.mlp.dense_h_to_4h tensor(13.0474, device='cuda:0', dtype=torch.float64) tensor(5.1661, device='cuda:0', dtype=torch.float64)
layers.2.mlp.dense_4h_to_h tensor(-10.1438, device='cuda:0', dtype=torch.float64) tensor(11.3137, device='cuda:0', dtype=torch.float64)
layers.2.mlp.dense_h_to_4h tensor(13.2539, device='cuda:0', dtype=torch.float64) tensor(27.3991, device='cuda:0', dtype=torch.float64)
layers.3.mlp.dense_4h_to_h tensor(16.1342, device='cuda:0', dtype=torch.float64) tensor(11.3137, device='cuda:0', dtype=torch.float64)
layers.3.mlp.dense_h_to_4h tensor(-11.7757, device='

In [21]:
gc.collect()
torch.cuda.empty_cache()

## 4. Compute eigenvaluecorrection

In [22]:
eigenvalue_correction_test_path = os.path.join(test_path, "eigenvalue_corrections")
os.makedirs(eigenvalue_correction_test_path, exist_ok=True)

In [23]:
# load eigenvectors
eigenvectors_activations = load_file(os.path.join(eigenvectors_test_path, "eigenvectors_activations.safetensors"))
eigenvectors_gradients = load_file(os.path.join(eigenvectors_test_path, "eigenvectors_gradients.safetensors"))

In [None]:
# load eigenvectors from run
eigenvectors_activations_run_path = "/root/bergson-approx-unrolling/tests/ekfac_tests/test_files/pile_100_examples/run/influence_results/activation_eigen_sharded/shard_0.safetensors"
eigenvector_gradient_run_path = "/root/bergson-approx-unrolling/tests/ekfac_tests/test_files/pile_100_examples/run/influence_results/gradient_eigen_sharded/shard_0.safetensors"

eigenvectors_activations = load_file(eigenvectors_activations_run_path)
eigenvectors_gradients = load_file(eigenvector_gradient_run_path)


In [24]:
eigenvalue_corrections = {}
activation_cache = {}


In [25]:
def compute_eigenvalue_correction(
    rank: int,
    eigenvalue_corrections,
    eigenvectors_activations=eigenvectors_activations,
    eigenvectors_gradients=eigenvectors_gradients,
):
    total_processed = 0
    batches = batches_world[rank]

    def callback_activation(name: str, a: torch.Tensor):
        activation = a.reshape(-1, a.shape[-1])
        activation_cache[name] = activation

    def callback_gradient(name: str, g: torch.Tensor):
        eigenvector_a = eigenvectors_activations[name].to(device=device)
        eigenvector_g = eigenvectors_gradients[name].to(device=device)

        g = g.reshape(-1, g.shape[-1])  # [N*S, O]

        a = activation_cache[name]
        gradient = torch.einsum("B I, B O -> B I O", a, g)

        transformed_gradient = torch.einsum(" B I O, I J -> B I O ", gradient, eigenvector_a)

        correction = torch.einsum(" B I O, J O -> B I O", transformed_gradient, eigenvector_g)
        correction = correction**2
        correction = correction.sum(dim=0)  # sum over batch dimension

        if name not in eigenvalue_corrections:
            eigenvalue_corrections[name] = correction
        else:
            eigenvalue_corrections[name].add_(correction)

    collector = EkfacCollector(
        model.base_model,
        closure=callback_gradient,
        processor=processor,
        target_modules=target_modules,
        fwd_closure=callback_activation,
    )
    for sl in tqdm(batches):
        batch = data[sl]
        x, y = pad_and_tensor(
            batch["input_ids"],  # type: ignore
            labels=batch.get("labels"),  # type: ignore
            device=device,
        )

        total_processed += x.numel()

        with collector:
            logits = model(x).logits
            losses = F.cross_entropy(
                logits[:, :-1].reshape(-1, logits.size(-1)),
                y[:, 1:].flatten(),
                reduction="none",
            ).reshape_as(y[:, 1:])

            masks = y[:, 1:] != -100
            denoms = masks.sum(dim=1, dtype=logits.dtype)
            losses = losses.sum(1).div(denoms)

            losses.mean().backward()

            model.zero_grad()

    return {"losses": loss_list, "total_processed_rank": total_processed}


In [26]:
total_processed_global = 0
for rank in range(workers):
    eigenvalue_correction_test_path_rank = os.path.join(eigenvalue_correction_test_path, f"rank_{rank}")
    os.makedirs(eigenvalue_correction_test_path_rank, exist_ok=True)

    eigenvalue_corrections = {}
    d = compute_eigenvalue_correction(
        rank=rank,
        eigenvalue_corrections=eigenvalue_corrections,
        eigenvectors_activations=eigenvectors_activations,
        eigenvectors_gradients=eigenvectors_gradients,
    )
    save_file(
        eigenvalue_corrections, os.path.join(eigenvalue_correction_test_path_rank, "eigenvalue_corrections.safetensors")
    )
    with open(os.path.join(covariance_test_path_rank, "stats.json"), "w") as f:
        json.dump({"total_processed_rank": d["total_processed_rank"]}, f, indent=4)
        print(f"Rank {rank} processed {d['total_processed_rank']} tokens.")
    total_processed_global += d["total_processed_rank"]

  0%|          | 0/14 [00:00<?, ?it/s]

Rank 0 processed 102000 tokens.


In [27]:
# Combine results from all ranks
eigenvalue_corrections = TensorDict({})


for rank in range(workers):
    eigenvalue_correction_test_path_rank = os.path.join(eigenvalue_correction_test_path, f"rank_{rank}")

    # TensorDict wrapper to simplify tensor operations over dicts of tensors
    eigenvalue_corrections_rank = TensorDict(
        load_file(os.path.join(eigenvalue_correction_test_path_rank, "eigenvalue_corrections.safetensors"))
    ).to(device)

    if not eigenvalue_corrections:
        eigenvalue_corrections = eigenvalue_corrections_rank
    else:
        eigenvalue_corrections.add_(eigenvalue_corrections_rank)

eigenvalue_corrections.div_(total_processed_global)
save_file(
    eigenvalue_corrections.to_dict(),
    os.path.join(eigenvalue_correction_test_path, "eigenvalue_corrections.safetensors"),
)

In [28]:
total_processed_global

102000