In [1]:
DEBUG = False

MODEL_NAME = "rinna/youri-7b-instruction"
MODEL_BASE_NAME = MODEL_NAME.split("/")[-1]
LORA_DIR = f"./pretrained_lora_{MODEL_BASE_NAME}"

OUTPUT_MERGED_DIR = f"./pretrained_merged_{MODEL_BASE_NAME}"
OUTPUT_QUANTIZED_DIR = f"./pretrained_awq_{MODEL_BASE_NAME}"

In [2]:
from peft import PeftModel  # type: ignore
from transformers import AutoTokenizer, AutoModelForCausalLM
import os

if not os.path.exists(OUTPUT_MERGED_DIR):
    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
    )
    model = PeftModel.from_pretrained(base_model, LORA_DIR)
    model = model.merge_and_unload().half()
    model.save_pretrained(OUTPUT_MERGED_DIR)
    del model  # unload
    del base_model  # unload
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    # save to OUTPUT_SAVE_DIR
    tokenizer.save_pretrained(OUTPUT_MERGED_DIR)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from datasets import load_dataset


# base: https://github.com/casper-hansen/AutoAWQ/blob/main/examples/quant_custom_data.py
# Define data loading methods
def load_wiki_ja():
    data = load_dataset(
        "singletongue/wikipedia-utils",
        split="train",
        name="passages-c400-jawiki-20230403",
    )

    return [text for text in data["text"]]  # type: ignore

In [4]:
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model_path = OUTPUT_MERGED_DIR
quant_path = OUTPUT_QUANTIZED_DIR

q_version = "GEMM"
quant_config = {
    "zero_point": True,
    "q_group_size": 128,
    "w_bit": 4,
    "version": q_version,
}

# Load model
model = AutoAWQForCausalLM.from_pretrained(model_path, safetensors=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Quantize
model.quantize(tokenizer, quant_config=quant_config, calib_data=load_wiki_ja())  # type: ignore

# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)

Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.10it/s]
AWQ: 100%|██████████| 32/32 [11:47<00:00, 22.11s/it]


('./pretrained_awq_youri-7b-instruction/tokenizer_config.json',
 './pretrained_awq_youri-7b-instruction/special_tokens_map.json',
 './pretrained_awq_youri-7b-instruction/tokenizer.model',
 './pretrained_awq_youri-7b-instruction/added_tokens.json',
 './pretrained_awq_youri-7b-instruction/tokenizer.json')