## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import gc
import hashlib
import json
import os
import random
from contextlib import nullcontext
from typing import Literal, Optional

import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from datasets import Dataset
from jaxtyping import Float
from safetensors import safe_open
from safetensors.torch import load_file, save_file
from torch import Tensor

from tqdm.auto import tqdm
from transformers import PreTrainedModel

from bergson.collection import collect_gradients
from bergson.data import DataConfig, IndexConfig, create_index, load_gradients, pad_and_tensor
from bergson.distributed import distributed_computing, setup_data_pipeline
from bergson.gradients import (
    GradientProcessor,
)
from bergson.hessians.collector import EkfacCollector
from bergson.hessians.logger import get_logger

In [3]:
def _projection(
    name: str,
    m: int,
    n: int,
    device: str,
    side: Literal["left", "right"],
    dtype: torch.dtype,
) -> Tensor:
    """Return the `side` projection matrix for parameter `name` of shape [m, n]."""
    # Seed the PRNG with the name of the layer and what "side" we are projecting
    message = bytes(f"{name}/{side}", "utf-8")
    digest = hashlib.md5(message).digest()
    seed = int.from_bytes(digest, byteorder="big") % (2**63 - 1)
    prng = torch.Generator(device).manual_seed(seed)

    A = torch.randn(m, n, device=device, dtype=dtype, generator=prng)
    A /= A.norm(dim=1, keepdim=True)
    return A

## -1. Compute gradients

In [4]:
test_path = "/root/bergson/tests/ekfac_tests/test_files/pile_10k_examples"
test_gradients_path = test_path + "/test_gradients"

In [5]:
cfg_json = json.load(open(os.path.join(test_path, "ground_truth", "index_config.json"), "r"))
cfg = IndexConfig(**cfg_json)
cfg.data = DataConfig(**(cfg_json["data"]))
original_proj_dim = cfg.projection_dim

In [6]:
data = setup_data_pipeline(cfg)
assert isinstance(data, Dataset)
data = data.select(range(10))  # only a small number of examples since we also want to store uncompressed gradients
# save data to test_gradients_path
test_gradients_path_data = os.path.join(test_gradients_path, "gradient_data")
data.save_to_disk(test_gradients_path_data)

Saving the dataset (0/1 shards):   0%|          | 0/10 [00:00<?, ? examples/s]

In [7]:
cfg.ekfac = True
cfg.skip_preconditioners = True
cfg.world_size = 1
cfg.data.dataset = test_gradients_path_data
cfg.data.completion_column = ""
cfg.data.conversation_column = ""


In [8]:
# Run compressed version
cfg.run_path = test_gradients_path + "/proj_dim_16"
cfg.projection_dim = 16

distributed_computing(
    cfg=cfg,
    worker_fn=collect_gradients,
)

Building index:   0%|          | 0/2 [00:00<?, ?it/s]

layers.0.attention.query_key_value tensor(-0.0463, device='cuda:0')
layers.0.attention.dense tensor(0.5877, device='cuda:0')
layers.0.mlp.dense_h_to_4h tensor(5.8712, device='cuda:0')
layers.0.mlp.dense_4h_to_h tensor(-0.1341, device='cuda:0')
layers.1.attention.query_key_value tensor(-2.3742, device='cuda:0')
layers.1.attention.dense tensor(2.5233, device='cuda:0')
layers.1.mlp.dense_h_to_4h tensor(-4.1161, device='cuda:0')
layers.1.mlp.dense_4h_to_h tensor(2.1344, device='cuda:0')
layers.2.attention.query_key_value tensor(-0.4214, device='cuda:0')
layers.2.attention.dense tensor(-1.4546, device='cuda:0')
layers.2.mlp.dense_h_to_4h tensor(-3.9674, device='cuda:0')
layers.2.mlp.dense_4h_to_h tensor(6.5737, device='cuda:0')
layers.3.attention.query_key_value tensor(-1.0522, device='cuda:0')
layers.3.attention.dense tensor(0.6521, device='cuda:0')
layers.3.mlp.dense_h_to_4h tensor(-1.2894, device='cuda:0')
layers.3.mlp.dense_4h_to_h tensor(1.3050, device='cuda:0')
layers.4.attention.quer

Saving the dataset (0/1 shards):   0%|          | 0/10 [00:00<?, ? examples/s]

In [9]:
# Run uncompression version
cfg.run_path = test_gradients_path + "/proj_dim_0"
cfg.projection_dim = 0
distributed_computing(
    cfg=cfg,
    worker_fn=collect_gradients,
)

Building index:   0%|          | 0/2 [00:00<?, ?it/s]

Saving the dataset (0/1 shards):   0%|          | 0/10 [00:00<?, ? examples/s]

## 0. Load EKFAC

In [10]:
influence_path = test_path + "/run/influence_results"

# all paths inside ekfac_path
world_size = len(os.listdir(influence_path + "/activation_covariance_sharded"))


In [11]:
def merge_shards(path: str):
    shard_paths = [path + f"/shard_{rank}.safetensors" for rank in range(world_size)]
    shard_factor = [load_file(path, device="cuda") for path in shard_paths]
    tensor_dict = {}
    for k, v in shard_factor[0].items():
        tensor_dict[k] = torch.cat([shard_factor[rank][k] for rank in range(world_size)], dim=0)
    return tensor_dict

In [12]:
eigen_a_full = merge_shards(influence_path + "/activation_eigen_sharded")
eigen_g_full = merge_shards(influence_path + "/gradient_eigen_sharded")
lambda_factor_full = merge_shards(influence_path + "/eigenvalue_correction_sharded")

## 1. Load the gradient


In [13]:
gradient_path_16 = test_gradients_path + "/proj_dim_16"

mmap_16 = load_gradients(gradient_path_16)
with open(os.path.join(gradient_path_16, "info.json")) as f:
    info = json.load(f)


In [14]:
gradient_path_0 = test_gradients_path + "/proj_dim_0"
mmap_0 = load_gradients(gradient_path_0)
with open(os.path.join(gradient_path_0, "info.json")) as f:
    info_0 = json.load(f)

## 2. Apply EKFAC

In [15]:
names = eigen_a_full.keys()

In [35]:
def apply_ekfac_module(gradient, eigen_a, eigen_g, lambda_matrix, proj_right, proj_left):
    """Apply EKFAC preconditioning to a gradient.

    Args:
        gradient: A tensor of shape [out_features, in_features].
        eigen_a: Eigenvectors of the activation covariance, shape [in_features, r].
        eigen_g: Eigenvectors of the gradient covariance, shape [out_features, r].
        lambda_matrix: Eigenvalue correction matrix, shape [r, r].

    Returns:
        The preconditioned gradient of shape [out_features, in_features].
    """
    # Project the gradient into the subspace defined by eigen_a and eigen_g
    projected = eigen_g.T @ gradient @ eigen_a  # Shape: [r, r]
    corrected = projected
    # Apply the eigenvalue correction
    inverse_lambda = (lambda_matrix + cfg.lambda_damp_factor * lambda_matrix.mean()).reciprocal()
    inverse_lambda = (lambda_matrix).reciprocal()
    inverse_lambda = (0 + cfg.lambda_damp_factor * lambda_matrix.mean()).reciprocal())

    print(f"y_{i}= {lambda_matrix.mean().item()}")
    for i in range(8):
        shard_size = lambda_matrix.shape[0] // 8
        start = i * shard_size
        end = start + shard_size
        shard = lambda_matrix[start:end]
        shard_mean = shard.mean()
        print(f"y_{i}= {shard_mean.item()}")

    corrected = projected * inverse_lambda  # Element-wise multiplication

    # Reconstruct the preconditioned gradient
    preconditioned = eigen_g @ corrected @ eigen_a.T  # Shape: [o, i]

    projected = proj_left @ preconditioned @ proj_right.T

    return projected

SyntaxError: unmatched ')' (2473352953.py, line 19)

In [36]:
preconditioned_grad_dict = {}
for name in names:
    eigen_a_tensor = eigen_a_full[name]
    eigen_g_tensor = eigen_g_full[name]
    lambda_tensor = lambda_factor_full[name]

    i, o = eigen_a_tensor.shape[0], eigen_g_tensor.shape[0]
    gradient_tensor = torch.from_numpy(mmap_0[name].copy()).to("cuda", dtype=eigen_a_full[name].dtype).view(-1, o, i)

    proj_pi = _projection(
        name,
        original_proj_dim,  # type: ignore
        i,
        device="cuda",
        side="right",
        dtype=eigen_a_full[name].dtype,
    )

    proj_qo = _projection(
        name,
        original_proj_dim,  # type: ignore
        o,
        device="cuda",
        side="left",
        dtype=eigen_g_full[name].dtype,
    )
    preconditioned_gradient = apply_ekfac_module(
        gradient_tensor,
        eigen_a_tensor,
        eigen_g_tensor,
        lambda_tensor,
        proj_pi,
        proj_qo,
    )
    preconditioned_grad_dict[name] = preconditioned_gradient


5.500789370671555e-07
Shard 0 mean: 9.141173507032363e-08
Shard 1 mean: 1.3952188737675897e-07
Shard 2 mean: 1.8090653952640423e-07
Shard 3 mean: 2.2734712956662406e-07
Shard 4 mean: 2.9106851684446156e-07
Shard 5 mean: 3.9171223420453316e-07
Shard 6 mean: 6.236740546228248e-07
Shard 7 mean: 2.4549894988012966e-06
3.055223487535841e-07
Shard 0 mean: 3.150902116999532e-08
Shard 1 mean: 4.221812233140554e-08
Shard 2 mean: 5.456433171957542e-08
Shard 3 mean: 7.203694707413888e-08
Shard 4 mean: 9.989233973328737e-08
Shard 5 mean: 1.5433980138368497e-07
Shard 6 mean: 2.8525721518235514e-07
Shard 7 mean: 1.7043611251210677e-06
2.965434191537497e-07
Shard 0 mean: 4.2788229848156334e-08
Shard 1 mean: 6.242325412131322e-08
Shard 2 mean: 8.104353099724904e-08
Shard 3 mean: 1.0884258472287911e-07
Shard 4 mean: 1.4980199125602667e-07
Shard 5 mean: 2.1152183649064682e-07
Shard 6 mean: 3.559083268100949e-07
Shard 7 mean: 1.360017449769657e-06
2.3346341038177343e-07
Shard 0 mean: 2.986843128383043e-0

In [29]:
x_7 = 2.4549894988012966e-06
x_5 = 3.9171223420453316e-07
x_3 = 2.2734712956662406e-07
x_0 = 9.141173507032363e-08
x_2 = 1.8090653952640423e-07
x_1 = 1.3952188737675897e-07
x_6 = 6.236740546228248e-07
x_4 = 2.9106851684446156e-07

x_mean = x_1 + x_2 + x_3 + x_4 + x_5 + x_6 + x_7 + x_0
x_mean /= 8

In [30]:
x_mean

5.500789495016534e-07

In [18]:
preconditioned_grad_dict

{'layers.0.mlp.dense_4h_to_h': tensor([[[-254694.5938, -366830.5312,   84742.3906,  ...,  -72391.2500,
             35746.0312,   84781.8047],
          [-370850.2812,   -2695.0781,  183879.3750,  ..., -214038.8750,
            339558.8750,  126594.1250],
          [ 411689.3438,  263341.3125,   45840.7227,  ..., -218862.7812,
            -12723.5781,  -45963.3633],
          ...,
          [ 750122.1875,  204666.0938, -217900.0469,  ..., -106158.1719,
           -165504.3750, -412334.9375],
          [ -32869.5273,   56840.7969,  208140.5156,  ...,   98977.5078,
            -43380.3945,   98793.6875],
          [ 376956.0000,   79309.1094, -179397.3906,  ...,   25765.8203,
             21491.7852,  140923.7188]],
 
         [[ 461567.2500,  262055.6875,  437672.9062,  ..., -259536.3438,
            522418.6875, -102002.3438],
          [ 106191.7188, -153854.7031, -255312.6562,  ...,  189738.5625,
             55503.3438,  -59298.3164],
          [ -98582.8828,  -98150.0938,   53175.5

In [19]:
os.makedirs(test_gradients_path + "/gradients_after_ekfac", exist_ok=True)
save_file(preconditioned_grad_dict, test_gradients_path + "/gradients_after_ekfac" + "/gradients.safetensors")

In [20]:
n, o, i = 10, 32, 16
c = 32 // 8


shard_ci = torch.rand(c, i).to("cuda")

matrix_noi = torch.rand(n, o, i).to("cuda")
matrix_noi_copy = matrix_noi.clone()
start_row = 12
end_row = 16

In [21]:
result_1 = matrix_noi[:, start_row:end_row, :].multiply(
    (shard_ci + cfg.lambda_damp_factor * shard_ci.mean()).reciprocal().unsqueeze(0)
)


In [22]:
A = matrix_noi_copy[:, start_row:end_row, :]
B = (shard_ci + cfg.lambda_damp_factor * shard_ci.mean()).reciprocal()

In [23]:
A[0, :, :] * B

tensor([[0.7659, 2.6546, 1.4608, 1.2419, 1.1269, 1.1895, 2.1062, 0.4292, 0.5772,
         0.1421, 1.5482, 0.5690, 8.9974, 0.7116, 1.2178, 1.3367],
        [0.8764, 2.5108, 1.4513, 2.4445, 0.4616, 0.6970, 0.6242, 0.7996, 1.7006,
         0.5489, 1.3874, 0.9121, 1.0678, 4.2754, 0.4991, 0.5481],
        [3.9777, 0.0281, 5.2895, 2.2531, 1.0785, 1.7859, 0.3435, 0.7319, 0.9600,
         1.9925, 0.3357, 2.3339, 1.9611, 0.1425, 0.5402, 1.5362],
        [0.0583, 0.4650, 1.5651, 1.0459, 2.0959, 0.8396, 1.3358, 2.4505, 1.8738,
         0.2054, 0.4600, 2.2889, 1.6646, 1.7419, 2.1397, 1.0930]],
       device='cuda:0')

In [24]:
A[0, 0, 0] * B[0, 0]

tensor(0.7659, device='cuda:0')

In [None]:
matrix_noi[0][start_row:end_row, :]

tensor([[0.7542, 0.4275, 0.5320, 0.6569, 0.5523, 0.2436, 0.6513, 0.4454, 0.2740,
         0.1172, 0.5467, 0.4251, 0.6068, 0.7427, 0.5097, 0.6713],
        [0.7542, 0.3751, 0.6358, 0.3567, 0.0241, 0.6464, 0.4201, 0.7671, 0.9487,
         0.4990, 0.3968, 0.6267, 0.4948, 0.7221, 0.2798, 0.3859],
        [0.9352, 0.0153, 0.6294, 0.9950, 0.8613, 0.4615, 0.1293, 0.7399, 0.9321,
         0.5810, 0.2457, 0.9929, 0.1702, 0.0993, 0.3307, 0.1110],
        [0.0550, 0.1587, 0.8747, 0.2274, 0.9632, 0.7865, 0.1697, 0.6067, 0.5572,
         0.1885, 0.2298, 0.7441, 0.5814, 0.3240, 0.9716, 0.5489]],
       device='cuda:0')

In [26]:
result_1[0]

tensor([[0.7659, 2.6546, 1.4608, 1.2419, 1.1269, 1.1895, 2.1062, 0.4292, 0.5772,
         0.1421, 1.5482, 0.5690, 8.9974, 0.7116, 1.2178, 1.3367],
        [0.8764, 2.5108, 1.4513, 2.4445, 0.4616, 0.6970, 0.6242, 0.7996, 1.7006,
         0.5489, 1.3874, 0.9121, 1.0678, 4.2754, 0.4991, 0.5481],
        [3.9777, 0.0281, 5.2895, 2.2531, 1.0785, 1.7859, 0.3435, 0.7319, 0.9600,
         1.9925, 0.3357, 2.3339, 1.9611, 0.1425, 0.5402, 1.5362],
        [0.0583, 0.4650, 1.5651, 1.0459, 2.0959, 0.8396, 1.3358, 2.4505, 1.8738,
         0.2054, 0.4600, 2.2889, 1.6646, 1.7419, 2.1397, 1.0930]],
       device='cuda:0')