In [None]:
import sys
import os
sys.path.append("../")

import datetime
from logging import Logger

import torch
import torch.distributed as dist
from transformers import LlamaTokenizerFast
import transformers
from eval_utils.main import ptq_model
from eval_utils.modeling_llama import LlamaForCausalLM
from utils import data_utils, eval_utils, utils
from utils.process_args import process_args_ptq

import evaluate
from lm_eval import evaluator
from lm_eval.utils import make_table

from utils.quant_utils import find_qlayers, ActQuantWrapper
from functools import partial
import pickle

from utils.profile import (
  measure, profile, get_profiler, 
  get_profiled_df, plot_profiled_df,
  run_profile
)
import pstats
import importlib

import pandas as pd
pd.set_option('display.max_colwidth', 100)

from matplotlib import pyplot as plt

import torch
import functools
import random

In [None]:
log: Logger = utils.get_logger("spinquant")
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
sys.argv = [
  "python",
  "--input_model", "../models/llama2-7b",
  "--do_train", "False",
  "--do_eval", "True",
  "--per_device_eval_batch_size", "4",
  "--model_max_length", "2048",
  "--fp16", "True",
  "--bf16", "False",
  "--save_safetensors", "False",
  "--w_bits", "4",
  "--a_bits", "16",
  "--k_bits", "4",
  "--v_bits", "4",
  "--w_clip",
  "--a_asym",
  "--k_asym",
  "--v_asym",
  "--rotate",
  "--k_groupsize", "128",
  "--v_groupsize", "128",
  "--load_qmodel_path", "../saved_models/llama2-7b/a16w4kv4-vasym.pt",
  "--optimized_rotation_path", "../rotation_llama-2-7b/a16w4kv4-vasym/R.bin"
]

In [None]:
# dist.init_process_group(backend="nccl", timeout=datetime.timedelta(hours=8))
model_args, training_args, ptq_args = process_args_ptq()
print("------- ARGS ----------")
print("-----model args-----")
print(model_args)
print("------train args-------")
print(training_args)
print("-------- ptq args ---------")
print(ptq_args)
print("------- ARGS END ----------")

config = transformers.AutoConfig.from_pretrained(
    model_args.input_model, token=model_args.access_token, attn_implementation="eager"
)
# Llama v3.2 specific: Spinquant is not compatiable with tie_word_embeddings, clone lm_head from embed_tokens
process_word_embeddings = False
if config.tie_word_embeddings:
    config.tie_word_embeddings = False
    process_word_embeddings = True
dtype = torch.bfloat16 if training_args.bf16 else torch.float16
model = LlamaForCausalLM.from_pretrained(
    pretrained_model_name_or_path=model_args.input_model,
    config=config,
    torch_dtype=dtype,
    token=model_args.access_token,
)
if process_word_embeddings:
    model.lm_head.weight.data = model.model.embed_tokens.weight.data.clone()
model.cuda()

model = ptq_model(ptq_args, model, model_args)

In [None]:
model.seqlen = training_args.model_max_length
tokenizer = LlamaTokenizerFast.from_pretrained(
    pretrained_model_name_or_path=model_args.input_model,
    cache_dir=training_args.cache_dir,
    model_max_length=training_args.model_max_length,
    padding_side="right",
    use_fast=True,
    add_eos_token=False,
    add_bos_token=False,
    token=model_args.access_token,
)

In [None]:
def append_to_dict(d, key, value, max_length=None):
  if key not in d:
    d[key] = []
  if max_length is not None:
    if len(d[key]) >= max_length:
      d[key].pop(random.choice(range(len(d[key]))))
  d[key].append(value)


def is_linear(module):
  return isinstance(module, torch.nn.Linear) and (hasattr(module, 'weight'))

def hook_to_linear(model, inputs, outputs, max_length=1024):
  def forward_hook(module, input, output, name=None):
    append_to_dict(inputs, name, input[0].detach().cpu(), max_length=max_length)
    append_to_dict(outputs, name, output.detach().cpu(), max_length=max_length)

  handles = []
  for name, module in model.named_modules():
    if is_linear(module):
      handles.append(module.register_forward_hook(functools.partial(forward_hook, name=name)))
  
  return handles

def extract_weights(model):
  weights = {}
  for name, module in model.named_modules():
    if is_linear(module) and "lm_head" not in name:
      weights[name] = module.weight.detach().cpu()
  return weights

In [None]:
inputs, outputs, weights = {}, {}, {}
handles = hook_to_linear(model, inputs, outputs, max_length=10000)

In [None]:
weights = extract_weights(model)
weights = {k.replace(".module",""): v for k, v in weights.items()}

In [None]:
# task_names = ['hellaswag', 'arc_easy','arc_challenge', 'winogrande', 'openbookqa']
task_names = ['hellaswag']

for task_name in task_names:
  inputs.clear()
  outputs.clear()

  results = evaluator.simple_evaluate(
      model="hf",
      model_args={"pretrained" : model.cuda(),
                  "tokenizer" : tokenizer},
      tasks=[task_name],
      num_fewshot=0,
      batch_size=1,
      limit=1,
      device="cuda"
  )

In [None]:
import pickle

In [None]:
pickle.dump(inputs, open(f"inputs_{task_name}.pkl", "wb"))
pickle.dump(outputs, open(f"outputs_{task_name}.pkl", "wb"))
pickle.dump(weights, open(f"weights_{task_name}.pkl", "wb"))