In [None]:
%env CUDA_VISIBLE_DEVICES=6
%env TRANSFORMERS_CACHE=/mnt/LLM/hub

%env HF_HOME=/mnt/LLM
%env OMP_NUM_THREADS=16

import os
import sys
sys.path.insert(0, '..')

import math
from argparse import Namespace
from typing import Sequence, Optional, List, Tuple

import torch
import torch.nn as nn
import transformers
from tqdm import trange

from prekv import datautils, modelutils
from prekv.quantizers import QuantizerBase, HiggsQuantizer
from prekv.linear_utils import fit_linear_regression
from train_predictors import OutputCatcher, get_predictor, get_dequant_values, compute_relative_mse
from ppl import evaluate_perplexity
from datasets import load_dataset
from prekv.cache_utils import TreatPrefixSeparately,PredictorHiggsCache,SingleChunkQuantizedCacheWithPredictors
from functools import partial

In [None]:
def make_arg_parser():
    import argparse

    parser = argparse.ArgumentParser(add_help=True)

    parser.add_argument(
        "--model_name",
        default = "unsloth/Llama-3.2-3B",
        type=str,
        help="path to llama model to load, as in LlamaForCausalLM.from_pretrained()",
    )
    parser.add_argument(
        "--dataset",
        default="pajama",
        type=str,
        help="Dataset name [c4, pajama] or path to data where to extract calibration data from.",
    )
    parser.add_argument(
        "--torch_dtype",
        type=str,
        default="auto",
        choices=["auto", "float16", "float32", "bfloat16"],
        help="dtype to load the model in",
    )
    parser.add_argument(
        "--compute_dtype",
        type=str,
        default=None,
        help="dtype for computing activations",
    )
    parser.add_argument(
        "--model_seqlen",
        type=int,
        default=8192,
        help="Model seqlen and calibration data context length.",
    )
    parser.add_argument("--devices",
                        metavar="N",
                        type=str,
                        nargs="+",
                        default=None,
                        help="List of devices")
    parser.add_argument(
        "--seed",
        type=int,
        default=0,
        help="Seed for calibration data and initialization. "
             "Note that the main training is not strictly deterministic.",
    )
    parser.add_argument(
        "--offload_activations",
        action="store_true",
        help="Offload activations to RAM to save GPU memory.",
    )

    parser.add_argument(
        "--total_nsamples",
        type=int,
        default=256,
        help="Number of calibration data samples.If None take all calibration data.",
    )
    parser.add_argument(
        "--valid_nsamples",
        type=int,
        default=32,
        help="Number of calibration data samples.If None take all calibration data.",
    )
    parser.add_argument(
        "--chunk_size",
        type=int,
        default=4096,
        help="Number of tokens in one chunk.",
    )

    parser.add_argument(
        "--percdamp",
        type=float,
        default=1e-3,
        help="Percent of the average Hessian diagonal to use for dampening.",
    )

    parser.add_argument(
        "--hadamard_groupsize",
        type=int,
        default=128,
        help="Groupsize of Hadamard transform for HIGGS.",
    )
    parser.add_argument(
        "--edenn_d",
        type=int,
        default=6,
        help="The grid dimension d for HIGGS.",
    )
    parser.add_argument(
        "--edenn_n",
        type=int,
        default=4096,
        help="The grid size n for HIGGS.",
    )
    
    parser.add_argument(
        "--ppl_chunk_size", #<- need to be renamed
        type=int,
        default=32,
        help="Number of tokens in one chunk.",
    )
    parser.add_argument(
        "--ppl_buffer_size",#<- need to be renamed
        type=int,
        default=128,
        help="Number of tokens in one chunk.",
    )
    parser.add_argument(
        "--output_path",
        type=str,
        default="./key_value_predictors_perheadchannel.pt",
        help="Path to save trained predictors for Key and Values",
    )

    parser.add_argument("--wandb", action="store_true", help="Whether to use wandb or store locally.")
    return parser

### Parsing Arguments

In [None]:
parser = make_arg_parser()
torch.set_num_threads(min(16, torch.get_num_threads()))
args = parser.parse_args(args=[])
args.offload_activations = True

In [None]:
from prekv.quantizers import QuantizerBase, HadLinear, quantize_linear_weight
class PerChannelHiggsQuantizer(QuantizerBase):
    def __init__(self, hadamard_groupsize: int, edenn_d: int, edenn_n: int):
        super().__init__()
        self.hadamard_groupsize, self.edenn_d, self.edenn_n = hadamard_groupsize, edenn_d, edenn_n
        self.channel_group_size = 128
        self.num_heads = model.config.num_key_value_heads

    @torch.no_grad()
    def quantize(self, x: torch.Tensor):
        x_flat = x.reshape(-1, x.shape[-1] // self.num_heads)  # [num tokens * num_heads, head_dim]
        num_head_vectors, head_dim = x_flat.shape
        total_block_size = self.channel_group_size * self.num_heads
        
        
        padding_size = 0
        if num_head_vectors % total_block_size != 0:
            padding_size = total_block_size - num_head_vectors % total_block_size
            x_flat = torch.cat([x_flat, torch.zeros(padding_size, head_dim, device=x.device, dtype=x.dtype)])

        x_channelwise_with_padding = x_flat.reshape(
            -1, total_block_size, head_dim).swapaxes(1, 2).flatten(0, 1).contiguous()
        
        quantized = quantize_linear_weight(
            x_channelwise_with_padding, self.hadamard_groupsize, self.edenn_d, self.edenn_n)[0]
        
        quantized._head_dim = head_dim
        quantized._size_before_pad = x_flat.shape[0] - padding_size
        quantized._og_shape = x.shape
        quantized._og_dtype = x.dtype
        return quantized
        

    @torch.no_grad()
    def dequantize(self, quantized: HadLinear) -> torch.Tensor:
        total_block_size = self.channel_group_size * self.num_heads
        x_channelwise_with_padding =  quantized(
            torch.eye(quantized.weight.shape[1], device='cuda').half()).T[:, :total_block_size].contiguous()
        
        assert x_channelwise_with_padding.shape[-1] == total_block_size, "dequantized shape mismatch; did you set hadamard_groupsize correctly?"
        x_flat = x_channelwise_with_padding.reshape(
            -1, quantized._head_dim, total_block_size).swapaxes(1, 2).reshape(-1, quantized._head_dim)
        x_flat = x_flat[:quantized._size_before_pad]
        return x_flat.reshape(quantized._og_shape).to(quantized._og_dtype)


### Training from the scratch 

In [None]:
# infer defaults
if args.devices is None:
    if torch.cuda.is_available():
        args.devices = [torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())]
    else:
        args.devices = [torch.device("cpu")]
else:
    args.devices = [torch.device(device_str) for device_str in args.devices]
assert len(args.devices) == 1, "parallelism is still WIP"

# load model and data
model = transformers.AutoModelForCausalLM.from_pretrained(
    args.model_name, torch_dtype=args.torch_dtype, low_cpu_mem_usage=True, use_cache=False
)

data = datautils.get_loaders(
    args.dataset,
    nsamples=args.total_nsamples,
    seed=args.seed,
    model_path=args.model_name,
    seqlen=args.model_seqlen,
)

key_quantizer = PerChannelHiggsQuantizer(args.hadamard_groupsize, args.edenn_d, args.edenn_n)
value_quantizer = HiggsQuantizer(args.hadamard_groupsize, args.edenn_d, args.edenn_n)

# Calibration: propagate a set of inputs through one layer at a time, train predictors as we go
layers = modelutils.get_layers(model)

inps, forward_args = modelutils.get_inps(
    model, data, args.model_seqlen, args.devices, args.offload_activations)

for k, v in forward_args.items():
    forward_args[k] = v.to(args.devices[0]) if isinstance(v, torch.Tensor) else v

outs = [torch.zeros_like(inp_tensor, pin_memory=inp_tensor.is_pinned()) for inp_tensor in inps]
old_attn_keys = None
old_attn_values = None

key_predictors = {}
value_predictors = {}

for layer_index in range(len(layers)):
    print(f"\n---------------- Layer {layer_index} of {len(layers)} ----------------")
    layer_device_original = next(layers[layer_index].parameters()).device
    layer_dtype_original = next(layers[layer_index].parameters()).dtype
    layer = layers[layer_index].to(device=args.devices[0], dtype=args.compute_dtype or layer_dtype_original)

    layer.self_attn.k_proj = OutputCatcher(layer.self_attn.k_proj, args.offload_activations)
    layer.self_attn.v_proj = OutputCatcher(layer.self_attn.v_proj, args.offload_activations)

    modelutils.update_outs_inplace_(args.devices, layer, inps, outs, **forward_args, compute_mse=False)

    attn_keys = layer.self_attn.k_proj.outputs
    assert all(elem.shape[0] == 1 for elem in attn_keys)
    attn_keys = [elem[0] for elem in attn_keys]

    attn_values = layer.self_attn.v_proj.outputs
    assert all(elem.shape[0] == 1 for elem in attn_values)
    attn_values = [elem[0] for elem in attn_values]

    layer.self_attn.k_proj = layer.self_attn.k_proj.inner
    layer.self_attn.v_proj = layer.self_attn.v_proj.inner

    layers[layer_index] = layer.to(device=layer_device_original, dtype=layer_dtype_original)
    del layer
    torch.cuda.empty_cache()

    inps, outs = outs, inps

    if layer_index == 0:
        old_attn_keys = attn_keys
        old_attn_values = attn_values
        continue

    ### training predictor below ###
    key_predictor_inputs = list(old_attn_keys)
    key_predictor, mse_train_keys, mse_valid_keys = get_predictor(args, key_predictor_inputs, attn_keys)
    attn_keys = get_dequant_values(args, key_quantizer, key_predictor, key_predictor_inputs, attn_keys)
    del key_predictor_inputs
    key_predictors[layer_index] = key_predictor.cpu()
    train_bits_keys = - math.log(mse_train_keys) / math.log(4)
    valid_bits_keys = - math.log(mse_valid_keys) / math.log(4)
    print(f'{layer_index=}\tPREDICTOR_KEYS   \t| relMSE train: {mse_train_keys:.4f} valid: {mse_valid_keys:.4f} '
          f'| equiv.bits train: {train_bits_keys:.2f} valid: {valid_bits_keys:.2f}')
    value_predictor_inputs = [
        torch.cat([k_i, old_v_i], dim=-1) for k_i, old_v_i in zip(attn_keys, old_attn_values)]
    value_predictor, mse_train_values, mse_valid_values = get_predictor(args, value_predictor_inputs, attn_values)
    attn_values = get_dequant_values(args, value_quantizer, value_predictor, value_predictor_inputs, attn_values)
    value_predictors[layer_index] = value_predictor.cpu()
    del value_predictor_inputs
    train_bits_values = - math.log(mse_train_values) / math.log(4)
    valid_bits_values = - math.log(mse_valid_values) / math.log(4)
    print(
        f'{layer_index=}\tPREDICTOR_VALUES \t| relMSE train: {mse_train_values:.4f} valid: {mse_valid_values:.4f} '
        f'| equiv.bits train: {train_bits_values:.2f} valid: {valid_bits_values:.2f}')

    old_attn_keys, old_attn_values = attn_keys, attn_values

torch.save(dict(key_predictors=key_predictors, value_predictors=value_predictors), args.output_path)
print("Saved predictors to", args.output_path)

## PPL evaluation 

In [None]:
from typing import Any, Tuple, Optional, Dict, List
from prekv.cache_utils import apply_rotary_to_keys,split_heads,combine_heads

In [None]:
class SingleChunkQuantizedCacheWithPredictorsPerChannel(transformers.cache_utils.Cache):
    """A **write-once** cache that uses cumulative predictors; assumes that inputs are pre-grouped"""

    def __init__(self, *, quantizer_key: QuantizerBase, quantizer_value: QuantizerBase,
                 key_predictors: Optional[Dict[int, nn.Module]] = None,
                 value_predictors: Optional[Dict[int, nn.Module]] = None):
        super().__init__()
        self.quantizer_key, self.quantizer_value, self.key_predictors, self.value_predictors = quantizer_key, quantizer_value, key_predictors, value_predictors
        self.key_states_cache, self.value_states_cache = dict(), dict()
        self.previous_key_reconstruction = self.previous_value_reconstruction = None
        self.next_layer_idx = 0
        self.seq_length = 0
        self.cos = self.sin = None
        self.head_dim = None

    def predict_next_key_states(self) -> torch.Tensor:
        if self.key_predictors is not None:
            return self.key_predictors[self.next_layer_idx](self.previous_key_reconstruction)
        else:
            return torch.zeros_like(self.previous_key_reconstruction)

    def predict_next_value_states(self, reconstructed_key_states: torch.Tensor) -> torch.Tensor:
        if self.value_predictors is not None:
            value_predictor_inputs = torch.cat([reconstructed_key_states, self.previous_value_reconstruction], dim=-1)
            return self.value_predictors[self.next_layer_idx](value_predictor_inputs)
        else:
            return torch.zeros_like(self.previous_value_reconstruction)

    def get_seq_length(self, layer_idx: int = 0) -> int:
        assert layer_idx == 0
        return self.key_states_cache[0].shape[-2] if self.key_states_cache else 0

    @torch.no_grad()
    def update(self,
               key_states: Optional[torch.Tensor],
               value_states: Optional[torch.Tensor],
               layer_idx: int,
               cache_kwargs: Optional[Dict[str, Any]] = None,
               ) -> Tuple[torch.Tensor, torch.Tensor]:
        assert layer_idx in (self.next_layer_idx, 0), (layer_idx, self.next_layer_idx, 0)
        assert (key_states is None and value_states is None) or (key_states.shape == value_states.shape)
        saving_new_entries = key_states is not None and key_states.numel() != 0
        assert saving_new_entries == (layer_idx not in self.key_states_cache), "can only write once per layer"

        if saving_new_entries:  # write mode
            key_states_original, value_states_original = key_states, value_states
            assert 'sin' in cache_kwargs and 'cos' in cache_kwargs
            if self.cos is None:  # save the (identical) sin/cos for future reuse
                self.cos, self.sin = cache_kwargs['cos'], cache_kwargs['sin']

            if self.head_dim is None:
                self.head_dim = key_states.shape[-1]
            # undo rotation using cos(-alpha) = cos(alpha) and sin(-alpha) = -sin(alpha)
            key_states = apply_rotary_to_keys(key_states, cos=self.cos, sin=-self.sin)

            # v-- from [batch, num_heads, seq_length, head_dim] to [batch, seq_length, hidden_size]
            key_states, value_states = map(combine_heads, (key_states, value_states))

            if layer_idx == 0:
                reconstructed_key_states = self.key_states_cache[0] = key_states
                reconstructed_value_states = self.value_states_cache[0] = value_states
            else:
                predicted_key_states = self.predict_next_key_states()
                self.key_states_cache[layer_idx] = self.quantizer_key.quantize(
                    (key_states - predicted_key_states).flatten(0, -2))
                reconstructed_key_states = predicted_key_states + self.quantizer_key.dequantize(
                    self.key_states_cache[layer_idx]).view_as(key_states).to(predicted_key_states.dtype)
                predicted_value_states = self.predict_next_value_states(reconstructed_key_states)
                self.value_states_cache[layer_idx] = self.quantizer_value.quantize(
                    (value_states - predicted_value_states).flatten(0, -2))
                reconstructed_value_states = predicted_value_states + self.quantizer_value.dequantize(
                    self.value_states_cache[layer_idx]).view_as(value_states).to(predicted_value_states.dtype)

            # return original data since it's available, avoid quantization errors for that one step
            result_key, result_value = key_states_original, value_states_original
        else:  # read mode
            if layer_idx == 0:
                reconstructed_key_states = self.key_states_cache[0]
                reconstructed_value_states = self.value_states_cache[0]
            else:
                reconstructed_key_states = self.predict_next_key_states()
                reconstructed_key_states += self.quantizer_key.dequantize(
                    self.key_states_cache[layer_idx]).view_as(
                    reconstructed_key_states).to(reconstructed_key_states.dtype)

                reconstructed_value_states = self.predict_next_value_states(reconstructed_key_states)
                reconstructed_value_states += self.quantizer_value.dequantize(self.value_states_cache[layer_idx]).view_as(
                    reconstructed_value_states).to(reconstructed_value_states.dtype)

            # apply rotary embedding again
            assert self.sin is not None and self.cos is not None and self.head_dim is not None
            result_key_without_rotary = split_heads(reconstructed_key_states, self.head_dim)
            result_key = apply_rotary_to_keys(result_key_without_rotary, cos=self.cos, sin=self.sin)
            result_value = split_heads(reconstructed_value_states, self.head_dim)

        self.next_layer_idx = layer_idx + 1
        self.previous_key_reconstruction = reconstructed_key_states
        self.previous_value_reconstruction = reconstructed_value_states
        return result_key, result_value

    def __repr__(self):
        return f"{self.__class__.__name__}({self.get_seq_length()})"

In [None]:
with torch.no_grad():
    model.to(args.devices[0])
    testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    config = transformers.AutoConfig.from_pretrained(args.model_name)
    tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_name, config=config, padding_side="left")
    
    testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")['input_ids']
    step_size = args.ppl_chunk_size

In [None]:
with torch.no_grad():
    cache_factory = None
    ppl = evaluate_perplexity(model, testenc, args.model_seqlen, device=args.devices[0], step_size=step_size, cache_factory=cache_factory)

In [None]:
[key_predictors[i].to(args.devices[0]) for i in key_predictors]
[value_predictors[i].to(args.devices[0]) for i in value_predictors]

In [None]:
with torch.no_grad():
    key_quantizer = PerChannelHiggsQuantizer(args.hadamard_groupsize, args.edenn_d, args.edenn_n)
    value_quantizer = HiggsQuantizer(args.hadamard_groupsize, args.edenn_d, args.edenn_n)
    cache_factory = lambda: TreatPrefixSeparately(prefix_size=4,
                          prefix_cache=transformers.DynamicCache(),
                          suffix_cache=PredictorHiggsCache(
                          config=model.config, min_buffer_size=args.ppl_buffer_size, save_dequantized_values=True,
                          make_quantized_cache=partial(
                                SingleChunkQuantizedCacheWithPredictorsPerChannel, 
                                quantizer_key=key_quantizer, quantizer_value=value_quantizer,
                                key_predictors=key_predictors, value_predictors=value_predictors
                            )
                        ))

    ppl_quantized = evaluate_perplexity(model, testenc, args.model_seqlen, device=args.devices[0], step_size=step_size, cache_factory=cache_factory)

In [None]:
print(f"PPL on with static cache {ppl}\nPPL on with quantized cache {ppl_quantized}\n")