In [2]:
import sys
sys.argv = [
  "python",
  "--input_model", "llama2-7b-atlas2",
  "--do_train", "False",
  "--do_eval", "True",
  "--per_device_eval_batch_size", "4",
  "--model_max_length", "2048",
  "--fp16", "True",
  "--bf16", "False",
  "--save_safetensors", "False",
  "--w_bits", "4",
  "--a_bits", "16",
  "--k_bits", "4",
  "--v_bits", "4",
  "--w_clip",
  "--a_asym",
  "--k_asym",
  "--v_asym",
  "--k_groupsize", "128",
  "--v_groupsize", "128",
  "--rotate",
  "--save_qmodel_path", "saved_models/qllama2-7b-4-4-4-128-fp16.pt",
  "--optimized_rotation_path", "rotation_llama-2-7b/a4w4kv4_fp16/R.bin"
]

In [3]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [4]:
import datetime
from logging import Logger

import torch
import torch.distributed as dist
from transformers import LlamaTokenizerFast
import transformers
from eval_utils.main import ptq_model
from eval_utils.modeling_llama import LlamaForCausalLM
from utils import data_utils, eval_utils, utils
from utils.process_args import process_args_ptq

log: Logger = utils.get_logger("spinquant")

import evaluate
from lm_eval import evaluator
from lm_eval.utils import make_table

from utils.quant_utils import find_qlayers, ActQuantWrapper
from functools import partial
import pickle

In [5]:
task_names = ['hellaswag', 'arc_easy','arc_challenge', 'winogrande', 'openbookqa', "wikitext"]
# task_names = ['openbookqa']
# task_names = ['arc_easy']

CUDA_DEVICES = list(map(str.strip, os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")))
FIRST_GPU_ID = int(CUDA_DEVICES[0])
GPU_ID = 0

In [6]:
# dist.init_process_group(backend="nccl", timeout=datetime.timedelta(hours=8))
model_args, training_args, ptq_args = process_args_ptq()
print("------- ARGS ----------")
print("-----model args-----")
print(model_args)
print("------train args-------")
print(training_args)
print("-------- ptq args ---------")
print(ptq_args)
print("------- ARGS END ----------")

# local_rank = utils.get_local_rank()

# log.info("the rank is {}".format(local_rank))
# torch.distributed.barrier()

config = transformers.AutoConfig.from_pretrained(
    model_args.input_model, token=model_args.access_token
)
# Llama v3.2 specific: Spinquant is not compatiable with tie_word_embeddings, clone lm_head from embed_tokens
process_word_embeddings = False
if config.tie_word_embeddings:
    config.tie_word_embeddings = False
    process_word_embeddings = True
dtype = torch.bfloat16 if training_args.bf16 else torch.float16
model = LlamaForCausalLM.from_pretrained(
    pretrained_model_name_or_path=model_args.input_model,
    config=config,
    torch_dtype=dtype,
    token=model_args.access_token,
)
if process_word_embeddings:
    model.lm_head.weight.data = model.model.embed_tokens.weight.data.clone()
model.cuda()

model = ptq_model(ptq_args, model, model_args)

# for l, layer in enumerate(model.model.layers):
#     layer.self_attn.q_proj.quantizer.register_forward_hook(partial(forward_hook_act_quant, name=name))
#     layer.self_attn.q_proj.register_forward_hook(forward_hook_weight_quant)

model.seqlen = training_args.model_max_length
# if local_rank == 0:
log.info("Model PTQ completed {}".format(model))
log.info("Start to load tokenizer...")
tokenizer = LlamaTokenizerFast.from_pretrained(
    pretrained_model_name_or_path=model_args.input_model,
    cache_dir=training_args.cache_dir,
    model_max_length=training_args.model_max_length,
    padding_side="right",
    use_fast=True,
    add_eos_token=False,
    add_bos_token=False,
    token=model_args.access_token,
)
log.info("Complete tokenizer loading...")
model.config.use_cache = False

# dataset_ppl = eval_utils.evaluator(model, testloader, utils.DEV, ptq_args)
# log.info("wiki2 ppl is: {}".format(dataset_ppl))
# dist.barrier()

------- ARGS ----------
-----model args-----
ModelArguments(input_model='llama2-7b-atlas2', output_rotation_path='test-output', optimized_rotation_path='rotation_llama-2-7b/a4w4kv4_fp16/R.bin', access_token=None)
------train args-------
TrainingArguments(
_n_gpu=1,
accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None, 'use_configured_state': False},
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
batch_eval_metrics=False,
bf16=False,
bf16_full_eval=False,
cache_dir=None,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
dataloader_prefetch_factor=None,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Rotating:  28%|██▊       | 9/32 [00:52<02:15,  5.88s/layer]


KeyboardInterrupt: 

In [None]:
model = model.cuda()

In [73]:
prompt = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(prompt, return_tensors="pt")
inputs.input_ids = inputs.input_ids.cuda()
inputs.attention_mask = inputs.attention_mask.cuda()

In [74]:
# forward
output = model(
  input_ids=inputs.input_ids,
  use_cache=True
)
next_tokens = torch.argmax(output.logits, dim=-1)
result = tokenizer.batch_decode(next_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(result)

[', I you going?\n you hear to me?\n']


In [75]:
prompt = "Hey, how are you?"
inputs_2 = tokenizer(prompt, return_tensors="pt")
inputs_2.input_ids = inputs_2.input_ids.cuda()
inputs_2.attention_mask = inputs_2.attention_mask.cuda()

# inputs.input_ids = torch.cat([inputs.input_ids, inputs_2.input_ids], dim=1)
# inputs.attention_mask = torch.cat([inputs.attention_mask, inputs_2.attention_mask], dim=1)
past_key_values = output.past_key_values

In [83]:
output2 = model(
  input_ids=inputs_2.input_ids, use_cache=True, past_key_values=past_key_values)

In [79]:
next_tokens2 = torch.argmax(output2.logits, dim=-1)
result2 = tokenizer.batch_decode(next_tokens2, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(result2)

[', are are you feeling\n']


In [50]:
# Generate
generate_ids = model.generate(inputs.input_ids, max_length=1024)
result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(result)

Hey, are you conscious? Can you talk to me?
I was able to answer, "Yes, I am conscious."
The doctor said, "Well, that's good. Now, can you talk to me?"
I said, "Yes, I can talk to you."
The doctor said, "Well, that's good. Now, can you hear me?"
I said, "Yes, I can hear you."
The doctor said, "Well, that's good. Now, can you move?"
I said, "Yes, I can move."
The doctor said, "Well, that's good. Now, can you stand?"
I said, "Yes, I can stand."
The doctor said, "Well, that's good. Now, can you sit?"
I said, "Yes, I can sit."
The doctor said, "Well, that's good. Now, can you lie down?"
I said, "Yes, I can lie down."
The doctor said, "Well, that's good. Now, can you get up?"
I said, "Yes, I can get up."
The doctor said, "Well, that's good. Now, can you get out of bed?"
I said, "Yes, I can get out of bed."
The doctor said, "Well, that's good. Now, can you get downstairs?"
I said, "Yes, I can get downstairs."
The doctor said, "Well, that's good. Now, can you get down the street?"
I said, "Ye

In [84]:
data = torch.rand(1, 128, dtype=torch.float16)

In [86]:
a = torch.nn.SiLU()

In [87]:
print(a(data))

tensor([[4.0869e-01, 3.0933e-01, 8.8440e-02, 3.5791e-01, 7.2205e-02, 6.7041e-01,
         4.5239e-01, 3.4863e-01, 6.2207e-01, 7.0801e-02, 4.8828e-04, 6.3770e-01,
         2.0020e-01, 4.4281e-02, 3.5156e-01, 7.5256e-02, 5.9570e-02, 6.8408e-01,
         5.8887e-01, 1.0278e-01, 4.2114e-01, 6.5625e-01, 6.8604e-02, 5.2881e-01,
         2.5317e-01, 3.1128e-01, 4.1382e-02, 3.2056e-01, 3.8184e-01, 6.6602e-01,
         2.8418e-01, 7.2119e-01, 5.7227e-01, 1.9043e-01, 5.6299e-01, 6.5820e-01,
         6.7200e-02, 6.7200e-02, 4.4556e-01, 3.5376e-01, 4.1656e-02, 1.1481e-01,
         9.7290e-02, 3.1836e-01, 3.9819e-01, 3.3539e-02, 2.0898e-01, 3.1079e-01,
         3.4155e-01, 1.9788e-01, 2.1326e-01, 4.0820e-01, 6.5381e-01, 3.6694e-01,
         3.2861e-01, 2.5415e-01, 1.9006e-01, 3.9209e-01, 6.7676e-01, 6.9238e-01,
         1.8433e-01, 5.1123e-01, 6.2012e-02, 2.7539e-01, 5.2277e-02, 5.5566e-01,
         2.0435e-01, 6.1572e-01, 1.1505e-01, 5.8545e-01, 6.1493e-02, 5.6689e-01,
         2.1790e-01, 1.0605e

In [89]:
print(a(data.to(torch.float32)).to(torch.float16))

tensor([[4.0869e-01, 3.0933e-01, 8.8440e-02, 3.5791e-01, 7.2205e-02, 6.7041e-01,
         4.5239e-01, 3.4863e-01, 6.2207e-01, 7.0801e-02, 4.8828e-04, 6.3770e-01,
         2.0020e-01, 4.4281e-02, 3.5156e-01, 7.5256e-02, 5.9570e-02, 6.8408e-01,
         5.8887e-01, 1.0278e-01, 4.2114e-01, 6.5625e-01, 6.8604e-02, 5.2881e-01,
         2.5317e-01, 3.1128e-01, 4.1382e-02, 3.2056e-01, 3.8184e-01, 6.6602e-01,
         2.8418e-01, 7.2119e-01, 5.7227e-01, 1.9043e-01, 5.6299e-01, 6.5820e-01,
         6.7200e-02, 6.7200e-02, 4.4556e-01, 3.5376e-01, 4.1656e-02, 1.1481e-01,
         9.7290e-02, 3.1836e-01, 3.9819e-01, 3.3539e-02, 2.0898e-01, 3.1079e-01,
         3.4155e-01, 1.9788e-01, 2.1326e-01, 4.0820e-01, 6.5381e-01, 3.6694e-01,
         3.2861e-01, 2.5415e-01, 1.9006e-01, 3.9209e-01, 6.7676e-01, 6.9238e-01,
         1.8433e-01, 5.1123e-01, 6.2012e-02, 2.7539e-01, 5.2277e-02, 5.5566e-01,
         2.0435e-01, 6.1572e-01, 1.1505e-01, 5.8545e-01, 6.1493e-02, 5.6689e-01,
         2.1790e-01, 1.0605e

In [93]:
model.lm_head.weight

Parameter containing:
tensor([[-0.0379, -0.0127, -0.0067,  ...,  0.0037, -0.0528,  0.0105],
        [-0.0534, -0.0088,  0.0450,  ..., -0.0033, -0.0148,  0.0345],
        [ 0.0424, -0.0146,  0.0022,  ...,  0.0226,  0.0007,  0.0328],
        ...,
        [-0.0692,  0.0705, -0.0648,  ...,  0.0278, -0.0333,  0.0130],
        [ 0.0017,  0.0330, -0.0479,  ..., -0.0121, -0.0211,  0.0172],
        [-0.0619, -0.0010, -0.0347,  ...,  0.0504,  0.0083, -0.0307]],
       device='cuda:0', dtype=torch.float16, requires_grad=True)

In [None]:
model.model.embed_tokens.weight

Parameter containing:
tensor([[-1.1206e-05, -3.1590e-06,  3.6955e-06,  ..., -8.5235e-06,
          3.6359e-06, -2.0862e-06],
        [-1.7986e-03,  3.9062e-03, -6.9847e-03,  ...,  2.5768e-03,
          9.1970e-05, -6.0692e-03],
        [-5.0316e-03,  7.6675e-03, -6.0577e-03,  ...,  2.5234e-03,
          5.1003e-03,  1.9318e-02],
        ...,
        [-2.4429e-02,  8.8272e-03,  2.4139e-02,  ..., -1.1566e-02,
         -1.3702e-02, -4.4746e-03],
        [-3.7781e-02,  5.1193e-03,  1.3397e-02,  ...,  1.5327e-02,
         -8.6136e-03, -3.2806e-02],
        [-2.7588e-02,  1.2939e-02, -8.2397e-03,  ...,  8.6746e-03,
          1.4412e-02,  1.1208e-02]], device='cuda:0', dtype=torch.float16,
       requires_grad=True)

: 