In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
import os
import sys
module_path = os.path.abspath(os.path.join('../../'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import os
import json
from enum import Enum
from libfb.py.pyre import none_throws

import torch
import transformers
import torch.nn as nn
import numpy as np

import callm.core.utils.utils as utils
from callm.metaformers.src.args.trainer import TrainerArgs
from callm.core.data import datautils
from callm.core.model_utils import get_mp_rank_size, get_consolidated_ckpt_path, ElasticQuantBinarizerSigned, get_torch_dtype
from callm.core.models.llama_xl.transformer import (
    Transformer,
    TransformerForCausalLM,
    TransformerForSequenceClassification,
    wrap_model,
    wrap_model_pt,
)
import llama_xl.quantized_transformer as quantized_transformer
from callm.core.utils.process_args import (
    QAT,
)
from fairscale.nn.model_parallel import initialize as fs_init

In [None]:
import os
import torch
import torch.distributed as dist
from torch.distributed.elastic.utils.distributed import get_free_port

# Set environment variables
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"

# Initialize the basic distributed process group
dist.init_process_group(backend='nccl')  # or 'gloo' for CPU

In [None]:
import fairscale.nn.model_parallel.initialize as fs_init

# Initialize model parallel groups using fairscale
if not fs_init.model_parallel_is_initialized():
    fs_init.initialize_model_parallel(
        1,
        model_parallel_backend="nccl",
        ddp_backend="nccl",
    )

In [None]:
# Load the FP model and compute the activations on the Wiki2 dataset
class ParallelImpl(str, Enum):
    FAIRSCALE = "FAIRSCALE"
    PT_D = "PT_D"
    NONE = "NONE"

class data_args:
    checkpoint_path = "xx/1_16_16/paretoq_lr_2e5/70000/"
    max_parallel_files = 1
    api_key = None

class model_args:
    input_model_filename = "xx/baselines/full_precision_models/original_xl/llama/1B/"
    input_model_local_path = None
    parallel_impl = ParallelImpl.FAIRSCALE
    # model args
    share_embedding = True
    layer_sharing = False
    custom_bwd = True
    dropout = 0
    w_bits = 1
    a_bits = 16
    kv_bits = 16
    emb_bits = 32
    output_w_bits = 32
    output_a_bits = 32

class training_args:
    qat = QAT.EXPERIMENTAL
    bf16 = True
    fp16 = False
    model_max_length = 2048

# Step 1: Load the FP mode
def load_model_xl(
    local_rank: int = 0,
    wrap: bool = False,
    generate_only: bool = False):
    pathmgr = utils.get_path_manager(
        max_parallel=data_args.max_parallel_files,
        api_key=data_args.api_key,
    )
    model_args.input_model_local_path = pathmgr.get_local_path(
        none_throws(model_args.input_model_filename)
    )
    with pathmgr.open(
        os.path.join(model_args.input_model_local_path, "params.json")
    ) as f:
        data = json.load(f)
        if "hive_data" in data:
            del data["hive_data"]
        args = TrainerArgs.from_dict(data)
        args.dtype = "bf16" if training_args.bf16 else "fp16"
        args.model.sequence_parallel = False
        args.model.loss_parallel = False
        args.model.max_length = none_throws(training_args.model_max_length)
        # pyre-fixme[16]: `ModelArgs` has no attribute `share_embedding`.
        args.model.share_embedding = none_throws(model_args.share_embedding)
        # pyre-fixme[16]: `ModelArgs` has no attribute `layer_sharing`.
        args.model.layer_sharing = none_throws(model_args.layer_sharing)
        args.model.parallel_impl = none_throws(model_args.parallel_impl)
        args.model.custom_bwd = (
            none_throws(model_args.custom_bwd)
            and not generate_only
            and args.model.parallel_impl != ParallelImpl.NONE
        )
        if model_args.dropout > 0:
            args.model.dropout = model_args.dropout
            args.model.custom_bwd = False

    consolidated_ckpt_local_path = None
    if data_args.checkpoint_path is not None:
        mp_rank, mp_size = get_mp_rank_size()
        consolidated_ckpt_local_path = pathmgr.get_local_path(
            get_consolidated_ckpt_path(
                none_throws(data_args.checkpoint_path),
                mp_rank,
                mp_size,
                none_throws(model_args.parallel_impl),
            )
        )

    config = transformers.AutoConfig.from_pretrained(model_args.input_model_local_path)
    args.model.vocab_size = config.vocab_size

    assert data_args.checkpoint_path is not None
    args.model.init.no_init = True

    if training_args.qat == QAT.EXPERIMENTAL:
        args.model.emb_bits = model_args.emb_bits
        args.model.output_w_bits = model_args.output_w_bits
        args.model.output_a_bits = model_args.output_a_bits
        args.model.kv_bits = model_args.kv_bits
        args.model.w_bits = model_args.w_bits
        args.model.a_bits = model_args.a_bits
        model = quantized_transformer.Transformer(args.model)
    else:
        model = Transformer(args.model)

    if model_args.parallel_impl != ParallelImpl.PT_D:
        assert consolidated_ckpt_local_path is not None
        state_dict = torch.load(consolidated_ckpt_local_path, map_location="cpu")

        for key in model.state_dict().keys():
            if key not in state_dict.keys() and "weight_clip_val" in key:
                weight_key = key.replace("weight_clip_val", "weight")
                x = state_dict[weight_key]
                best = torch.full(
                    [x.shape[0], 1], float("inf"), device=x.device
                ).type_as(x)
                grid = 100
                norm = 2.4
                if model_args.w_bits == 1:
                    scale = (torch.mean(x.abs(), dim=-1, keepdim=True)).detach()
                    state_dict[key] = scale
                    continue
                if model_args.w_bits == 0 or model_args.w_bits == 2:
                    scale, _ = torch.max(torch.abs(x), dim=-1, keepdim=True)
                    state_dict[key] = scale
                    continue
                if model_args.w_bits >= 3:
                    xmax, _ = torch.max(torch.abs(x), dim=-1, keepdim=True)
                    maxshrink = 0.5
                else:
                    raise NotImplementedError

                maxq = 2 ** (model_args.w_bits - 1) - 1
                scale = xmax / maxq

                state_dict[key] = scale

        print(model.load_state_dict(
            state_dict,
            strict=False,
        ))
        del state_dict

    if training_args.qat == QAT.EXPERIMENTAL:
        model = model.cuda().to(get_torch_dtype(args.dtype))
        print("Loading QAT model")
        model = quantized_transformer.TransformerForCausalLM(config, model)
    else:
        model_cls = TransformerForCausalLM
        model = model.cuda().to(get_torch_dtype(args.dtype))
        model = model_cls(config, model)

    print("finish loading checkpoint and wrap model...")
    if local_rank == 0:
        print(model)
        print("Complete model loading...")

    return model

model = load_model_xl()

In [None]:
# Step 2 Create a wiki2 dataset
model_args.llama_version = "3"
model_args.use_fast_tokenizer = True
training_args.cache_dir = None
data_args.add_eos_token = False
data_args.add_bos_token = False
data_args.data_path = "xx/data_third_party/wiki/wikitext-2/train.jsonl"
data_args.eval_data_path = "xx/data_third_party/eval/"

# Initialize tokenizer
tokenizer_class = transformers.AutoTokenizer
if "3" in model_args.llama_version:
    tokenizer_class = transformers.LlamaTokenizerFast
tokenizer = tokenizer_class.from_pretrained(
    pretrained_model_name_or_path=model_args.input_model_local_path,
    cache_dir=training_args.cache_dir,
    model_max_length=training_args.model_max_length,
    padding_side="right",
    use_fast=model_args.use_fast_tokenizer,
    # When evaluating ppl tasks, set add_eos_token to True, otherwise, set it to False.
    add_eos_token=data_args.add_eos_token,
    add_bos_token=data_args.add_bos_token,
)

# Load dataset
pathmgr = utils.get_path_manager(
        max_parallel=data_args.max_parallel_files,
        api_key=data_args.api_key,
    )
data_args.train_data_local_path = pathmgr.get_local_path(
        data_args.data_path
    )
data_args.eval_data_local_path = pathmgr.get_local_path(
                data_args.eval_data_path
            )
train_dataset, valid_dataset = datautils.get_train_val_dataset(
        train_path=data_args.train_data_local_path,
        valid_path=( None
            # os.path.join(data_args.eval_data_local_path, "wiki2/test.jsonl")
            # if data_args.eval_data_local_path is not None
            # else None
        ),
    )
train_data = datautils.CustomJsonDataset(
    train_dataset,
    tokenizer,
    block_size=min(training_args.model_max_length, 2048),
)
valid_data = datautils.CustomJsonDataset(
    valid_dataset,
    tokenizer,
    block_size=min(training_args.model_max_length, 2048),
)

In [8]:
from callm.core.models.llama.utils_quant import QuantizedEmbedding, QuantizeLinear
from callm.core.models.llama_xl.quantized_layers import QuantizedParallelEmbedding, QuantizedColumnParallelLinear, QuantizedRowParallelLinear
from collections import defaultdict
import torch.nn.functional as F
import math

device = model.device

def move_batch_to_device(batch, device):
    for key in ['input_ids', "labels"]:
        if type(batch[key]) == list:
            batch[key] = torch.tensor(np.array(batch[key]), dtype=torch.int64).to(device)
        if len(batch[key].shape) == 1:
            batch[key] = batch[key].view(1, -1)
    return batch

def evaluate_loss(model, dataset, device, maximum_count = 500):
    losses = []
    count = 0
    with torch.no_grad():
        for batch in dataset:
            batch = move_batch_to_device(batch, device)
            outputs = model(input_ids=batch['input_ids'], labels=batch['labels'])
            losses.append(outputs.loss.cpu().item())
            count += 1

            if count >= maximum_count:
                break
    return np.mean(losses), np.std(losses)


def weighted_average(x, count):
    loss = np.sum(x * count)
    total =  np.sum(count)
    return (loss / total), total


def avg_dist_dict(keys, dictionary):
    avg = {}
    for k in keys:
        v = dictionary[k]
        if len(v) > 0:
            avg_v = float(np.mean(v))
        else:
            avg_v = 0.0
        try:
            dist_avg_v, _ = weighted_average(avg_v, len(v))
        except ZeroDivisionError:
            dist_avg_v = -1
        avg[k] = dist_avg_v
    return avg

def eval_ppl(model, dataset, device, maximum_count = 500):
    model.eval()
    count = 0
    metrics_ls = defaultdict(list)
    with torch.no_grad():
        for batch in dataset:
            batch = move_batch_to_device(batch, device)
            outputs = model(input_ids=batch['input_ids'], labels=batch['labels'])
            logits = outputs.logits[:, :-1]

            y = batch["labels"][:, 1:].to(device)
            loss = F.cross_entropy(logits.flatten(0, 1), y.flatten(0, 1), reduction="sum")
            metric = loss.item()
            n_toks = y.nelement()
            metrics_ls["metric"].append(metric)
            metrics_ls["n_toks"].append(n_toks)

            count += 1

            if count >= maximum_count:
                break
    metrics = avg_dist_dict(["metric", "n_toks"], metrics_ls)
    ppl = math.exp(metrics["metric"] / metrics["n_toks"])
    return ppl

In [9]:
state_dict = model.model.state_dict()
state_dict = {key: val.detach().cpu() for key, val in state_dict.items()}

In [None]:
fp_model_loss, fp_model_loss_std = evaluate_loss(model, train_data, device, maximum_count = 200)
print("FP model loss: {:.4f} +/- {:.4f}".format(fp_model_loss, fp_model_loss_std))

In [11]:
import numpy as np
import torch
from llama.utils_quant import QuantizedEmbedding, QuantizeLinear
from llama_xl.quantized_layers import (
    QuantizedColumnParallelLinear,
    QuantizedParallelEmbedding,
    QuantizedRowParallelLinear,
)


def prepare_inputs(batch, device):
    """
    prepare inputs for model
    """
    batch = {k: v.to(device) for k, v in batch.items()}
    return batch


def normalization(vs):
    """
    normalization of a list of vectors
    return: normalized vectors v
    """
    norms = [torch.sum(v * v) for v in vs]
    norms = [(norm**0.5).cpu().item() for norm in norms]
    vs = [vi / (norms[i] + 1e-6) for (i, vi) in enumerate(vs)]
    return vs


def orthnormal(ws, vs_list):
    """
    make vector w orthogonal to each vector in v_list.
    afterwards, normalize the output w
    """
    for vs in vs_list:
        for w, v in zip(ws, vs):
            w.data.add_(-v * (torch.sum(w * v)))
    return normalization(ws)


def get_layers(
    module,
    layers=[
        torch.nn.Linear,
        QuantizeLinear,
        QuantizedEmbedding,
        torch.nn.Embedding,
        QuantizedRowParallelLinear,
        QuantizedColumnParallelLinear,
        QuantizedParallelEmbedding,
    ],
    name: str = "",
):
    if (
        type(module)
        in [QuantizedEmbedding, torch.nn.Embedding, QuantizedParallelEmbedding]
        and type(module) in layers
    ):
        return {"embed_tokens": module}
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(
            get_layers(
                child, layers=layers, name=name + "." + name1 if name != "" else name1
            )
        )
    return res


def compute_eigenvalue(model, loss, device, maxIter=200, tol=1e-3, top_n=1):
    """Calculate Top Eigenvalue of Hessian"""
    # # Get parameters and gradients of corresponding layer
    # batch = prepare_inputs(batch, device)
    # outputs = model(**batch)
    # loss = outputs.loss

    layers = get_layers(model)
    weights = [module.weight for name, module in layers.items()]
    model.zero_grad()
    """ use negative loss to get the minimum eigenvalue here """
    gradients = torch.autograd.grad(loss, weights, retain_graph=True, create_graph=True)

    topn_eigenvalues = []
    eigenvectors = []
    computed_dim = 0
    while computed_dim < top_n:
        eigenvalues = None
        vs = [torch.randn_like(weight) for weight in weights]  # generate random vector
        vs = normalization(vs)  # normalize the vector

        for _ in range(maxIter):
            vs = orthnormal(vs, eigenvectors)
            model.zero_grad()

            Hvs = torch.autograd.grad(
                gradients, weights, grad_outputs=vs, retain_graph=True
            )
            tmp_eigenvalues = [
                torch.sum(Hv * v).cpu().item() for (Hv, v) in zip(Hvs, vs)
            ]

            vs = normalization(Hvs)

            if eigenvalues == None:
                eigenvalues = tmp_eigenvalues
            else:
                if (
                    abs(sum(eigenvalues) - sum(tmp_eigenvalues))
                    / (abs(sum(eigenvalues)) + 1e-6)
                    < tol
                ):
                    break
                else:
                    eigenvalues = tmp_eigenvalues
        topn_eigenvalues.append(eigenvalues)
        eigenvectors.append(vs)
        computed_dim += 1

    return topn_eigenvalues, eigenvectors


def compute_hessian_traces(model, loss, device, maxIter=200, tol=1e-4):
    # batch = prepare_inputs(batch, device)
    # outputs = model(**batch)
    # loss = outputs.loss

    layers = get_layers(model)
    weights = []
    for name, module in layers.items():
        weights.append(module.weight)
    model.zero_grad()
    gradients = torch.autograd.grad(loss, weights, retain_graph=True, create_graph=True)

    layer_traces = []
    trace_vhv = []
    trace = 0.0
    for _ in range(maxIter):
        vs = [torch.randint_like(weight, high=2) for weight in weights]

        for v in vs:
            v[v == 0] = -1

        model.zero_grad()
        Hvs = torch.autograd.grad(
            gradients, weights, grad_outputs=vs, retain_graph=True
        )
        tmp_layer_traces = np.array(
            [torch.sum(Hv * v).cpu().item() for Hv, v in zip(Hvs, vs)]
        )

        layer_traces.append(tmp_layer_traces)
        trace_vhv.append(sum(tmp_layer_traces))

        if abs(np.mean(trace_vhv) - trace) / (abs(trace) + 1e-6) < tol:
            break
        else:
            trace = np.mean(trace_vhv)
    return np.mean(np.array(layer_traces), axis=0)


In [12]:
def get_linear_layers(
    module,
    layers=[
        torch.nn.Linear,
        QuantizeLinear,
        QuantizedRowParallelLinear,
        QuantizedColumnParallelLinear,
    ],
    name: str = "",
):
    if (
        type(module)
        in [QuantizedEmbedding, torch.nn.Embedding, QuantizedParallelEmbedding]
        and type(module) in layers
    ):
        return {"embed_tokens": module}
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(
            get_linear_layers(
                child, layers=layers, name=name + "." + name1 if name != "" else name1
            )
        )
    return res

def get_embedding_layers(
    module,
    layers=[
        QuantizedEmbedding,
        torch.nn.Embedding,
        QuantizedParallelEmbedding,
    ],
    name: str = "",
):
    if (
        type(module)
        in [QuantizedEmbedding, torch.nn.Embedding, QuantizedParallelEmbedding]
        and type(module) in layers
    ):
        return {"embed_tokens": module}
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(
            get_embedding_layers(
                child, layers=layers, name=name + "." + name1 if name != "" else name1
            )
        )
    return res

def get_model_param_list(model):
    layers = get_linear_layers(model) # change here
    weights = []
    for name, module in layers.items():
        weights.append(module.weight)
    return weights


def hvp(model, loss, v, create_graph=False):
    """
    Compute H @ v where H is the Hessian of the training loss w.r.t. parameters.
    Uses a single mini-batch to approximate the training loss Hessian.
    """
    param_list = get_model_param_list(model)
    grads = torch.autograd.grad(loss, param_list, retain_graph=True, create_graph=True)
    g = torch.cat([gi.reshape(-1) for gi in grads])

    # form g^T v and take gradient again
    gv = (g * v).sum()
    Hv = torch.autograd.grad(
        gv, param_list, retain_graph=True, create_graph=False
    )
    Hv = torch.cat([h.reshape(-1) for h in Hv]).detach()
    return Hv


def lanczos_tridiag(hvp_fn, dim, m, v0=None, device="cuda:0"):
    """
    Run m steps of Lanczos on implicit symmetric operator H using HVPs.
    Returns alpha (diag), beta (off-diag), and the first Lanczos vector q1.
    """
    if v0 is None:
        v = torch.randn(dim, device=device)
    else:
        v = v0.clone().detach().to(device)
    q = v / (v.norm() + 1e-12)
    Q1 = q.clone()  # return q1 for weight computation
    alphas = []
    betas = []
    prev_q = torch.zeros_like(q)

    for _ in range(m):
        z = hvp_fn(q)
        alpha = torch.dot(q, z).item()
        z = z - alpha * q - (betas[-1] * prev_q if betas else 0)
        beta = z.norm().item()
        alphas.append(alpha)
        betas.append(beta)
        prev_q, q = (
            q,
            (z / (beta + 1e-12)) if beta > 1e-14 else (z * 0.0),
        )  # stop if breakdown
        if beta < 1e-14:
            break

    # betas length is m; last beta is the next-iteration beta (unused in T)
    return (
        np.array(alphas, dtype=np.float64),
        np.array(betas[:-1], dtype=np.float64),
        Q1,
    )


def hvp_flat(model, loss, v_flat):
    # map flat vector to parameter-shaped list for autograd
    param_list = get_model_param_list(model)
    v_slices = []
    idx = 0
    for p in param_list:
        n = p.numel()
        v_slices.append(v_flat[idx : idx + n].view_as(p))
        idx += n
    v = torch.cat(
        [vi.reshape(-1) for vi in v_slices]
    )  # flat again for our hvp() interface
    return hvp(model, loss, v, create_graph=False)


def slq_density(model, loss, n_probes=20, m=50, sigma=1e-3, grid=None, device="cuda:0"):
    """
    Returns: grid (np.array), density (np.array), ritz_values_all (list), approx_lam_min/max
    """
    param_list = get_model_param_list(model)
    dim = sum(p.numel() for p in param_list)
    # We'll accumulate Ritz values to set a data-driven grid after a warmup
    grid_min, grid_max = +np.inf, -np.inf
    thetas_all = []
    weights_all = []

    for _ in range(n_probes):
        v0 = [torch.randn_like(p) for p in param_list]
        v0 = torch.cat([vi.reshape(-1) for vi in v0])
        v0 = v0 / (v0.norm() + 1e-12)
        alphas, betas, q1 = lanczos_tridiag(
            lambda v: hvp_flat(model, loss, v), dim, m, v0=v0
        )

        # Build T
        k = len(alphas)
        T = np.diag(alphas[:k])
        if k > 1:
            off = betas[: k - 1]
            T += np.diag(off, 1) + np.diag(off, -1)

        # Eigendecompose small tridiagonal
        theta, U = np.linalg.eigh(T)
        # weights are (first component)^2
        w = (U[0, :] ** 2) * (v0.norm().item() ** 2)
        thetas_all.append(theta)
        weights_all.append(w)
        grid_min = min(grid_min, theta.min())
        grid_max = max(grid_max, theta.max())

    # Construct grid with a small margin
    pad = 0.05 * (grid_max - grid_min + 1e-12)
    grid = np.linspace(grid_min - pad, grid_max + pad, 400)

    # Accumulate Gaussian kernels on the grid
    density = np.zeros_like(grid, dtype=np.float64)
    lam_mins, lam_maxs = [], []

    for i in range(n_probes):
        theta = thetas_all[i]
        w = weights_all[i]

        # Gaussian accumulation
        for ti, wi in zip(theta, w):
            density += (
                wi
                * np.exp(-0.5 * ((grid - ti) / sigma) ** 2)
                / (np.sqrt(2 * np.pi) * sigma)
            )

        lam_mins.append(theta.min())
        lam_maxs.append(theta.max())

    density /= n_probes
    approx_lam_min = float(np.median(lam_mins))
    approx_lam_max = float(np.median(lam_maxs))
    return grid, density, thetas_all, weights_all, approx_lam_min, approx_lam_max


In [14]:
from llama_xl.quantized_layers import QuantizedParallelEmbedding, QuantizedColumnParallelLinear, QuantizedRowParallelLinear
from llama_xl.multiple_bits_quantized_layers import MultiBitsQuantizedRowParallelLinear, MultiBitsQuantizedColumnParallelLinear
from callm.core.models.llama_xl.utils_quant import (
    AsymQuantizer,
    ElasticQuantBinarizerSigned,
    ElasticQuantBinarizerUnsigned,
    ElasticQuantN2UQ,
    SymQuantizer,
)

def convert_weight(layer, w_bits=2):
    real_weights = layer.weight
    if type(layer) == QuantizedParallelEmbedding:
        if layer.w_bits is not None and layer.w_bits >= 16:
            weight = layer.weight
        elif layer.w_bits is not None and layer.w_bits >= 4:
            weight = SymQuantizer.apply(
                real_weights, layer.weight_clip_val, layer.w_bits, False
            )
        elif layer.w_bits == 2:
            weight = ElasticQuantN2UQ.apply(
                real_weights,
                layer.weight_clip_val,
                layer.w_bits,
                False,
            )
        else:
            weight = ElasticQuantBinarizerSigned.apply(
                real_weights,
                layer.weight_clip_val,
                layer.w_bits,
                False,
            )
    elif type(layer) == QuantizedColumnParallelLinear:
        if layer.w_bits is None or layer.w_bits >= 16:
            weight = layer.weight
        elif layer.w_bits is not None and layer.w_bits > 4:
            weight = SymQuantizer.apply(
                real_weights,
                layer.weight_clip_val,
                layer.w_bits,
                layer.weight_layerwise,
            )
        elif layer.w_bits is not None and (layer.w_bits == 2 or layer.w_bits == 0 or layer.w_bits == 1):
            weight = ElasticQuantN2UQ.apply(
                real_weights,
                layer.weight_clip_val,
                layer.w_bits,
                layer.weight_layerwise,
            )
        else:
            weight = ElasticQuantBinarizerSigned.apply(
                real_weights,
                layer.weight_clip_val,
                layer.w_bits,
                layer.weight_layerwise,
            )
    elif type(layer) == QuantizedRowParallelLinear:
        if layer.w_bits is None or layer.w_bits >= 16:
            weight = layer.weight
        elif layer.w_bits is not None and layer.w_bits > 4:
            weight = SymQuantizer.apply(
                real_weights, layer.weight_clip_val, layer.w_bits, layer.weight_layerwise
            )
        elif layer.w_bits is not None and (layer.w_bits == 2 or layer.w_bits == 0 or layer.w_bits == 1):
            weight = ElasticQuantN2UQ.apply(
                real_weights,
                layer.weight_clip_val,
                layer.w_bits,
                layer.weight_layerwise,
            )
        else:
            weight = ElasticQuantBinarizerSigned.apply(
                real_weights,
                layer.weight_clip_val,
                layer.w_bits,
                layer.weight_layerwise,
            )
    elif type(layer) == MultiBitsQuantizedColumnParallelLinear:
        weight_clip_val = layer.weight_clip_val_list[str(int(w_bits))]
        if w_bits is None or w_bits >= 16:
            weight = layer.weight
        elif w_bits is not None and w_bits > 4:
            weight = SymQuantizer.apply(
                real_weights,
                weight_clip_val,
                w_bits,
                layer.weight_layerwise,
            )
        elif w_bits is not None and (w_bits == 2 or w_bits == 0 or w_bits == 1):
            weight = ElasticQuantN2UQ.apply(
                real_weights,
                weight_clip_val,
                w_bits,
                layer.weight_layerwise,
            )
        else:
            weight = ElasticQuantBinarizerSigned.apply(
                real_weights,
                weight_clip_val,
                w_bits,
                layer.weight_layerwise,
            )
    elif type(layer) == MultiBitsQuantizedRowParallelLinear:
        weight_clip_val = layer.weight_clip_val_list[str(int(w_bits))]
        if w_bits is None or w_bits >= 16:
            weight = layer.weight
        elif w_bits is not None and w_bits > 4:
            weight = SymQuantizer.apply(
                real_weights, weight_clip_val, w_bits, layer.weight_layerwise
            )
        elif w_bits is not None and (w_bits == 2 or w_bits == 0 or w_bits == 1):
            weight = ElasticQuantN2UQ.apply(
                real_weights,
                weight_clip_val,
                w_bits,
                layer.weight_layerwise,
            )
        else:
            weight = ElasticQuantBinarizerSigned.apply(
                real_weights,
                weight_clip_val,
                w_bits,
                layer.weight_layerwise,
            )
    else:
        raise NotImplementedError
    layer.weight.data = weight

def post_training(model):
    named_layers = get_layers(model)
    for name, layer in named_layers.items():
        convert_weight(layer, w_bits=4)


post_training(model)
quantized_state_dict = model.model.state_dict()
quantized_state_dict = {key: val.detach().cpu() for key, val in quantized_state_dict.items()}
def interpolate_two_state_dict(state_dict1, state_dict2, alpha):
    state_dict = {}
    for key in state_dict1.keys():
        if key in state_dict2.keys():
            state_dict[key] = (1 - alpha) * state_dict1[key].clone() + alpha * state_dict2[key].clone()
        else:
            print("Key {} not found in state_dict2".format(key))
    return state_dict

alpha = 0
new_state_dict = interpolate_two_state_dict(state_dict, quantized_state_dict, alpha)
new_state_dict = {key: val.to(device) for key, val in new_state_dict.items()}
model.model.load_state_dict(new_state_dict)

<All keys matched successfully>

In [15]:
fp_model_loss, fp_model_loss_std = evaluate_loss(model, train_data, device, maximum_count = 200)
print("FP model loss: {:.4f} +/- {:.4f}".format(fp_model_loss, fp_model_loss_std))

FP model loss: 3.0466 +/- 0.2331


In [16]:
keys = [key for key in state_dict.keys() if "layers" in key]

quantized_errors = {key: [] for key in keys}
weight_norms = {key: [] for key in keys}

def get_quantize_error(state_dict_1, state_dict_2, key):
    if key not in state_dict_1.keys() or key not in state_dict_2.keys():
        return 0, 0, 0
    return torch.norm(state_dict_1[key] - state_dict_2[key]).cpu().item(), torch.norm(state_dict_1[key]).cpu().item(), torch.norm(state_dict_2[key]).cpu().item()

new_state_dict = model.model.state_dict()
new_state_dict = {key: val.to("cpu") for key, val in new_state_dict.items()}

for key in keys:
    gap, norm_latent_weights, norm_quantized_weights = get_quantize_error(state_dict, new_state_dict, key)
    quantized_errors[key].append(gap)
    weight_norms[key].append(norm_latent_weights)

quantized_errors = np.array(list(quantized_errors.values()))
weight_norms = np.array(list(weight_norms.values()))
print("error norm: {:.4f}".format(np.sqrt(np.sum(np.square(quantized_errors)))))
print("weight norm: {:.4f}".format(np.sqrt(np.sum(np.square(weight_norms)))))
print("error ratio: {:.4f}".format(
    np.sqrt(np.sum(np.square(quantized_errors)))/np.sqrt(np.sum(np.square(weight_norms)))
))

error norm: 0.0000
weight norm: 636.4101
error ratio: 0.0000


In [None]:
n_probes = 20      # increase for smoother/less noisy density
m_steps  = 60      # Lanczos steps (resolution); 50–100 is common
sigma    = 5e-3    # kernel bandwidth for smoothing; tune per scale

for i, batch in enumerate(train_data):
    batch = move_batch_to_device(batch, device)
    outputs = model(input_ids=batch['input_ids'], labels=batch['labels'])
    loss = outputs.loss

    grid, dens, thetas_all, weights_all, lam_min, lam_max = slq_density(model, loss, n_probes=n_probes, m=m_steps, sigma=sigma)

    np.save(f"xx/notebooks/hessian_spectrum/1_bits_steps_40k_alpha_0.0/grid_{i}.npy", grid)
    np.save(f"xx/notebooks/hessian_spectrum/1_bits_steps_40k_alpha_0.0/density_{i}.npy", dens)
    np.save(f"xx/notebooks/hessian_spectrum/1_bits_steps_40k_alpha_0.0/thetas_{i}.npy", thetas_all)
    np.save(f"xx/notebooks/hessian_spectrum/1_bits_steps_40k_alpha_0.0/weights_{i}.npy", weights_all)
    np.save(f"xx/notebooks/hessian_spectrum/1_bits_steps_40k_alpha_0.0/min_and_max_{i}.npy", np.array([lam_min, lam_max]))
    if i >= 9: break

In [None]:
# compute gradient norm
for step in [0, 5000] + list(np.arange(10000, 80001, 10000)):
    if step == 0:
        data_args.checkpoint_path = "xx/baselines/full_precision_models/original_xl/llama/1B/"
    else:
        data_args.checkpoint_path = f"xx/1_16_16/paretoq_lr_2e5/{int(step)}/"
    model = load_model_xl()

    count = 0
    gradient_norms =[]
    param_list = get_model_param_list(model)
    weight_norm = 0
    for param in param_list:
        weight_norm += torch.norm(param).item()**2
    weight_norm = weight_norm ** 0.5
    print(weight_norm)

    for batch in train_data:
        batch = move_batch_to_device(batch, device)
        outputs = model(input_ids=batch['input_ids'], labels=batch['labels'])
        loss = outputs.loss
        gradient = torch.autograd.grad(loss, param_list, create_graph=False, retain_graph=False)
        gradient = [g.view(-1) for g in gradient]
        print("Gradient norm: {:.4f}".format(torch.norm(torch.concatenate(gradient)).item()))
        gradient_norms.append(torch.norm(torch.concatenate(gradient)).item())
        count += 1
        if count >= 20:
            break

    gradient_norms = np.array(gradient_norms)/weight_norm
    print(np.mean(gradient_norms), np.std(gradient_norms))

In [None]:
from torch.nn.attention import sdpa_kernel, SDPBackend

pathmgr = utils.get_path_manager(
        max_parallel=data_args.max_parallel_files,
        api_key=data_args.api_key,
    )

def get_consolidated_ckpt_path(
    ckpt_dir: str,
    mp_rank: int = 0,
    mp_size: int = 1,
    parallel_impl = None
) -> str:
    if mp_size == 1:
        assert mp_rank == 0
        return os.path.join(ckpt_dir, "consolidated.pth")
    else:
        return os.path.join(ckpt_dir, f"consolidated.{mp_rank:02d}.pth")

def get_state_dict(path):
    mp_rank, mp_size = 0, 1
    path = pathmgr.get_local_path(
        get_consolidated_ckpt_path(
            none_throws(path),
            mp_rank,
            mp_size,
        )
    )
    state_dict = torch.load(path, map_location=model.device)
    return state_dict

def interpolate_two_state_dict(state_dict1, state_dict2, alpha):
    state_dict = {}
    for key in state_dict1.keys():
        if key in state_dict2.keys():
            state_dict[key] = (1 - alpha) * state_dict1[key].clone() + alpha * state_dict2[key].clone()
        else:
            print("Key {} not found in state_dict2".format(key))
    return state_dict

def compute_hessian_on_current_model(model, train_data, device, maximum_count=20):
    # compute the hessian statistics
    traces = []; eigenvalues = []
    count = 0
    with sdpa_kernel(SDPBackend.MATH):
        for batch in train_data:
            batch = move_batch_to_device(batch, device)
            outputs = model(input_ids=batch['input_ids'], labels=batch['labels'])
            loss = outputs.loss
            cur_layer_traces = compute_hessian_traces(model, loss, device)
            # cur_eigenvalues, _ = compute_eigenvalue(model, loss, device, top_n=1)

            traces.append(cur_layer_traces)
            # eigenvalues.append(np.array(cur_eigenvalues[0]))

            count += 1
            if count >= maximum_count:
                break

    traces = np.array(traces)
    # eigenvalues = np.array(eigenvalues)
    print("Embedding layer avg trace value {:.2f} +/- {:.2f}".format(np.mean(traces[:, 0]), np.std(traces[:, 0])))
    print("Transformer layers avg trace value {:.2f} +/- {:.2f}".format(np.mean(traces[:, 1:].sum(axis=-1)), np.std(traces[:, 1:].sum(axis=-1))))
    print("Avg trace value {:.2f} +/- {:.2f}".format(np.mean(traces.sum(axis=-1)), np.std(traces.sum(axis=-1))))
    # "Avg top eigenvalue {:.2f} +/- {:.2f}".format(np.mean(eigenvalues.sum(axis=-1)), np.std(eigenvalues.sum(axis=-1)))
    return traces, eigenvalues

model.training=False # enable standard loss computation

traces, eigenvalues = compute_hessian_on_current_model(model, train_data, device, maximum_count=10)

In [None]:
from typing import Literal, Tuple, Optional, Union
import numpy as np

ArrayLike = Union[np.ndarray, list, tuple]

def _trapz(y: np.ndarray, x: np.ndarray) -> float:
    """Scalar trapezoidal integral; assumes x is 1D, y is broadcastable to x."""
    return float(np.trapz(y, x))

def _validate_grid(grid: np.ndarray, positive_required: bool = False) -> None:
    if grid.ndim != 1:
        raise ValueError("grid must be a 1D array")
    if np.any(~np.isfinite(grid)):
        raise ValueError("grid contains non-finite values")
    if np.any(np.diff(grid) <= 0):
        raise ValueError("grid must be strictly increasing")
    if positive_required and np.any(grid <= 0):
        raise ValueError("log-space requires grid > 0 (since x = log(lambda))")

def _to_2d(a: np.ndarray) -> np.ndarray:
    """Ensure densities are 2D: (n_samples, G)."""
    if a.ndim == 1:
        return a[None, :]
    if a.ndim == 2:
        return a
    raise ValueError("densities must be 1D or 2D (n_samples, G)")

def normalize_density(
    grid: ArrayLike,
    density: ArrayLike,
    space: Literal["linear", "log"] = "linear",
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Normalize density curves to integrate to 1 in the appropriate variable.

    Parameters
    ----------
    grid : array (G,)
        λ grid if space='linear'; λ grid (still λ) if space='log' (we take x = log(λ)).
    density : array (G,) or (n, G)
        p(λ) if space='linear'; q(x) if space='log', where x = log(λ).
    space : {'linear','log'}
        Interpretation of 'density' and the integration variable.

    Returns
    -------
    density_norm : array (n, G)
        Normalized densities.
    grid_out : array (G,)
        Echo of grid (λ grid).
    """
    grid = np.asarray(grid, dtype=float)
    dens = np.asarray(density, dtype=float)
    dens = _to_2d(dens)

    if space == "linear":
        _validate_grid(grid, positive_required=False)
        # Integrate p(λ) dλ
        masses = np.array([_trapz(dens[i], grid) for i in range(dens.shape[0])])
    elif space == "log":
        _validate_grid(grid, positive_required=True)
        x = np.log(grid)
        # Integrate q(x) dx
        masses = np.array([_trapz(dens[i], x) for i in range(dens.shape[0])])
    else:
        raise ValueError("space must be 'linear' or 'log'")

    # Avoid division by zero
    masses[masses == 0] = np.nan
    dens_norm = dens / masses[:, None]
    return dens_norm, grid

def spectral_moment(
    grid: ArrayLike,
    density: ArrayLike,
    k: int = 1,
    space: Literal["linear", "log"] = "linear",
    normalize: bool = True,
) -> np.ndarray:
    """
    Compute the k-th spectral moment μ_k = ∫ λ^k p(λ) dλ (linear space)
    or μ_k = ∫ λ^k q(x) dx with x=log(λ) and λ=e^x (log space).

    Parameters
    ----------
    grid : array (G,)
        λ grid (must be strictly increasing).
    density : array (G,) or (n, G)
        p(λ) if space='linear'; q(x) if space='log'.
    k : int
        Moment order (k=1 gives ∫ λ p(λ) dλ).
    space : {'linear','log'}
        Whether the provided density is with respect to dλ ('linear')
        or dx where x=log(λ) ('log').
    normalize : bool
        If True, first normalize each density to integrate to 1 in its space.

    Returns
    -------
    moments : array (n,)
        The k-th moment for each density.
    """
    grid = np.asarray(grid, dtype=float)
    dens = np.asarray(density, dtype=float)
    dens = _to_2d(dens)

    if normalize:
        dens, grid = normalize_density(grid, dens, space=space)

    if space == "linear":
        # μ_k = ∫ λ^k p(λ) dλ
        integrands = dens * (grid[None, :] ** k)
        moments = np.array([_trapz(integrands[i], grid) for i in range(dens.shape[0])])
    elif space == "log":
        # μ_k = ∫ (e^x)^k q(x) dx = ∫ e^{k x} q(x) dx
        x = np.log(grid)
        integrands = dens * np.exp(k * x[None, :])
        moments = np.array([_trapz(integrands[i], x) for i in range(dens.shape[0])])
    else:
        raise ValueError("space must be 'linear' or 'log'")
    return moments

def trace_from_density(
    grid: ArrayLike,
    density: ArrayLike,
    matrix_size: int,
    space: Literal["linear", "log"] = "linear",
    normalize: bool = True,
) -> np.ndarray:
    """
    Compute tr(H) from discretized spectrum on a grid:
        tr(H) = n * ∫ λ p(λ) dλ  (linear space)
        tr(H) = n * ∫ e^x q(x) dx (log space, x = log λ)

    Parameters
    ----------
    grid : array (G,)
        λ grid (strictly increasing).
    density : array (G,) or (n, G)
        p(λ) if space='linear'; q(x) if space='log'.
    matrix_size : int
        n = dimension of H (number of eigenvalues).
    space : {'linear','log'}
        Interpretation of density.
    normalize : bool
        Normalize each density before computing the moment.

    Returns
    -------
    traces : array (n,)
        Estimated traces for each provided density.
    """
    moments1 = spectral_moment(grid, density, k=1, space=space, normalize=normalize)
    return matrix_size * moments1


param_list = get_model_param_list(model)
dim = sum(p.numel() for p in param_list)
trace_from_density(grid, dens, dim, normalize=True)

In [None]:
import matplotlib.pyplot as plt

# Plot density
# --- Plot the estimated spectral density ---
plt.figure(figsize=(7,4))
plt.scatter(grid, np.log(dens))

# plt.axvline(lam_min, linestyle="--", label="~ λ_min (Ritz median)")
# plt.axvline(lam_max, linestyle="--", label="~ λ_max (Ritz median)")
# plt.xticks(np.arange(-1, 1.1, 0.5))
plt.xlabel("Eigenvalue λ")
plt.ylabel("Estimated density ρ(λ)")
plt.title("Hessian spectrum (of the embedding layers)")
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt

# Plot density
# --- Plot the estimated spectral density ---
plt.figure(figsize=(7,4))
plt.scatter(grid, np.log(dens), label="SLQ density estimate")
# plt.axvline(lam_min, linestyle="--", label="~ λ_min (Ritz median)")
# plt.axvline(lam_max, linestyle="--", label="~ λ_max (Ritz median)")
plt.xlabel("Eigenvalue λ")
plt.ylabel("Estimated density ρ(λ)")
plt.title("Hessian spectrum (SLQ approximation)")
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
np.save("xx/notebooks/slq_grid_2bit_10k.npy", grid)
np.save("xx/notebooks/slq_density_2bit_10k.npy", dens)
np.save("xx/notebooks/slq_ritz_all_2bit_10k.npy", ritz_all)
lam_min, lam_max
# 1-bit (-15.938819445322734, 16.789593178943367)
# (-127.41929134942242, 123.44032548997839)
# 2-bit (-69.01394941459208, 66.09353891130456)

In [None]:
grid = np.load("xx/notebooks/hessian_spectrum/slq_grid_2bit.npy")
dens = np.load("xx/notebooks/hessian_spectrum/slq_density_2bit.npy")
ritz_all = np.load("xx/notebooks/hessian_spectrum/slq_ritz_all_2bit.npy")