## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import gc
import json
import os
from dataclasses import asdict
from typing import Optional

import torch
import torch.distributed as dist
import torch.nn.functional as F
from datasets import Dataset, DatasetDict, IterableDatasetDict, load_dataset
from safetensors.torch import load_file, save_file
from tqdm.notebook import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

from bergson.data import DataConfig, IndexConfig, pad_and_tensor, tokenize
from bergson.gradients import (
    GradientProcessor,
)
from bergson.hessians.collector import EkfacCollector
from bergson.hessians.utils import TensorDict
from bergson.utils import assert_type

## -1. Helper functions

In [3]:
def allocate_batches_test(doc_lengths: list[int], N: int, workers: Optional[int] = None) -> list[list[list[int]]]:
    """
    Modification of allocate_batches to return a flat list of batches for testing(instead of returning allocation[rank])
    Allocate documents into batches that are then distributed evenly across
    a fixed number of workers.

    Parameters
    ----------
    doc_lengths : Sequence[int]
        Length (in tokens) of each document.  The *i-th* document is referred to
        internally by its index ``i``.
    workers : int
        Number of parallel workers ( 1 ≤ workers ≤ 8).
    N : int
        Hard memory budget per *batch*, expressed as
        ``max(length in batch) * (# docs in batch) ≤ N``.

    Returns
    -------
    list[list[list[int]]]
        ``allocation[w][b]`` is the list of document indices that belong to the
        *b-th* batch assigned to worker ``w``.  Every worker receives the same
        number of (non-empty) batches.

    Raises
    ------
    AllocationError
        If the three hard constraints cannot be satisfied.

    Notes
    -----
    1.  **Per-batch cost constraint**:  Each batch is padded to the maximum
        sequence length *inside that batch*, so its cost in “token × examples”
        units is ``max_len_in_batch * batch_size``.  This must stay ≤ ``N``.
    2.  **Bin-packing strategy**:  We use *first-fit decreasing* (FFD) to obtain
        an initial near-minimal set of batches, then split some of the larger
        batches (never increases cost) until

            * every worker has at least one batch,
            * the total number of batches is a multiple of ``workers``.

        Because each split only lowers the cost of the two resulting batches,
        the constraint in (1) remains satisfied throughout.
    """

    if workers is None:
        world_size = dist.get_world_size() if dist.is_initialized() else 1
    else:
        world_size = workers

    if not doc_lengths:
        raise RuntimeError("Empty document list.")
    if max(doc_lengths) > N:  # a single document would overflow any batch
        raise RuntimeError("At least one document is too long for the budget N.")

    # ---------------------------------------------------------------------
    # 1) First-fit decreasing (FFD) bin packing under the cost function
    #    cost(batch) = max_len_in_batch * len(batch)
    # ---------------------------------------------------------------------
    docs_sorted = sorted(enumerate(doc_lengths), key=lambda x: x[1], reverse=True)
    batches: list[list[int]] = []  # holds document *indices*
    batch_meta = []  # (max_len, size) for each batch

    for idx, length in docs_sorted:
        placed = False
        for j, (mx, sz) in enumerate(batch_meta):
            new_mx = max(mx, length)
            new_sz = sz + 1
            if new_mx * new_sz <= N:  # still fits
                batches[j].append(idx)
                batch_meta[j] = (new_mx, new_sz)
                placed = True
                break

        if not placed:  # open a new batch
            batches.append([idx])
            batch_meta.append((length, 1))

    # ---------------------------------------------------------------------
    # 2) Ensure every worker gets ≥ 1 batch
    # ---------------------------------------------------------------------
    if len(batches) < world_size:
        # split the largest batches (by size) until we have ≥ workers batches
        batches.sort(key=len, reverse=True)
        while len(batches) < world_size:
            big = batches.pop(0)  # take the current largest
            if len(big) == 1:  # cannot split a singleton
                raise RuntimeError("Not enough documents to give each worker at least one batch.")
            batches.append([big.pop()])  # move one doc into new batch
            batches.append(big)  # put the remainder back
            # preserve cost constraint automatically

    # ---------------------------------------------------------------------
    # 3) Pad the number of batches to a multiple of `workers`
    # ---------------------------------------------------------------------
    k = -(-len(batches) // world_size)  # ceiling division
    target_batches = world_size * k  # == k batches per worker

    # Split arbitrary (non-singleton) batches until we reach the target
    i = 0
    while len(batches) < target_batches:
        batch = batches[i % len(batches)]
        if len(batch) == 1:
            i += 1  # try another batch
            continue
        batches.append([batch.pop()])  # split off a singleton
        i += 1

    assert len(batches) == target_batches
    assert all(max(doc_lengths[i] for i in batch) * len(batch) <= N for batch in batches)

    # ---------------------------------------------------------------------
    # 4) Round-robin assignment to workers
    # ---------------------------------------------------------------------
    allocation: list[list[list[int]]] = [[] for _ in range(world_size)]
    for b_idx, batch in enumerate(batches):
        allocation[b_idx % world_size].append(batch)

    # sanity: equal # of batches per worker
    assert len({len(b) for b in allocation}) == 1
    return allocation

## 0. Hyperparameters

In [4]:
current_path = os.getcwd()
parent_path = os.path.join(current_path, "test_files", "pile_10k_examples")

test_path = parent_path + "/ground_truth"
ekfac_run_path = parent_path + "/run/influence_results"


os.makedirs(test_path, exist_ok=True)
cfg = IndexConfig(run_path="")  # empty run path because we are not using it to save data
cfg.model = "EleutherAI/Pythia-14m"
cfg.precision = "fp32"
cfg.revision = None
cfg.fsdp = False
cfg.normalizer = "none"
cfg.fisher_fourth_root = False

# cfg.data = DataConfig(dataset=parent_path + "/data")
cfg.data = DataConfig(dataset="NeelNanda/pile-10k")

data_str = cfg.data.dataset

# save cfg
with open(os.path.join(test_path, "index_config.json"), "w") as f:
    json.dump(asdict(cfg), f, indent=4)


workers = 8  # simulating n workers, but we will run on a single GPU to get ground truth
device = torch.device("cuda:0")

In [5]:
match cfg.precision:
    case "bf16":
        dtype = torch.bfloat16
    case "fp16":
        dtype = torch.float16
    case "fp32":
        dtype = torch.float32
    case "int4" | "int8":
        dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    case other:
        raise ValueError(f"Unsupported precision: {other}")


In [6]:
debug_name = "layers.0.mlp.dense_4h_to_h"  # for debugging

## 1. Loading model and data

In [7]:
model = AutoModelForCausalLM.from_pretrained(
    cfg.model,
    device_map="cuda",
    quantization_config=(
        BitsAndBytesConfig(
            load_in_4bit=cfg.precision == "int4",
            load_in_8bit=cfg.precision == "int8",
            bnb_4bit_compute_dtype=dtype,
            bnb_4bit_quant_storage=dtype,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
        )
        if cfg.precision in ("int4", "int8")
        else None
    ),
    torch_dtype=dtype,
    revision=cfg.revision,
)


config.json:   0%|          | 0.00/595 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/53.3M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

In [8]:
data_str = cfg.data.dataset
if data_str.endswith(".csv"):
    ds = assert_type(Dataset, Dataset.from_csv(data_str))
elif data_str.endswith(".json") or data_str.endswith(".jsonl"):
    ds = assert_type(Dataset, Dataset.from_json(data_str))
else:
    try:
        ds = load_dataset(data_str, split="train", streaming=cfg.streaming)

        if isinstance(ds, DatasetDict) or isinstance(ds, IterableDatasetDict):
            raise NotImplementedError("DatasetDicts and IterableDatasetDicts are not supported.")
    except ValueError as e:
        # Automatically use load_from_disk if appropriate
        if "load_from_disk" in str(e):
            ds = Dataset.load_from_disk(data_str, keep_in_memory=False)
        else:
            raise e


README.md:   0%|          | 0.00/373 [00:00<?, ?B/s]

dataset_infos.json:   0%|          | 0.00/921 [00:00<?, ?B/s]

(…)-00000-of-00001-4746b8785c874cc7.parquet:   0%|          | 0.00/33.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [9]:
assert isinstance(ds, Dataset)  # pleasing the typechecker
remove_columns = ds.column_names if cfg.drop_columns else None
tokenizer = AutoTokenizer.from_pretrained(cfg.model, model_max_length=cfg.token_batch_size, revision=cfg.revision)


ds = ds.map(
    tokenize,
    batched=True,
    fn_kwargs=dict(args=cfg.data, tokenizer=tokenizer),
    remove_columns=remove_columns,
)

data = ds

tokenizer_config.json:   0%|          | 0.00/264 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [10]:
batches_world = allocate_batches_test(doc_lengths=ds["length"], N=cfg.token_batch_size, workers=workers)
assert len(batches_world) == workers

In [11]:
target_modules = None
normalizers = {}

processor = GradientProcessor(
    normalizers=normalizers,
    fisher_fourth_root=cfg.fisher_fourth_root,
    projection_dim=None,
)


## 2. Compute activation and gradient covariance

In [12]:
covariance_test_path = os.path.join(test_path, "covariances")

In [13]:
def compute_covariance(rank: int, activation_covariances={}, gradient_covariances={}):
    total_processed = 0
    batches = batches_world[rank]

    loss_list = []

    def callback_activation(name: str, a: torch.Tensor):
        activation_covariance = activation_covariances.get(name, None)  # Our stored slice

        a = a.reshape(-1, a.shape[-1])  # [N*S, O]
        if name == debug_name:
            print(a[0, 0])
        update = a.mT @ a
        if name == debug_name:
            print(update[0, 0])
        if activation_covariance is None:
            activation_covariances[name] = update

        else:
            # Add it to our permanently stored slice
            activation_covariance.add_(update)

    def callback_gradient(name: str, g: torch.Tensor):
        gradient_covariance = gradient_covariances.get(name, None)

        g = g.reshape(-1, g.shape[-1])  # [N*S, O]
        update = g.mT @ g

        if gradient_covariance is None:
            gradient_covariances[name] = update
        else:
            gradient_covariance.add_(update)

    collector = EkfacCollector(
        model.base_model,
        closure=callback_gradient,
        processor=processor,
        target_modules=target_modules,
        fwd_closure=callback_activation,
    )
    for sl in tqdm(batches):
        batch = data[sl]
        x, y = pad_and_tensor(
            batch["input_ids"],  # type: ignore
            labels=batch.get("labels"),  # type: ignore
            device=device,
        )

        total_processed += x.numel()

        with collector:
            logits = model(x).logits
            losses = F.cross_entropy(
                logits[:, :-1].reshape(-1, logits.size(-1)),
                y[:, 1:].flatten(),
                reduction="none",
            ).reshape_as(y[:, 1:])

            masks = y[:, 1:] != -100
            denoms = masks.sum(dim=1, dtype=logits.dtype)
            losses = losses.sum(1).div(denoms)

            losses.mean().backward()

            loss_list.append(losses.detach().cpu())
            model.zero_grad()

    return {"losses": loss_list, "total_processed_rank": total_processed}


In [14]:
total_processed_global = 0
for rank in range(workers):
    covariance_test_path_rank = os.path.join(covariance_test_path, f"rank_{rank}")
    os.makedirs(covariance_test_path_rank, exist_ok=True)

    activation_covariances = {}
    gradient_covariances = {}
    d = compute_covariance(
        rank=rank, activation_covariances=activation_covariances, gradient_covariances=gradient_covariances
    )

    save_file(activation_covariances, os.path.join(covariance_test_path_rank, "activation_covariance.safetensors"))
    save_file(gradient_covariances, os.path.join(covariance_test_path_rank, "gradient_covariance.safetensors"))
    with open(os.path.join(covariance_test_path_rank, "stats.json"), "w") as f:
        json.dump({"total_processed_rank": d["total_processed_rank"]}, f, indent=4)
        print(f"Rank {rank} processed {d['total_processed_rank']} tokens.")


  0%|          | 0/190 [00:00<?, ?it/s]

tensor(-0.1472, device='cuda:0')
tensor(228.2182, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(235.2228, device='cuda:0')
tensor(0.0206, device='cuda:0')
tensor(270.5420, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(193.4709, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(258.1457, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(257.5855, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(240.2396, device='cuda:0')
tensor(0.1708, device='cuda:0')
tensor(368.6412, device='cuda:0')
tensor(-0.1678, device='cuda:0')
tensor(318.9656, device='cuda:0')
tensor(-0.1683, device='cuda:0')
tensor(322.6793, device='cuda:0')
tensor(-0.1672, device='cuda:0')
tensor(218.0562, device='cuda:0')
tensor(-0.1676, device='cuda:0')
tensor(247.2193, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(197.4022, device='cuda:0')
tensor(-0.0017, device='cuda:0')
tensor(267.6510, device='cuda:0')
tensor(0.4969, device='cuda:0')
tensor(300.3181, device='cuda:0'

  0%|          | 0/190 [00:00<?, ?it/s]

tensor(-0.1252, device='cuda:0')
tensor(281.4669, device='cuda:0')
tensor(-0.1698, device='cuda:0')
tensor(291.3121, device='cuda:0')
tensor(-0.1664, device='cuda:0')
tensor(344.2610, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(220.9984, device='cuda:0')
tensor(-0.0719, device='cuda:0')
tensor(404.3353, device='cuda:0')
tensor(-0.1678, device='cuda:0')
tensor(232.8561, device='cuda:0')
tensor(-0.1649, device='cuda:0')
tensor(354.7084, device='cuda:0')
tensor(-0.1625, device='cuda:0')
tensor(275.5823, device='cuda:0')
tensor(-0.1450, device='cuda:0')
tensor(239.0916, device='cuda:0')
tensor(-0.0746, device='cuda:0')
tensor(328.4109, device='cuda:0')
tensor(-0.1678, device='cuda:0')
tensor(411.3898, device='cuda:0')
tensor(-0.1621, device='cuda:0')
tensor(280.5967, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(237.3250, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(210.7202, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(216.7578, device='cuda

  0%|          | 0/190 [00:00<?, ?it/s]

tensor(-0.1038, device='cuda:0')
tensor(315.0810, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(220.9362, device='cuda:0')
tensor(-0.1678, device='cuda:0')
tensor(294.6105, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(216.9103, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(213.0353, device='cuda:0')
tensor(-0.0863, device='cuda:0')
tensor(209.8636, device='cuda:0')
tensor(0.0997, device='cuda:0')
tensor(238.0121, device='cuda:0')
tensor(-0.1594, device='cuda:0')
tensor(282.6869, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(239.5140, device='cuda:0')
tensor(-0.1678, device='cuda:0')
tensor(229.7704, device='cuda:0')
tensor(-0.1696, device='cuda:0')
tensor(160.8892, device='cuda:0')
tensor(-0.0751, device='cuda:0')
tensor(405.7447, device='cuda:0')
tensor(-0.1678, device='cuda:0')
tensor(269.3739, device='cuda:0')
tensor(-0.0746, device='cuda:0')
tensor(374.0466, device='cuda:0')
tensor(-0.1698, device='cuda:0')
tensor(867.2820, device='cuda:

  0%|          | 0/190 [00:00<?, ?it/s]

tensor(-0.0921, device='cuda:0')
tensor(327.3918, device='cuda:0')
tensor(0.0850, device='cuda:0')
tensor(293.0247, device='cuda:0')
tensor(-0.1539, device='cuda:0')
tensor(309.1653, device='cuda:0')
tensor(-0.1198, device='cuda:0')
tensor(317.0034, device='cuda:0')
tensor(-0.1418, device='cuda:0')
tensor(293.5501, device='cuda:0')
tensor(-0.1038, device='cuda:0')
tensor(250.3243, device='cuda:0')
tensor(-0.0145, device='cuda:0')
tensor(314.8920, device='cuda:0')
tensor(-0.1678, device='cuda:0')
tensor(275.7524, device='cuda:0')
tensor(-0.1673, device='cuda:0')
tensor(309.0125, device='cuda:0')
tensor(-0.1038, device='cuda:0')
tensor(372.7568, device='cuda:0')
tensor(-0.1250, device='cuda:0')
tensor(219.9495, device='cuda:0')
tensor(-0.0746, device='cuda:0')
tensor(273.4466, device='cuda:0')
tensor(-0.1678, device='cuda:0')
tensor(402.3210, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(230.4802, device='cuda:0')
tensor(0.0850, device='cuda:0')
tensor(296.0960, device='cuda:0

  0%|          | 0/190 [00:00<?, ?it/s]

tensor(-0.0145, device='cuda:0')
tensor(282.6613, device='cuda:0')
tensor(-0.1689, device='cuda:0')
tensor(364.6546, device='cuda:0')
tensor(-0.1689, device='cuda:0')
tensor(419.3730, device='cuda:0')
tensor(-0.1038, device='cuda:0')
tensor(323.3156, device='cuda:0')
tensor(-0.0746, device='cuda:0')
tensor(312.6237, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(209.0979, device='cuda:0')
tensor(-0.1038, device='cuda:0')
tensor(294.7710, device='cuda:0')
tensor(-0.0746, device='cuda:0')
tensor(291.2008, device='cuda:0')
tensor(-0.1678, device='cuda:0')
tensor(290.8581, device='cuda:0')
tensor(-0.1678, device='cuda:0')
tensor(273.3892, device='cuda:0')
tensor(-0.1678, device='cuda:0')
tensor(247.8301, device='cuda:0')
tensor(-0.1678, device='cuda:0')
tensor(300.5929, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(245.8909, device='cuda:0')
tensor(-0.1539, device='cuda:0')
tensor(235.1472, device='cuda:0')
tensor(-0.1038, device='cuda:0')
tensor(349.6380, device='cuda

  0%|          | 0/190 [00:00<?, ?it/s]

tensor(-0.1036, device='cuda:0')
tensor(273.5028, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(198.5334, device='cuda:0')
tensor(0.3182, device='cuda:0')
tensor(263.0723, device='cuda:0')
tensor(-0.1539, device='cuda:0')
tensor(295.5519, device='cuda:0')
tensor(-0.1678, device='cuda:0')
tensor(245.6493, device='cuda:0')
tensor(-0.1689, device='cuda:0')
tensor(319.6798, device='cuda:0')
tensor(-0.1038, device='cuda:0')
tensor(317.0273, device='cuda:0')
tensor(-0.1038, device='cuda:0')
tensor(256.9405, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(216.8510, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(321.3602, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(205.6788, device='cuda:0')
tensor(0.0850, device='cuda:0')
tensor(280.9502, device='cuda:0')
tensor(-0.1678, device='cuda:0')
tensor(322.9493, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(318.1693, device='cuda:0')
tensor(-0.0746, device='cuda:0')
tensor(262.2312, device='cuda:0

  0%|          | 0/190 [00:00<?, ?it/s]

tensor(-0.1036, device='cuda:0')
tensor(237.9683, device='cuda:0')
tensor(-0.0017, device='cuda:0')
tensor(276.4297, device='cuda:0')
tensor(-0.1376, device='cuda:0')
tensor(202.7254, device='cuda:0')
tensor(-0.1252, device='cuda:0')
tensor(272.8650, device='cuda:0')
tensor(-0.1460, device='cuda:0')
tensor(448.0073, device='cuda:0')
tensor(-0.1666, device='cuda:0')
tensor(278.1480, device='cuda:0')
tensor(-0.1038, device='cuda:0')
tensor(279.9359, device='cuda:0')
tensor(-0.1198, device='cuda:0')
tensor(196.5405, device='cuda:0')
tensor(-0.1625, device='cuda:0')
tensor(346.0069, device='cuda:0')
tensor(-0.1678, device='cuda:0')
tensor(368.7529, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(327.9495, device='cuda:0')
tensor(-0.1038, device='cuda:0')
tensor(319.3956, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(243.7483, device='cuda:0')
tensor(-0.1678, device='cuda:0')
tensor(359.9354, device='cuda:0')
tensor(-0.0746, device='cuda:0')
tensor(263.8012, device='cuda

  0%|          | 0/190 [00:00<?, ?it/s]

tensor(-0.1036, device='cuda:0')
tensor(213.2959, device='cuda:0')
tensor(-0.1539, device='cuda:0')
tensor(273.2589, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(242.8246, device='cuda:0')
tensor(-0.1036, device='cuda:0')
tensor(210.1244, device='cuda:0')
tensor(-0.1698, device='cuda:0')
tensor(287.3475, device='cuda:0')
tensor(-0.0746, device='cuda:0')
tensor(314.5435, device='cuda:0')
tensor(-0.1678, device='cuda:0')
tensor(398.9280, device='cuda:0')
tensor(-0.0447, device='cuda:0')
tensor(397.2848, device='cuda:0')
tensor(-0.0746, device='cuda:0')
tensor(333.6378, device='cuda:0')
tensor(-0.1198, device='cuda:0')
tensor(196.9509, device='cuda:0')
tensor(-0.1038, device='cuda:0')
tensor(263.0993, device='cuda:0')
tensor(-0.1678, device='cuda:0')
tensor(410.9694, device='cuda:0')
tensor(-0.1678, device='cuda:0')
tensor(322.1888, device='cuda:0')
tensor(0.0850, device='cuda:0')
tensor(315.6104, device='cuda:0')
tensor(-0.1569, device='cuda:0')
tensor(256.5403, device='cuda:

In [15]:
# Combine results from all ranks
activation_covariances = TensorDict({})
gradient_covariances = TensorDict({})
total_processed_global = 0
loss_list = []

for rank in range(workers):
    covariance_test_path_rank = os.path.join(covariance_test_path, f"rank_{rank}")

    with open(os.path.join(covariance_test_path_rank, "stats.json"), "r") as f:
        d = json.load(f)
        total_processed_global += d["total_processed_rank"]

    # TensorDict wrapper to simplify tensor operations over dicts of tensors
    activation_covariances_rank = TensorDict(
        load_file(os.path.join(covariance_test_path_rank, "activation_covariance.safetensors"))
    ).to(device)

    gradient_covariances_rank = TensorDict(
        load_file(os.path.join(covariance_test_path_rank, "gradient_covariance.safetensors"))
    ).to(device)

    if not activation_covariances:
        activation_covariances = activation_covariances_rank
    else:
        activation_covariances = activation_covariances + activation_covariances_rank

    if not gradient_covariances:
        gradient_covariances = gradient_covariances_rank
    else:
        gradient_covariances = gradient_covariances + (gradient_covariances_rank)


In [16]:
activation_covariances[debug_name][0, 0]

tensor(404655.8125, device='cuda:0')

In [17]:
save_file(activation_covariances.to_dict(), os.path.join(covariance_test_path, "activation_covariance.safetensors"))
save_file(gradient_covariances.to_dict(), os.path.join(covariance_test_path, "gradient_covariance.safetensors"))
with open(os.path.join(covariance_test_path, "stats.json"), "w") as f:
    json.dump({"total_processed_global": total_processed_global}, f, indent=4)
    print(f"Global processed {total_processed_global} tokens.")

Global processed 10888691 tokens.


In [18]:
gc.collect()
torch.cuda.empty_cache()

## 3. Compute eigenvalues and eigenvectors

By default, eigh will be done in float64!

In [19]:
eigenvectors_test_path = os.path.join(test_path, "eigenvectors")
os.makedirs(eigenvectors_test_path, exist_ok=True)

eigenvectors_activations = {}
eigenvectors_gradients = {}

In [20]:
# load activation and gradient covariance, and total processed
with open(os.path.join(covariance_test_path, "stats.json"), "r") as f:
    d = json.load(f)
    total_processed_global = d["total_processed_global"]

activation_covariances = load_file(os.path.join(covariance_test_path, "activation_covariance.safetensors"))
gradient_covariances = load_file(os.path.join(covariance_test_path, "gradient_covariance.safetensors"))


In [21]:
# # Use run_path to see if everything is correct from 3. and errors don't propagate

# covariance_a_run_path = os.path.join(ekfac_run_path, "activation_covariance_sharded/shard_0.safetensors")
# covariance_g_run_path = os.path.join(ekfac_run_path, "gradient_covariance_sharded/shard_0.safetensors")
# activation_covariances = load_file(covariance_a_run_path)
# gradient_covariances = load_file(covariance_g_run_path)


In [22]:
for name in activation_covariances.keys():
    a = activation_covariances[name].to(dtype=torch.float64, device=device)
    g = gradient_covariances[name].to(dtype=torch.float64, device=device)
    a = (a + a.T).div(2)
    g = (g + g.T).div(2)
    a.div_(total_processed_global)
    g.div_(total_processed_global)

    eigenvalues_a, eigenvectors_a = torch.linalg.eigh(a)
    eigenvalues_g, eigenvectors_g = torch.linalg.eigh(g)
    print(name, eigenvectors_a.sum(), eigenvectors_g.sum())
    eigenvectors_activations[name] = eigenvectors_a.to(dtype=dtype).contiguous()
    eigenvectors_gradients[name] = eigenvectors_g.to(dtype=dtype).contiguous()

save_file(eigenvectors_activations, os.path.join(eigenvectors_test_path, "eigenvectors_activations.safetensors"))
save_file(eigenvectors_gradients, os.path.join(eigenvectors_test_path, "eigenvectors_gradients.safetensors"))


layers.0.mlp.dense_4h_to_h tensor(18.9381, device='cuda:0', dtype=torch.float64) tensor(-11.3137, device='cuda:0', dtype=torch.float64)
layers.0.mlp.dense_h_to_4h tensor(-10.5283, device='cuda:0', dtype=torch.float64) tensor(1.8116, device='cuda:0', dtype=torch.float64)
layers.1.mlp.dense_4h_to_h tensor(16.5975, device='cuda:0', dtype=torch.float64) tensor(11.3137, device='cuda:0', dtype=torch.float64)
layers.1.mlp.dense_h_to_4h tensor(11.9538, device='cuda:0', dtype=torch.float64) tensor(-2.6594, device='cuda:0', dtype=torch.float64)
layers.2.mlp.dense_4h_to_h tensor(-25.6426, device='cuda:0', dtype=torch.float64) tensor(11.3137, device='cuda:0', dtype=torch.float64)
layers.2.mlp.dense_h_to_4h tensor(11.2272, device='cuda:0', dtype=torch.float64) tensor(-30.7567, device='cuda:0', dtype=torch.float64)
layers.3.mlp.dense_4h_to_h tensor(45.5371, device='cuda:0', dtype=torch.float64) tensor(-11.3136, device='cuda:0', dtype=torch.float64)
layers.3.mlp.dense_h_to_4h tensor(18.4924, device='

In [23]:
gc.collect()
torch.cuda.empty_cache()

## 4. Compute eigenvalue correction

In [57]:
eigenvalue_correction_test_path = os.path.join(test_path, "eigenvalue_corrections")
os.makedirs(eigenvalue_correction_test_path, exist_ok=True)

In [58]:
# load eigenvectors
eigenvectors_activations = load_file(os.path.join(eigenvectors_test_path, "eigenvectors_activations.safetensors"))
eigenvectors_gradients = load_file(os.path.join(eigenvectors_test_path, "eigenvectors_gradients.safetensors"))

In [59]:
# load eigenvectors from run
eigenvectors_activations_run_path = os.path.join(ekfac_run_path, "activation_eigen_sharded")


world_size = len(os.listdir(eigenvectors_activations_run_path))  # number of shards
# load run eigenvectors shards and concatenate them
run_eigenvectors_shards = [
    os.path.join(eigenvectors_activations_run_path, f"shard_{rank}.safetensors") for rank in range(world_size)
]
run_eigenvectors_list = [(load_file(shard)) for shard in run_eigenvectors_shards]
run_eigenvectors = {}
for k, v in run_eigenvectors_list[0].items():
    run_eigenvectors[k] = torch.cat([shard[k] for shard in run_eigenvectors_list], dim=0)

eigenvectors_activations = TensorDict(run_eigenvectors)


eigenvectors_gradients_run_path = os.path.join(ekfac_run_path, "gradient_eigen_sharded")


world_size = len(os.listdir(eigenvectors_gradients_run_path))  # number of shards
# load run eigenvectors shards and concatenate them
run_eigenvectors_shards = [
    os.path.join(eigenvectors_gradients_run_path, f"shard_{rank}.safetensors") for rank in range(world_size)
]
run_eigenvectors_list = [(load_file(shard)) for shard in run_eigenvectors_shards]
run_eigenvectors = {}
for k, v in run_eigenvectors_list[0].items():
    run_eigenvectors[k] = torch.cat([shard[k] for shard in run_eigenvectors_list], dim=0)

eigenvectors_gradients = TensorDict(run_eigenvectors)


In [60]:
eigenvalue_corrections = {}
activation_cache = {}


# only for debugging
gradient_cache = {}
pseudo_grad_cache = {}
transformed_activation_cache = {}
key_debug = "layers.0.mlp.dense_h_to_4h"

In [61]:
def compute_eigenvalue_correction(
    rank: int,
    eigenvalue_corrections,
    eigenvectors_activations=eigenvectors_activations,
    eigenvectors_gradients=eigenvectors_gradients,
):
    total_processed = 0
    batches = batches_world[rank]

    def callback_activation(name: str, a: torch.Tensor):
        a = a.reshape(-1, a.shape[-1])
        # a = torch.ones_like(a)  # for debugging, pretend all activations are 1
        activation = a
        activation_cache[name] = activation

    def callback_gradient(name: str, g: torch.Tensor):
        eigenvector_a = eigenvectors_activations[name].to(device=device)
        eigenvector_g = eigenvectors_gradients[name].to(device=device)

        g = g.reshape(-1, g.shape[-1])  # [N*S, O]
        # g = torch.ones_like(g)  # for debugging, pretend all gradients are 1
        if name == debug_name:
            gradient_cache[name] = g
        gradient = torch.einsum("B I, B O -> B O I", activation_cache[name], g)
        gradient = torch.einsum(" B O I, I J -> B O J ", gradient, eigenvector_a)
        gradient = torch.einsum(" O P, B O J -> B P J ", eigenvector_g, gradient)
        gradient = gradient**2
        correction = gradient.sum(dim=0)

        if name not in eigenvalue_corrections:
            eigenvalue_corrections[name] = correction
        else:
            eigenvalue_corrections[name].add_(correction)

    collector = EkfacCollector(
        model.base_model,
        closure=callback_gradient,
        processor=processor,
        target_modules=target_modules,
        fwd_closure=callback_activation,
    )
    for sl in tqdm(batches):
        batch = data[sl]
        x, y = pad_and_tensor(
            batch["input_ids"],  # type: ignore
            labels=batch.get("labels"),  # type: ignore
            device=device,
        )

        total_processed += x.numel()

        with collector:
            logits = model(x).logits
            losses = F.cross_entropy(
                logits[:, :-1].reshape(-1, logits.size(-1)),
                y[:, 1:].flatten(),
                reduction="none",
            ).reshape_as(y[:, 1:])

            masks = y[:, 1:] != -100

            denoms = masks.sum(dim=1, dtype=logits.dtype)
            losses = losses.sum(1).div(denoms)

            losses.mean().backward()

            model.zero_grad()

    return {"losses": loss_list, "total_processed_rank": total_processed}


In [62]:
gradient_cache_amortized = {}


def compute_eigenvalue_correction_amortized(
    rank: int,
    eigenvalue_corrections,
    eigenvectors_activations=eigenvectors_activations,
    eigenvectors_gradients=eigenvectors_gradients,
):
    total_processed = 0
    batches = batches_world[rank]

    def callback_activation(name: str, a: torch.Tensor):
        a = a.reshape(-1, a.shape[-1])
        # a = torch.ones_like(a)  # for debugging, pretend all activations are 1
        activation = a
        activation_cache[name] = activation

    def callback_gradient(name: str, g: torch.Tensor):
        eigenvector_a = eigenvectors_activations[name].to(device=device)
        eigenvector_g = eigenvectors_gradients[name].to(device=device)
        g = g.reshape(-1, g.shape[-1])  # [N*S, O]
        # g = torch.ones_like(g)  # for debugging, pretend all gradients are 1

        a_transformed = torch.einsum(" B I, I J -> B J ", activation_cache[name], eigenvector_a)
        g_transformed = torch.einsum(" O P, B O -> B P ", eigenvector_g, g)
        correction = torch.einsum(" B I, B O -> O I", a_transformed**2, g_transformed**2).contiguous()

        if name == debug_name:
            gradient_cache_amortized[name] = g
        if name not in eigenvalue_corrections:
            eigenvalue_corrections[name] = correction
        else:
            eigenvalue_corrections[name].add_(correction)

    collector = EkfacCollector(
        model.base_model,
        closure=callback_gradient,
        processor=processor,
        target_modules=target_modules,
        fwd_closure=callback_activation,
    )
    for sl in tqdm(batches):
        batch = data[sl]
        x, y = pad_and_tensor(
            batch["input_ids"],  # type: ignore
            labels=batch.get("labels"),  # type: ignore
            device=device,
        )

        total_processed += x.numel()

        with collector:
            logits = model(x).logits
            losses = F.cross_entropy(
                logits[:, :-1].reshape(-1, logits.size(-1)),
                y[:, 1:].flatten(),
                reduction="none",
            ).reshape_as(y[:, 1:])

            masks = y[:, 1:] != -100

            denoms = masks.sum(dim=1, dtype=logits.dtype)
            losses = losses.sum(1).div(denoms)

            losses.mean().backward()

            model.zero_grad()

    return {"losses": loss_list, "total_processed_rank": total_processed}


In [63]:
# eigenvalue_corrections_amortized = {}
# d = compute_eigenvalue_correction_amortized(
#     rank=0,
#     eigenvalue_corrections=eigenvalue_corrections_amortized,
#     eigenvectors_activations=eigenvectors_activations,
#     eigenvectors_gradients=eigenvectors_gradients,
# )

In [64]:
# eigenvalue_corrections = {}

# d = compute_eigenvalue_correction(
#     rank=0,
#     eigenvalue_corrections=eigenvalue_corrections,
#     eigenvectors_activations=eigenvectors_activations,
#     eigenvectors_gradients=eigenvectors_gradients,
# )

In [65]:
# TensorDict(gradient_cache).allclose(TensorDict(gradient_cache_amortized), rtol=1e-3, atol=1e-3)

In [66]:
# TensorDict(eigenvalue_corrections).allclose(TensorDict(eigenvalue_corrections_amortized), rtol=1e-3, atol=1e-3)

In [67]:
eigenvalue_corrections = {}

total_processed_global = 0
for rank in range(workers):
    eigenvalue_correction_test_path_rank = os.path.join(eigenvalue_correction_test_path, f"rank_{rank}")
    os.makedirs(eigenvalue_correction_test_path_rank, exist_ok=True)

    eigenvalue_corrections = {}
    d = compute_eigenvalue_correction_amortized(
        rank=rank,
        eigenvalue_corrections=eigenvalue_corrections,
        eigenvectors_activations=eigenvectors_activations,
        eigenvectors_gradients=eigenvectors_gradients,
    )
    save_file(
        eigenvalue_corrections, os.path.join(eigenvalue_correction_test_path_rank, "eigenvalue_corrections.safetensors")
    )
    with open(os.path.join(covariance_test_path_rank, "stats.json"), "w") as f:
        json.dump({"total_processed_rank": d["total_processed_rank"]}, f, indent=4)
        print(f"Rank {rank} processed {d['total_processed_rank']} tokens.")
    total_processed_global += d["total_processed_rank"]

  0%|          | 0/190 [00:00<?, ?it/s]

Rank 0 processed 1359432 tokens.


  0%|          | 0/190 [00:00<?, ?it/s]

Rank 1 processed 1363363 tokens.


  0%|          | 0/190 [00:00<?, ?it/s]

Rank 2 processed 1362429 tokens.


  0%|          | 0/190 [00:00<?, ?it/s]

Rank 3 processed 1361472 tokens.


  0%|          | 0/190 [00:00<?, ?it/s]

Rank 4 processed 1359266 tokens.


  0%|          | 0/190 [00:00<?, ?it/s]

Rank 5 processed 1362080 tokens.


  0%|          | 0/190 [00:00<?, ?it/s]

Rank 6 processed 1361366 tokens.


  0%|          | 0/190 [00:00<?, ?it/s]

Rank 7 processed 1359283 tokens.


In [68]:
# Combine results from all ranks
eigenvalue_corrections = TensorDict({})


for rank in range(workers):
    eigenvalue_correction_test_path_rank = os.path.join(eigenvalue_correction_test_path, f"rank_{rank}")

    # TensorDict wrapper to simplify tensor operations over dicts of tensors
    eigenvalue_corrections_rank = TensorDict(
        load_file(os.path.join(eigenvalue_correction_test_path_rank, "eigenvalue_corrections.safetensors"))
    ).to(device)

    if not eigenvalue_corrections:
        eigenvalue_corrections = eigenvalue_corrections_rank
    else:
        eigenvalue_corrections = eigenvalue_corrections + (eigenvalue_corrections_rank)

eigenvalue_corrections.div_(total_processed_global)
save_file(
    eigenvalue_corrections.to_dict(),
    os.path.join(eigenvalue_correction_test_path, "eigenvalue_corrections.safetensors"),
)

In [69]:
import torch

d_length = 10
test_activations = {}
test_gradients = {}
batch_sizes = torch.randint(4000, 8000, (d_length,))
activation_length = torch.randint(3000, 5000, (d_length,))
gradient_length = torch.randint(3000, 10000, (d_length,))
# Create test activations
for i in range(d_length):
    test_activations[f"layer_{i}"] = torch.rand(batch_sizes[i].item(), activation_length[i].item())
    test_gradients[f"layer_{i}"] = torch.rand(batch_sizes[i].item(), gradient_length[i].item())

In [70]:
TensorDict(test_activations).size()

TensorDict({'layer_0': torch.Size([5853, 3893]), 'layer_1': torch.Size([5941, 4286]), 'layer_2': torch.Size([4318, 3204]), 'layer_3': torch.Size([7423, 3726]), 'layer_4': torch.Size([7827, 4966]), 'layer_5': torch.Size([5599, 4826]), 'layer_6': torch.Size([7552, 3738]), 'layer_7': torch.Size([7499, 4529]), 'layer_8': torch.Size([7513, 3662]), 'layer_9': torch.Size([4860, 3773])})

In [71]:
gc.collect()
torch.cuda.empty_cache()