In [None]:
%pip install llmcompressor
%pip install -U llmcompressor transformers


In [None]:
import os
import json
import re
import hashlib
import torch
import shutil
from pathlib import Path
from datetime import datetime, timezone

from google.colab import drive
drive.mount('/content/drive')

from datasets import Dataset, concatenate_datasets, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.modifiers.quantization import GPTQModifier


In [None]:
MODEL_ID = "/content/drive/MyDrive/LGAimers/base_model"

NUM_CALIBRATION_SAMPLES = 1024
MAX_SEQUENCE_LENGTH = 512
CALIBRATION_SEED = 42

CALIB_DATASETS = [
    {
        "id": "LGAI-EXAONE/MANTA-1M",
        "split": "train",
        "n_samples": 512,
        "format": "manta",
        "priority": 1,
    },
    {
        "id": "nlpai-lab/kullm-v2",
        "split": "train",
        "n_samples": 256,
        "format": "instruction",
        "priority": 2,
    },
    {
        "id": "heegyu/OIG-small-chip2-ko",
        "split": "train",
        "n_samples": 192,
        "format": "oig_human_bot",
        "priority": 3,
    },
    {
        "id": "beomi/KoAlpaca-v1.1a",
        "split": "train",
        "n_samples": 64,
        "format": "instruction",
        "priority": 4,
    },
]

CALIBRATION_BENCHMARK_EXCLUDE = {
    "LGAI-EXAONE/KoMT-Bench",
    "LGAI-EXAONE/KMMLU-Redux",
    "LGAI-EXAONE/KMMLU-Pro",
}
KOALPACA_FALLBACK_ID = "nlpai-lab/kullm-v2"

SMOOTHING_STRENGTH = 0.40
QUANT_SCHEME = "W8A8"
QUANT_SCOPE = "mlp_only"
MLP_TARGETS = [
    "re:.*mlp\.gate_proj",
    "re:.*mlp\.up_proj",
    "re:.*mlp\.down_proj",
]
IGNORE = ["embed_tokens", "lm_head"]

assert sum(spec["n_samples"] for spec in CALIB_DATASETS) == NUM_CALIBRATION_SAMPLES, (
    "CALIB_DATASETS 샘플 합계와 NUM_CALIBRATION_SAMPLES가 일치해야 합니다."
)
assert not any(spec["id"] in CALIBRATION_BENCHMARK_EXCLUDE for spec in CALIB_DATASETS), (
    "평가용 벤치마크 데이터셋은 calibration에서 제외해야 합니다."
)

if QUANT_SCOPE != "mlp_only":
    raise ValueError(f"지원하지 않는 QUANT_SCOPE: {QUANT_SCOPE}")

QUANT_TARGETS = MLP_TARGETS
OUT_DIR = "/content/drive/MyDrive/LGAimers/sq_mixcal_w8a8_mlp_s040_calmix1024"


In [None]:
print("[INFO] 모델 로드 중...")

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
)

print("[INFO] 모델/토크나이저 로드 완료")


In [None]:
def collect_target_modules(model, target_patterns):
    all_module_names = [name for name, _ in model.named_modules()]

    matched_by_pattern = {}
    for pattern in target_patterns:
        regex = pattern[len("re:"):] if pattern.startswith("re:") else re.escape(pattern)
        matches = [name for name in all_module_names if re.fullmatch(regex, name)]
        matched_by_pattern[pattern] = matches
        print(f"[TARGET] pattern={pattern} matches={len(matches)} sample={matches[:3]}")

    unmatched_patterns = [p for p, m in matched_by_pattern.items() if not m]
    if unmatched_patterns:
        raise AssertionError(f"No module matched for patterns: {unmatched_patterns}")

    unique_matches = sorted({m for matches in matched_by_pattern.values() for m in matches})

    expected_suffixes = {
        "mlp.gate_proj",
        "mlp.up_proj",
        "mlp.down_proj",
    }
    per_layer = {}
    for name in unique_matches:
        parts = name.split(".")
        if len(parts) < 5 or parts[0] != "model" or parts[1] != "layers" or not parts[2].isdigit():
            continue
        layer_id = int(parts[2])
        suffix = ".".join(parts[3:])
        per_layer.setdefault(layer_id, set()).add(suffix)

    missing_per_layer = {
        layer_id: sorted(expected_suffixes - suffixes)
        for layer_id, suffixes in per_layer.items()
        if suffixes != expected_suffixes
    }
    if missing_per_layer:
        raise AssertionError(f"Incomplete MLP targets per layer: {missing_per_layer}")

    layer_ids = sorted(per_layer)
    print(f"[TARGET] total_unique_matches={len(unique_matches)}")
    if layer_ids:
        print(f"[TARGET] layer_count={len(layer_ids)}, layer_range=({layer_ids[0]}, {layer_ids[-1]})")

    return unique_matches


QUANT_TARGET_MATCHED_MODULES = collect_target_modules(model, QUANT_TARGETS)
print("[INFO] MLP-only target validation complete")


In [None]:
def collect_modules(model):
    tracked_keys = [
        "q_proj",
        "k_proj",
        "v_proj",
        "gate_proj",
        "up_proj",
        "input_layernorm",
        "post_attention_layernorm",
    ]
    mod_index = {key: [] for key in tracked_keys}

    for name, _ in model.named_modules():
        for key in tracked_keys:
            if name.endswith(key):
                mod_index[key].append(name)

    return mod_index


def build_exaone_sq_mappings(mod_index):
    module_counts = {k: len(v) for k, v in mod_index.items()}
    module_samples = {k: v[:5] for k, v in mod_index.items()}

    has_input_ln = module_counts["input_layernorm"] > 0
    has_post_attn_ln = module_counts["post_attention_layernorm"] > 0
    has_gate = module_counts["gate_proj"] > 0
    has_up = module_counts["up_proj"] > 0

    mappings = []
    mode = "disabled"

    if has_post_attn_ln and has_gate and has_up:
        mappings = [
            [["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"]
        ]
        mode = "mlp_only"

    reason_parts = []
    if not has_input_ln:
        reason_parts.append("input_layernorm not found -> qkv smoothing disabled")
    if not has_post_attn_ln:
        reason_parts.append("post_attention_layernorm not found")
    if not has_gate or not has_up:
        reason_parts.append("gate_proj/up_proj pair incomplete")
    if not reason_parts:
        reason_parts.append("EXAONE policy: qkv smoothing disabled, mlp-only smoothing enabled")

    diag = {
        "module_counts": module_counts,
        "module_samples": module_samples,
        "reason": "; ".join(reason_parts),
    }

    return mappings, mode, diag


def normalize_sq_mappings(mappings):
    normalized = []
    for m in mappings:
        if isinstance(m, dict):
            normalized.append([m["balance_layers"], m["smooth_layers"]])
        else:
            normalized.append(m)
    return normalized


module_index = collect_modules(model)
SQ_MAPPINGS, SQ_MODE, SQ_DIAG = build_exaone_sq_mappings(module_index)
SQ_MAPPINGS = normalize_sq_mappings(SQ_MAPPINGS)
SQ_ENABLED = len(SQ_MAPPINGS) > 0

print("[SQ] EXAONE mapping diagnostics")
for key in ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "input_layernorm", "post_attention_layernorm"]:
    print(f"[SQ] {key}: count={SQ_DIAG['module_counts'][key]}, sample={SQ_DIAG['module_samples'][key]}")

print(f"[SQ] mode={SQ_MODE}, enabled={SQ_ENABLED}, mappings={len(SQ_MAPPINGS)}")
print(f"[SQ] reason={SQ_DIAG['reason']}")
if SQ_ENABLED:
    print(f"[SQ] mappings={SQ_MAPPINGS}")


In [None]:
print("[INFO] 캘리브레이션 데이터 로드 중...")

def _clean_text(value):
    if value is None:
        return ""
    return str(value).strip()

def _normalize_manta_conversations(conversations):
    if not isinstance(conversations, list):
        return None

    normalized = []
    for turn in conversations:
        if not isinstance(turn, dict):
            continue
        role = _clean_text(turn.get("role")).lower()
        content = _clean_text(turn.get("content"))
        if role in {"system", "user", "assistant"} and content:
            normalized.append({"role": role, "content": content})

    has_user = any(turn["role"] == "user" for turn in normalized)
    has_assistant = any(turn["role"] == "assistant" for turn in normalized)
    if not (has_user and has_assistant):
        return None
    return normalized

def _normalize_instruction_turn(example):
    instruction = _clean_text(example.get("instruction"))
    input_text = _clean_text(example.get("input"))
    output_text = _clean_text(example.get("output"))

    user_content = instruction
    if input_text:
        user_content = f"{instruction}\n\n{input_text}" if instruction else input_text

    if not user_content or not output_text:
        return None

    return [
        {"role": "user", "content": user_content},
        {"role": "assistant", "content": output_text},
    ]

def _normalize_oig_pair(example):
    user_content = _clean_text(example.get("user_translated")) or _clean_text(example.get("user"))
    assistant_content = _clean_text(example.get("chip2_translated")) or _clean_text(example.get("chip2"))
    if user_content and assistant_content:
        return [
            {"role": "user", "content": user_content},
            {"role": "assistant", "content": assistant_content},
        ]

    raw = _clean_text(example.get("text"))
    if not raw:
        return None

    pattern = re.compile(r"<human>\s*(.*?)\s*<bot>\s*(.*?)(?=<human>|$)", re.IGNORECASE | re.DOTALL)
    match = pattern.search(raw)
    if not match:
        return None

    user_content = _clean_text(match.group(1))
    assistant_content = _clean_text(match.group(2))
    if not user_content or not assistant_content:
        return None

    return [
        {"role": "user", "content": user_content},
        {"role": "assistant", "content": assistant_content},
    ]

def normalize_example(example, source_format):
    if source_format == "manta":
        conversations = _normalize_manta_conversations(example.get("conversations"))
    elif source_format == "instruction":
        conversations = _normalize_instruction_turn(example)
    elif source_format == "oig_human_bot":
        conversations = _normalize_oig_pair(example)
    else:
        raise ValueError(f"지원하지 않는 source format: {source_format}")

    if conversations is None:
        return None

    return {"conversations": conversations}

def _conversation_hash(record):
    payload = json.dumps(record["conversations"], ensure_ascii=False, sort_keys=True)
    return hashlib.sha1(payload.encode("utf-8")).hexdigest()

def _load_single_spec(spec, seed, global_seen_hashes):
    dataset_id = spec["id"]
    split = spec["split"]
    target_n = spec["n_samples"]
    source_format = spec["format"]

    raw_ds = load_dataset(dataset_id, split=split).shuffle(seed=seed)
    dataset_columns = list(raw_ds.features.keys())

    selected = []
    scanned = 0
    dropped = 0
    duplicate = 0

    for example in raw_ds:
        scanned += 1
        normalized = normalize_example(example, source_format)
        if normalized is None:
            dropped += 1
            continue

        conv_hash = _conversation_hash(normalized)
        if conv_hash in global_seen_hashes:
            duplicate += 1
            continue

        global_seen_hashes.add(conv_hash)
        selected.append(normalized)
        if len(selected) >= target_n:
            break

    if len(selected) < target_n:
        raise RuntimeError(
            f"{dataset_id}에서 목표 샘플({target_n})을 채우지 못했습니다. "
            f"selected={len(selected)}, scanned={scanned}, dropped={dropped}, duplicate={duplicate}, "
            f"columns={dataset_columns}"
        )

    normalized_ds = Dataset.from_list(selected)
    stats = {
        "id": dataset_id,
        "split": split,
        "format": source_format,
        "target_n": target_n,
        "selected_n": len(selected),
        "scanned_n": scanned,
        "dropped_n": dropped,
        "duplicate_n": duplicate,
        "columns": dataset_columns,
    }
    return normalized_ds, stats

def build_calibration_dataset(specs, seed):
    if not specs:
        raise ValueError("CALIB_DATASETS가 비어 있습니다.")

    specs_sorted = sorted(specs, key=lambda x: x["priority"])
    subset_list = []
    stats_list = []
    global_seen_hashes = set()

    for spec in specs_sorted:
        try:
            subset, stats = _load_single_spec(spec, seed=seed, global_seen_hashes=global_seen_hashes)
            subset_list.append(subset)
            stats_list.append(stats)
        except Exception as exc:
            if spec["id"] == "beomi/KoAlpaca-v1.1a":
                print(f"[WARN] {spec['id']} 로드 실패 -> {KOALPACA_FALLBACK_ID}로 대체: {type(exc).__name__}: {exc}")
                fallback_spec = dict(spec)
                fallback_spec["id"] = KOALPACA_FALLBACK_ID
                fallback_spec["split"] = "train"
                fallback_spec["format"] = "instruction"
                subset, stats = _load_single_spec(fallback_spec, seed=seed, global_seen_hashes=global_seen_hashes)
                stats["fallback_from"] = spec["id"]
                subset_list.append(subset)
                stats_list.append(stats)
                continue
            raise

    mixed = concatenate_datasets(subset_list).shuffle(seed=seed)

    if len(mixed) != NUM_CALIBRATION_SAMPLES:
        raise RuntimeError(
            f"최종 calibration 샘플 수가 다릅니다: expected={NUM_CALIBRATION_SAMPLES}, actual={len(mixed)}"
        )

    for stats in stats_list:
        if stats["selected_n"] < min(5, stats["target_n"]):
            raise RuntimeError(f"정규화 성공 샘플이 부족합니다: {stats}")

    return mixed, stats_list

def _contains_korean(text):
    return any("가" <= ch <= "힣" for ch in text)

def _compute_ko_char_ratio(records):
    total_chars = 0
    ko_chars = 0
    for record in records:
        for turn in record["conversations"]:
            content = turn["content"]
            total_chars += len(content)
            ko_chars += sum(1 for ch in content if "가" <= ch <= "힣")
    if total_chars == 0:
        return 0.0
    return ko_chars / total_chars

ds, calibration_source_stats = build_calibration_dataset(CALIB_DATASETS, seed=CALIBRATION_SEED)

records = list(ds)
ko_char_ratio = _compute_ko_char_ratio(records)
ko_record_ratio = (
    sum(1 for item in records if _contains_korean(json.dumps(item["conversations"], ensure_ascii=False)))
    / max(len(records), 1)
)

print(f"[INFO] 혼합 캘리브레이션 샘플 수: {len(records)}")
print(f"[INFO] 한글 문자 비율: {ko_char_ratio:.4f}")
print(f"[INFO] 한글 포함 샘플 비율: {ko_record_ratio:.4f}")
print("[INFO] 소스별 통계:")
for stats in calibration_source_stats:
    print(f"  - {stats}")

def preprocess(example):
    return {
        "text": tokenizer.apply_chat_template(
            example["conversations"],
            add_generation_prompt=True,
            tokenize=False,
        )
    }

ds = ds.map(preprocess)

print("[INFO] 데이터 전처리 완료")


In [None]:
def build_w8a8_modifier():
    quant_builder_attempts = []

    quant_candidates = [
        {
            "scheme": QUANT_SCHEME,
            "targets": QUANT_TARGETS,
            "ignore": IGNORE,
        },
        {
            "scheme": {QUANT_SCHEME: QUANT_TARGETS},
            "ignore": IGNORE,
        },
    ]

    for idx, kwargs in enumerate(quant_candidates, start=1):
        try:
            modifier = QuantizationModifier(**kwargs)
            return modifier, {
                "builder": "QuantizationModifier",
                "attempts": quant_builder_attempts,
                "selected_kwargs": kwargs,
            }
        except Exception as exc:
            quant_builder_attempts.append(
                {
                    "candidate": idx,
                    "kwargs": kwargs,
                    "error": f"{type(exc).__name__}: {exc}",
                }
            )

    print("[WARN] QuantizationModifier 생성 실패 -> GPTQModifier(scheme=W8A8)로 폴백합니다.")
    modifier = GPTQModifier(
        scheme=QUANT_SCHEME,
        targets=QUANT_TARGETS,
        ignore=IGNORE,
    )
    return modifier, {
        "builder": "GPTQModifier_fallback",
        "attempts": quant_builder_attempts,
        "selected_kwargs": {
            "scheme": QUANT_SCHEME,
            "targets": QUANT_TARGETS,
            "ignore": IGNORE,
        },
    }


print(
    f"[INFO] SmoothQuant + {QUANT_SCHEME} 시작 (scope={QUANT_SCOPE}, strength={SMOOTHING_STRENGTH}, "
    f"samples={NUM_CALIBRATION_SAMPLES}, max_len={MAX_SEQUENCE_LENGTH}, sq_mode={SQ_MODE})"
)

sq_mappings = normalize_sq_mappings(SQ_MAPPINGS)
if SQ_ENABLED:
    print(f"[SQ] normalized mappings={sq_mappings}")

applied_mode = "w8a8_only_fallback"
last_error = None
quant_builder_diag = None
quant_modifier_impl = None

try:
    quant_modifier, quant_builder_diag = build_w8a8_modifier()
    quant_modifier_impl = type(quant_modifier).__name__

    if SQ_ENABLED:
        recipe = [
            SmoothQuantModifier(
                smoothing_strength=SMOOTHING_STRENGTH,
                mappings=sq_mappings,
            ),
            quant_modifier,
        ]
        oneshot(
            model=model,
            dataset=ds,
            recipe=recipe,
            max_seq_length=MAX_SEQUENCE_LENGTH,
            num_calibration_samples=NUM_CALIBRATION_SAMPLES,
        )
        applied_mode = f"smoothquant_{SQ_MODE}+w8a8"
    else:
        print("[WARN] SmoothQuant 비활성 상태라 W8A8-only로 실행합니다.")
        print(f"[WARN] reason: {SQ_DIAG['reason']}")
        recipe = [quant_modifier]
        oneshot(
            model=model,
            dataset=ds,
            recipe=recipe,
            max_seq_length=MAX_SEQUENCE_LENGTH,
            num_calibration_samples=NUM_CALIBRATION_SAMPLES,
        )
except (ValueError, RuntimeError, TypeError) as exc:
    last_error = exc
    print(f"[WARN] SmoothQuant 경로 실패: {type(exc).__name__}: {exc}")
    print("[WARN] W8A8-only 폴백으로 재실행합니다.")

    quant_modifier, quant_builder_diag = build_w8a8_modifier()
    quant_modifier_impl = type(quant_modifier).__name__
    recipe = [quant_modifier]
    oneshot(
        model=model,
        dataset=ds,
        recipe=recipe,
        max_seq_length=MAX_SEQUENCE_LENGTH,
        num_calibration_samples=NUM_CALIBRATION_SAMPLES,
    )
    applied_mode = "w8a8_only_fallback"

print(
    f"[INFO] quantization complete (applied_mode={applied_mode}, "
    f"scheme={QUANT_SCHEME}, scope={QUANT_SCOPE}, modifier={quant_modifier_impl})"
)


In [None]:
os.makedirs(OUT_DIR, exist_ok=True)

model.save_pretrained(OUT_DIR, save_compressed=True)
tokenizer.save_pretrained(OUT_DIR)

quant_recipe = {
    "timestamp_utc": datetime.now(timezone.utc).isoformat(),
    "model_id": MODEL_ID,
    "calibration_sources": CALIB_DATASETS,
    "calibration_seed": CALIBRATION_SEED,
    "calibration_benchmark_exclude": sorted(CALIBRATION_BENCHMARK_EXCLUDE),
    "calibration_source_stats": calibration_source_stats,
    "num_calibration_samples": NUM_CALIBRATION_SAMPLES,
    "max_sequence_length": MAX_SEQUENCE_LENGTH,
    "quant_scheme": QUANT_SCHEME,
    "quant_scope": QUANT_SCOPE,
    "quant_targets": QUANT_TARGETS,
    "ignore": IGNORE,
    "smoothing_strength": SMOOTHING_STRENGTH,
    "sq_mode": SQ_MODE,
    "sq_enabled": SQ_ENABLED,
    "sq_mappings": SQ_MAPPINGS,
    "sq_diag": SQ_DIAG,
    "applied_mode": applied_mode,
    "quant_modifier_impl": quant_modifier_impl,
    "quant_builder_diag": quant_builder_diag,
    "last_error": str(last_error) if last_error is not None else None,
    "out_dir": OUT_DIR,
}

recipe_path = Path(OUT_DIR) / "quant_recipe.json"
recipe_path.write_text(json.dumps(quant_recipe, indent=2, ensure_ascii=False), encoding="utf-8")

print(f"[INFO] quant recipe saved: {recipe_path}")
print(f"[INFO] 모델 저장 완료: {OUT_DIR}")


In [None]:
zip_name = f"/content/drive/MyDrive/LGAimers/submit/{Path(OUT_DIR).name}"
zip_path = Path(zip_name)

zip_path.parent.mkdir(parents=True, exist_ok=True)

from tempfile import TemporaryDirectory
with TemporaryDirectory() as tmpdir:
    tmp_root = Path(tmpdir)
    model_dir = tmp_root / "model"
    shutil.copytree(OUT_DIR, model_dir)
    shutil.make_archive(str(zip_path), "zip", root_dir=tmp_root, base_dir="model")

print(f"[INFO] 생성 완료: {zip_name}.zip")


In [None]:
out_dir = Path(OUT_DIR)
zip_file = Path(f"{zip_name}.zip")

required_files = [
    out_dir / "config.json",
    out_dir / "quant_recipe.json",
]

missing = [str(p) for p in required_files if not p.exists()]
if missing:
    raise FileNotFoundError(f"필수 파일 누락: {missing}")

has_tokenizer_artifact = any(
    (out_dir / name).exists()
    for name in ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"]
)
if not has_tokenizer_artifact:
    raise FileNotFoundError("토크나이저 산출물이 누락되었습니다.")

if not zip_file.exists():
    raise FileNotFoundError(f"zip 파일 누락: {zip_file}")

print("[VERIFY] 산출물 검증 성공")
print(f"[VERIFY] out_dir={out_dir}")
print(f"[VERIFY] zip={zip_file}")


In [None]:
# Smoke test 가이드 (실행 전 수동 적용)
# 1) NUM_CALIBRATION_SAMPLES=64
# 2) CALIB_DATASETS의 n_samples 합계를 64로 맞춤
# 3) 전체 셀 실행 후 applied_mode / 산출물 검증 로그 확인
print("[INFO] smoke test 가이드 셀입니다. 필요 시 설정값을 줄여 별도 실행하세요.")
