In [None]:
import torch
import gzip
import pickle
import numpy as np
import os

from transformers import (
    AutoModelForSequenceClassification,
    AutoConfig,
    AutoTokenizer,
    DataCollatorWithPadding,
)
from torch.utils.data import DataLoader
from datasets import load_dataset
from tqdm import tqdm
import evaluate

# ================= CONFIG =================
MODEL_PATH = "compressed_models_structured/google_electra-small-discriminator_qnli.pkl.gz"
BASE_MODEL = "google/electra-small-discriminator"   # must match compressed model

TASK_NAME = "qnli"
NUM_LABELS = 2
BATCH_SIZE = 32
MAX_LEN = 128

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# =========================================


def unpack_int4_to_tensor(packed_data, original_count):
    packed = torch.from_numpy(packed_data)
    high = (packed >> 4).to(torch.int8)
    low = (packed & 0x0F).to(torch.int8)

    unpacked = torch.empty(len(high) * 2, dtype=torch.int8)
    unpacked[0::2] = high
    unpacked[1::2] = low
    unpacked = unpacked - 8

    return unpacked[:original_count].float()


def load_compressed_model(file_path, base_model_name):
    print(f"\nðŸ“¦ Loading compressed model: {file_path}")

    config = AutoConfig.from_pretrained(
        base_model_name,
        num_labels=NUM_LABELS,
        finetuning_task=TASK_NAME,
    )

    model = AutoModelForSequenceClassification.from_pretrained(
        base_model_name,
        config=config,
    )

    with gzip.open(file_path, "rb") as f:
        wrapper = pickle.load(f)

    weights = wrapper["weights"]
    state_dict = model.state_dict()

    print("ðŸ”“ Decompressing layers...")
    for name, data in weights.items():
        if isinstance(data, dict):
            packed = data["data"]
            scale = data["scale"]
            row_map = data["row_map"]
            orig_shape = data["original_shape"]
            fmt = data["format"]

            count = len(row_map) * orig_shape[1]

            flat = (
                unpack_int4_to_tensor(packed, count)
                if "int4" in fmt
                else torch.from_numpy(packed).float().flatten()
            )

            rows = flat.view(len(row_map), orig_shape[1])
            scale = torch.tensor(scale, dtype=torch.float32)
            full = torch.zeros(orig_shape)
            full[row_map] = rows * scale
            state_dict[name] = full
        else:
            state_dict[name] = torch.from_numpy(data)

    missing, _ = model.load_state_dict(state_dict, strict=False)

    if any(k.startswith("classifier.") for k in missing):
        print("\nðŸš¨ CRITICAL WARNING ðŸš¨")
        print("Classifier head is MISSING.")
        print("You compressed a BASE model, not a QNLI-finetuned model.")
        print("Accuracy will be ~50%.\n")
    else:
        print("âœ… Classifier head loaded correctly.")

    return model


def main():
    model = load_compressed_model(MODEL_PATH, BASE_MODEL).to(DEVICE)
    model.eval()

    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = tokenizer.eos_token_id

    print("\nðŸ“Š Loading GLUE/QNLI")
    raw = load_dataset("nyu-mll/glue", TASK_NAME)

    metric = evaluate.load("glue", TASK_NAME)

    def preprocess(batch):
        return tokenizer(
            batch["question"],
            batch["sentence"],
            truncation=True,
            padding="max_length",
            max_length=MAX_LEN,
        )

    dataset = raw["validation"].map(preprocess, batched=True)
    dataset = dataset.rename_column("label", "labels")

    keep_cols = ["input_ids", "attention_mask", "labels"]
    if "token_type_ids" in dataset.column_names:
        keep_cols.append("token_type_ids")

    dataset.set_format("torch", columns=keep_cols)

    dataloader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        collate_fn=DataCollatorWithPadding(tokenizer),
    )

    print("ðŸš€ Evaluating...")
    for batch in tqdm(dataloader):
        labels = batch["labels"].to(DEVICE)
        batch = {k: v.to(DEVICE) for k, v in batch.items() if k != "labels"}

        with torch.no_grad():
            outputs = model(**batch)

        preds = torch.argmax(outputs.logits, dim=-1)
        metric.add_batch(predictions=preds, references=labels)

    result = metric.compute()
    print(f"\nðŸŽ¯ Final QNLI Accuracy: {result['accuracy']:.4f}")


if __name__ == "__main__":
    main()


  from .autonotebook import tqdm as notebook_tqdm



ðŸ“¦ Loading compressed model: compressed_models_structured/google_electra-small-discriminator_mnli.pkl.gz


Loading weights: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 199/199 [00:00<00:00, 2490.39it/s, Materializing param=electra.encoder.layer.11.output.dense.weight]              
ElectraForSequenceClassification LOAD REPORT from: google/electra-small-discriminator
Key                                               | Status     | 
--------------------------------------------------+------------+-
discriminator_predictions.dense_prediction.bias   | UNEXPECTED | 
discriminator_predictions.dense.weight            | UNEXPECTED | 
discriminator_predictions.dense.bias              | UNEXPECTED | 
discriminator_predictions.dense_prediction.weight | UNEXPECTED | 
classifier.dense.weight                           | MISSING    | 
classifier.out_proj.bias                          | MISSING    | 
classifier.out_proj.weight                        | MISSING    | 
classifier.dense.bias                             | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; n

ðŸ”“ Decompressing layers...


RuntimeError: Error(s) in loading state_dict for ElectraForSequenceClassification:
	size mismatch for classifier.out_proj.weight: copying a param with shape torch.Size([3, 256]) from checkpoint, the shape in current model is torch.Size([2, 256]).
	size mismatch for classifier.out_proj.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([2]).