In [None]:
%env CUDA_VISIBLE_DEVICES=7
%env TRANSFORMERS_CACHE=/mnt/LLM/hub

%env HF_HOME=/mnt/LLM
%env OMP_NUM_THREADS=16

import os
import sys
sys.path.insert(0, '..')

import math
from argparse import Namespace
from typing import Sequence, Optional, List, Tuple

import torch
import torch.nn as nn
import transformers
from tqdm import trange

from prekv import datautils, modelutils
from prekv.quantizers import QuantizerBase, HiggsQuantizer
from prekv.cache_utils import TreatPrefixSeparately,PredictorHiggsCache,SingleChunkQuantizedCacheWithPredictors
from functools import partial
from ppl import evaluate_perplexity
from datasets import load_dataset

In [None]:
def make_arg_parser():
    import argparse

    parser = argparse.ArgumentParser(add_help=True)

    parser.add_argument(
        "--model_name",
        default = "unsloth/Llama-3.2-3B",
        type=str,
        help="path to llama model to load, as in LlamaForCausalLM.from_pretrained()",
    )
    parser.add_argument(
        "--torch_dtype",
        type=str,
        default="auto",
        choices=["auto", "float16", "float32", "bfloat16"],
        help="dtype to load the model in",
    )
    parser.add_argument(
        "--model_seqlen",
        type=int,
        default=8192,
        help="Model seqlen and calibration data context length.",
    )
    parser.add_argument("--devices",
                        metavar="N",
                        type=str,
                        nargs="+",
                        default=None,
                        help="List of devices")
    parser.add_argument(
        "--seed",
        type=int,
        default=0,
        help="Seed for calibration data and initialization. "
             "Note that the main training is not strictly deterministic.",
    )
    parser.add_argument(
        "--ppl_chunk_size", #<- need to be renamed
        type=int,
        default=32,
        help="Number of tokens in one chunk.",
    )
    parser.add_argument(
        "--ppl_buffer_size",#<- need to be renamed
        type=int,
        default=128,
        help="Number of tokens in one chunk.",
    )


    parser.add_argument(
        "--hadamard_groupsize",
        type=int,
        default=1024,
        help="Groupsize of Hadamard transform for HIGGS.",
    )
    parser.add_argument(
        "--edenn_d",
        type=int,
        default=6,
        help="The grid dimension d for HIGGS.",
    )
    parser.add_argument(
        "--edenn_n",
        type=int,
        default=4096,
        help="The grid size n for HIGGS.",
    )
    parser.add_argument("--wandb", action="store_true", help="Whether to use wandb or store locally.") #TODO: implement
    return parser

### Parsing Arguments

In [None]:
parser = make_arg_parser()
torch.set_num_threads(min(16, torch.get_num_threads()))
args = parser.parse_args(args=[])

In [None]:
if args.devices is None:
    if torch.cuda.is_available():
        args.devices = [torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())]
    else:
        args.devices = [torch.device("cpu")]
else:
    args.devices = [torch.device(device_str) for device_str in args.devices]
assert len(args.devices) == 1, "parallelism is still WIP"

## PPL evaluation 

In [None]:
key_values = torch.load('../key_value_predictors.pt')
key_predictors, value_predictors =  key_values["key_predictors"], key_values["value_predictors"]
[key_predictors[i].to(args.devices[0]) for i in key_predictors]
[value_predictors[i].to(args.devices[0]) for i in value_predictors]

In [None]:
with torch.no_grad():
    testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    config = transformers.AutoConfig.from_pretrained(args.model_name)
    model =  transformers.AutoModelForCausalLM.from_pretrained(
            args.model_name, config=config, torch_dtype=args.torch_dtype, low_cpu_mem_usage=True).to(args.devices[0])
    tokenizer =  transformers.AutoTokenizer.from_pretrained(args.model_name, config=config, padding_side="left")

    testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")['input_ids']
    step_size = args.ppl_chunk_size

In [None]:
with torch.no_grad():
    cache_factory = None
    ppl = evaluate_perplexity(model, testenc, args.model_seqlen, device=args.devices[0], step_size=step_size, cache_factory=cache_factory)

In [None]:
with torch.no_grad():
    quantizer = HiggsQuantizer(args.hadamard_groupsize, args.edenn_d, args.edenn_n)
    cache_factory = lambda: TreatPrefixSeparately(prefix_size=4,
                          prefix_cache=transformers.DynamicCache(),
                          suffix_cache=PredictorHiggsCache(
                          config=model.config, min_buffer_size=args.ppl_buffer_size, save_dequantized_values=True,
                          make_quantized_cache=partial(
                                SingleChunkQuantizedCacheWithPredictors, quantizer=quantizer,
                                key_predictors=key_predictors, value_predictors=value_predictors
                            )
                        ))

    ppl_quantized = evaluate_perplexity(model, testenc, args.model_seqlen, device=args.devices[0], step_size=step_size, cache_factory=cache_factory)

In [None]:
print(f"PPL on with static cache {ppl}\nPPL on with quantized cache {ppl_quantized}\n")