In [None]:
import torch
import onnx
import numpy as np
import pandas as pd
import onnxruntime
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer, AutoConfig
from torch.utils.data import DataLoader, Dataset
from onnxruntime.quantization import quantize_static, QuantType, CalibrationDataReader

# ======= Load Model Properly =======
MODEL_PATH = "/kaggle/input/distilbertforclimatedisinfo/pytorch/default/1/distilbert_trained.pth"
config = AutoConfig.from_pretrained("distilbert-base-uncased", num_labels=8)
model = DistilBertForSequenceClassification(config)
state_dict = torch.load(MODEL_PATH, map_location=torch.device('cpu'))
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
print("Missing keys:", missing_keys)
print("Unexpected keys:", unexpected_keys)

# Check if classifier weights were loaded
if "classifier.weight" in missing_keys:
    print("⚠️ Warning: Classifier weights missing! Ensure fine-tuned model was saved correctly.")

model.eval()

# ======= Convert to ONNX =======
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
dummy_text = "This is a dummy input for ONNX conversion."
dummy_inputs = tokenizer(dummy_text, return_tensors="pt", max_length=365, padding="max_length", truncation=True)

onnx_path = "distilbert_model.onnx"
torch.onnx.export(
    model, 
    (dummy_inputs["input_ids"], dummy_inputs["attention_mask"]),
    onnx_path,
    input_names=["input_ids", "attention_mask"],
    output_names=["logits"],
    dynamic_axes={"input_ids": {0: "batch_size"}, "attention_mask": {0: "batch_size"}},
    opset_version=14
)
print(f"✅ Model exported to {onnx_path}")

# ======= Static Quantization with Calibration =======
class ONNXCalibrationDataset(CalibrationDataReader):
    def __init__(self, dataloader):
        self.dataloader = iter(dataloader)
    
    def get_next(self):
        try:
            batch = next(self.dataloader)
            return {
                "input_ids": np.array(batch["input_ids"], dtype=np.int64),
                "attention_mask": np.array(batch["attention_mask"], dtype=np.float32),
            }
        except StopIteration:
            return None

    def rewind(self):
        self.dataloader = iter(val_loader)

# Load Data
df = pd.read_parquet("/kaggle/input/test-parquet/test-00000-of-00001.parquet")
df['label_int'] = df['label'].str.split("_").str[0].astype(int)

class QuotesDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: np.array(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = np.array(self.labels[idx], dtype=np.int64)
        return item

    def __len__(self):
        return len(self.labels)

def encode_data(tokenizer, texts, labels, max_length):
    encodings = tokenizer(
        texts.tolist(), 
        truncation=True, 
        padding='max_length', 
        max_length=max_length, 
        return_tensors='np'
    )
    return QuotesDataset(encodings, labels.tolist())

MAX_LENGTH = 365
texts, labels = df['quote'], df['label_int']
val_dataset = encode_data(tokenizer, texts, labels, MAX_LENGTH)

num_workers = min(2, torch.get_num_threads())
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=num_workers)

calibration_reader = ONNXCalibrationDataset(val_loader)
quantized_onnx_path = "distilbert_quantized_05.onnx"

try:
    quantize_static(
        model_input=onnx_path, 
        model_output=quantized_onnx_path,
        quant_format="QDQ",
        weight_type=QuantType.QInt8,
        data_reader=calibration_reader
    )
    print(f"✅ Quantized model saved at {quantized_onnx_path}")
    
except TypeError as e:
    print("⚠️ Warning: Static quantization failed. Falling back to dynamic quantization.")

    from onnxruntime.quantization import quantize_dynamic
    quantized_onnx_path = "distilbert_quantized_dynamic.onnx"

    quantize_dynamic(
        model_input=onnx_path,
        model_output=quantized_onnx_path,
        weight_type=QuantType.QInt8
    )

    print(f"✅ Dynamic quantized model saved at {quantized_onnx_path}")


# ======= Optimized ONNX Evaluation =======
session = onnxruntime.InferenceSession(quantized_onnx_path, providers=["CPUExecutionProvider"])

def evaluate_onnx_model(session, val_loader):
    print("⚡ Starting ONNX evaluation...")
    total_correct, total_samples = 0, 0
    for batch_idx, batch in enumerate(val_loader):
        print(f"📝 Processing batch {batch_idx+1}...")
        inputs = {
            "input_ids": np.array(batch["input_ids"], dtype=np.int64),
            "attention_mask": np.array(batch["attention_mask"], dtype=np.int64)
        }
        outputs = session.run(["logits"], inputs)
        predictions = np.argmax(outputs[0], axis=1)

        labels = np.array(batch["labels"], dtype=np.int64)

        batch_correct = np.sum(predictions == labels)
        total_correct += batch_correct
        total_samples += len(labels)
        print(f"✔️ Batch {batch_idx+1}: {batch_correct}/{len(labels)} correct")

    accuracy = total_correct / total_samples
    print(f"ONNX Quantized Model Accuracy: {accuracy:.4f} ({total_correct}/{total_samples})")
    return accuracy

quantized_accuracy = evaluate_onnx_model(session, val_loader)
