In [21]:
import torch
import json
import os
from transformers import AutoModelForCausalLM, AutoModelForMaskedLM, AutoConfig, BitsAndBytesConfig

def run_quantization_lab(model_id, mode="4bit", double_quant=True, quant_type="nf4"):
    """
    Precision-aware workbench that detects native dtype (FP32/FP16/BF16),
    quantizes the model, and saves with standardized HF initials.
    """
    print(f"\n{'='*60}\nüöÄ QUANTIZATION LAB: {model_id}\n{'='*60}")

    # --- 1. DETECT NATIVE DTYPE & CALC ORIGINAL FOOTPRINT ---
    config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
    
    # Check what the model's actual native dtype is
    native_dtype = getattr(config, "torch_dtype", torch.float32)
    if native_dtype is None: native_dtype = torch.float32 # Default if not specified
    
    # Virtual load on 'meta' device using the native dtype
    with torch.device("meta"):
        is_masked = any(arch in str(config.architectures) for arch in ["Masked", "Bert", "Roberta"])
        model_class = AutoModelForMaskedLM if is_masked else AutoModelForCausalLM
        temp_model = model_class.from_config(config, torch_dtype=native_dtype)
    
    orig_mem = temp_model.get_memory_footprint() / 1e6
    dtype_name = str(native_dtype).split('.')[-1].upper()
    print(f"üì¶ Original Footprint ({dtype_name}): {orig_mem:.2f} MB")

    # --- 2. CONFIGURATION & NAMING ---
    dq_init = "-dq" if (mode == "4bit" and double_quant) else ""
    type_init = f"-{quant_type}" if mode == "4bit" else ""
    # Generates name like: bert-base-uncased-bnb-4bit-nf4-dq
    hf_save_name = f"{model_id.split('/')[-1]}-bnb-{mode}{type_init}{dq_init}"
    
    if mode == "4bit":
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type=quant_type,
            bnb_4bit_use_double_quant=double_quant,
            bnb_4bit_compute_dtype=torch.float16 # Standard for 4bit compute
        )
    elif mode == "8bit":
        bnb_config = BitsAndBytesConfig(load_in_8bit=True)
    else:
        bnb_config = None

    # --- 3. LOAD QUANTIZED MODEL ---
    print(f"üõ†Ô∏è  Loading in {mode} mode...")
    model = model_class.from_pretrained(
        model_id,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True
    )

    # --- 4. FINAL PRECISION-AWARE REPORT ---
    quant_mem = model.get_memory_footprint() / 1e6
    reduction = ((orig_mem - quant_mem) / orig_mem) * 100

    print(f"\nüìä COMPARISON:")
    print(f"üì¶ Original ({dtype_name}):  {orig_mem:.2f} MB")
    print(f"‚úÖ Quantized:           {quant_mem:.2f} MB")
    print(f"üìâ Actual Reduction:    {reduction:.1f}%")
    print(f"üè∑Ô∏è  Suggested HF Name:  {hf_save_name}")

    # --- 5. SAVE WITH METADATA ---
    save_path = f"./{hf_save_name}"
    print(f"\nüíæ Saving to {save_path}...")
    model.save_pretrained(save_path)
    
    # Inject bnb_config into config.json for auto-loading
    config_file = os.path.join(save_path, "config.json")
    with open(config_file, "r") as f:
        saved_json = json.load(f)
    
    saved_json["quantization_config"] = bnb_config.to_dict() if bnb_config else {}
    
    with open(config_file, "w") as f:
        json.dump(saved_json, f, indent=2)
    
    print(f"‚úÖ Successfully saved as {hf_save_name} with metadata injection!")
    print(f"{'='*60}\n")
    return model

# --- EXECUTION ---
model = run_quantization_lab(
    model_id="bert-base-uncased", 
    mode="4bit", 
    double_quant=True, 
    quant_type="nf4"
)


üöÄ QUANTIZATION LAB: bert-base-uncased
üì¶ Original Footprint (FLOAT32): 438.07 MB
üõ†Ô∏è  Loading in 4bit mode...


Loading weights: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 202/202 [00:00<00:00, 739.86it/s, Materializing param=cls.predictions.transform.dense.weight]                 
[1mBertForMaskedLM LOAD REPORT[0m from: bert-base-uncased
Key                         | Status     |  | 
----------------------------+------------+--+-
bert.pooler.dense.weight    | UNEXPECTED |  | 
cls.seq_relationship.bias   | UNEXPECTED |  | 
cls.seq_relationship.weight | UNEXPECTED |  | 
bert.pooler.dense.bias      | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m



üìä COMPARISON:
üì¶ Original (FLOAT32):  438.07 MB
‚úÖ Quantized:           138.73 MB
üìâ Actual Reduction:    68.3%
üè∑Ô∏è  Suggested HF Name:  bert-base-uncased-bnb-4bit-nf4-dq

üíæ Saving to ./bert-base-uncased-bnb-4bit-nf4-dq...


Writing model shards: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:00<00:00,  9.10it/s]


‚úÖ Successfully saved as bert-base-uncased-bnb-4bit-nf4-dq with metadata injection!

