## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import os
from dataclasses import asdict
from datetime import timedelta

import torch
import torch.distributed as dist
from datasets import Dataset, IterableDataset
from torch.distributed.fsdp import fully_shard
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, PreTrainedModel

from bergson.data import IndexConfig, allocate_batches
from bergson.distributed import distributed_computing
from bergson.gradients import GradientProcessor
from bergson.hessians.covariance_all_factors import EkfacComputer
from bergson.utils import assert_type, get_layer_list


## Playground

In [3]:
torch.set_default_device("cuda")

In [4]:
batch_size = 500
in_dim = 53
out_dim = 61
x = torch.randn(batch_size, in_dim)
A = torch.randn(in_dim, out_dim)
y = x @ A

In [5]:
xty = torch.einsum("bi,bo->bio", x, y)

In [None]:
xty[0]

tensor([[ -1.6143,   0.1879,   1.5810,  ...,  -0.6866,   0.8174,   4.8755],
        [ -6.5209,   0.7591,   6.3867,  ...,  -2.7734,   3.3020,  19.6950],
        [ -6.0131,   0.7000,   5.8893,  ...,  -2.5574,   3.0448,  18.1613],
        ...,
        [ -3.9590,   0.4608,   3.8775,  ...,  -1.6838,   2.0047,  11.9574],
        [ 14.2626,  -1.6602, -13.9690,  ...,   6.0660,  -7.2221, -43.0772],
        [ -1.0868,   0.1265,   1.0644,  ...,  -0.4622,   0.5503,   3.2823]],
       device='cuda:0')

In [None]:
xty_flat = xty.flatten(1)
xty_flat.shape

torch.Size([500, 3233])

In [None]:
xty_flat_outer = torch.einsum("bi,bj->bij", xty_flat, xty_flat)
xty_flat_outer.shape

torch.Size([500, 3233, 3233])

In [9]:
final_1 = xty_flat_outer.sum(dim=0)

In [10]:
xtx = torch.einsum("bi,bj->bij", x, x)

In [None]:
xtx_transformed = torch.einsum("bji,ik->bkj", xtx, A)
xtx_transformed.shape

torch.Size([500, 61, 53])

In [12]:
xtx_transformed_flat = xtx_transformed.flatten(1)

In [13]:
xtx_transformed_flat_outer = torch.einsum("bi,bj->bij", xtx_transformed_flat, xtx_transformed_flat)

In [14]:
final_2 = xtx_transformed_flat_outer.sum(dim=0)

In [None]:
final_2 - final_1

tensor([[ 0.0000e+00,  8.9059e+02,  1.0722e+04,  ...,  9.5906e+02,
         -2.0118e+03, -1.2207e-04],
        [ 8.9059e+02,  1.2562e+04, -3.1391e+03,  ..., -3.9731e+02,
         -9.6996e+01, -1.0428e+03],
        [ 1.0722e+04, -3.1391e+03,  1.0075e+04,  ..., -3.2259e+02,
          1.1691e+03, -1.4450e+02],
        ...,
        [ 9.5906e+02, -3.9731e+02, -3.2259e+02,  ..., -1.0917e+03,
         -2.2044e+03,  9.2240e+03],
        [-2.0118e+03, -9.6996e+01,  1.1691e+03,  ..., -2.2044e+03,
         -3.0092e+03, -8.2640e+03],
        [-1.2207e-04, -1.0428e+03, -1.4450e+02,  ...,  9.2240e+03,
         -8.2640e+03, -1.9531e-03]], device='cuda:0')