## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [10]:
import torch
import numpy as np
from transformers import GPTNeoXForCausalLM, AutoTokenizer
from safetensors.torch import load_file
from bergson.data import IndexConfig
from bergson.approx_unrolling.utils import TensorDict
import json
import os
from dataclasses import asdict
from datetime import timedelta

import torch
import torch.distributed as dist
from datasets import Dataset, IterableDataset
from torch.distributed.fsdp import fully_shard
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, PreTrainedModel

from bergson.data import IndexConfig, allocate_batches, DataConfig
from bergson.distributed import distributed_computing
from bergson.gradients import GradientProcessor
from bergson.hessians.covariance_all_factors import (
    compute_covariance,
    compute_eigendecomposition,
    compute_eigenvalue_correction,
)
from bergson.utils import assert_type, get_layer_list
import os

## 0. Hyperparameters

In [11]:
cfg = IndexConfig(run_path="")  # empty run path because we are not using it to save data
cfg.model = "EleutherAI/Pythia-14m"
cfg.precision = "fp16"
cfg.revision = None
cfg.fsdp = False
cfg.normalizer = "none"
cfg.fisher_fourth_root = False
cfg.data = DataConfig("NeelNanda/pile-10k")

In [12]:
match cfg.precision:
    case "bf16":
        dtype = torch.bfloat16
    case "fp16":
        dtype = torch.float16
    case "fp32":
        dtype = torch.float32
    case "int4" | "int8":
        dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    case other:
        raise ValueError(f"Unsupported precision: {other}")

## 1. Loading model and data

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    cfg.model,
    device_map="cuda",
    quantization_config=(
        BitsAndBytesConfig(
            load_in_4bit=cfg.precision == "int4",
            load_in_8bit=cfg.precision == "int8",
            bnb_4bit_compute_dtype=dtype,
            bnb_4bit_quant_storage=dtype,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
        )
        if cfg.precision in ("int4", "int8")
        else None
    ),
    torch_dtype=dtype,
    revision=cfg.revision,
)