# Probe Architecture Comparison

This notebook compares 9 novel probe architectures for predicting model confidence.

**Architectures tested:**
1. Default MLP (baseline)
2. AttentionProbe - Self-attention over hidden state chunks
3. ResidualProbe - Deep MLP with skip connections
4. BottleneckProbe - Compression to low-rank representation
5. MultiHeadProbe - Multiple experts with learned aggregation
6. GatedProbe - GLU-style gating
7. SparseProbe - Top-k dimension selection
8. HeteroscedasticProbe - Per-example uncertainty
9. BilinearProbe - Explicit feature interactions
10. HierarchicalProbe - Multi-scale hierarchical processing (fine → mid → semantic → global)

**Setup:** Run on Google Colab with GPU (Runtime > Change runtime type > T4 GPU)

In [1]:
%cd /content #unicorn

/content


In [2]:
!rm -rf /content/deep-learning
!rm -rf deep-learning
!rm -rf deep-learning* #unicorn

In [3]:
# Install dependencies (Colab)
!pip install -q transformers accelerate bitsandbytes datasets tqdm matplotlib seaborn scikit-learn

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m45.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [4]:
# Clone repo and setup path
!git clone -b maureen --single-branch https://github.com/joshcliu/deep-learning.git #unicorn
%cd deep-learning

import sys
sys.path.insert(0, '.')

Cloning into 'deep-learning'...
remote: Enumerating objects: 355, done.[K
remote: Counting objects: 100% (34/34), done.[K
remote: Compressing objects: 100% (26/26), done.[K
remote: Total 355 (delta 19), reused 19 (delta 8), pack-reused 321 (from 1)[K
Receiving objects: 100% (355/355), 257.29 KiB | 9.90 MiB/s, done.
Resolving deltas: 100% (180/180), done.
/content/deep-learning


In [5]:
# Patch for bfloat16 compatibility (8-bit quantized models)
import numpy as np
import torch
from src.models import extractor as extractor_module

_original_extract_batch = extractor_module.HiddenStateExtractor._extract_batch

def _patched_extract_batch(self, texts, layers, max_length, token_position):
    """Patched to preserve original behavior while safely handling bfloat16."""

    # 1. Tokenization (unchanged)
    encodings = self.tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
    )
    encodings = {k: v.to(self.device) for k, v in encodings.items()}

    # 2. Model forward pass
    with torch.no_grad():
        outputs = self.model(
            **encodings,
            output_hidden_states=True,
            return_dict=True,
        )

    hidden_states = outputs.hidden_states
    batch_hiddens = []

    # 3. Iterate entire batch for each requested layer
    for layer_idx in layers:

        # IMPORTANT: your model uses layer_idx + 1
        layer_hiddens = hidden_states[layer_idx + 1]  # (batch, seq_len, hidden)

        # === TOKEN SELECTION (MUST FOLLOW ORIGINAL LOGIC) ===
        if token_position == "last":
            attention_mask = encodings["attention_mask"]
            seq_lengths = attention_mask.sum(dim=1) - 1
            token_hiddens = layer_hiddens[
                torch.arange(layer_hiddens.size(0), device=self.device),
                seq_lengths
            ]

        elif token_position == "cls":
            token_hiddens = layer_hiddens[:, 0, :]

        elif token_position == "mean":
            attention_mask = encodings["attention_mask"].unsqueeze(-1)
            masked_hiddens = layer_hiddens * attention_mask
            token_hiddens = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)

        else:
            raise ValueError(f"Unknown token_position: {token_position}")

        # === BF16 / INT8 SAFE CONVERSION ===
        token_hiddens = token_hiddens.detach().cpu().to(torch.float32).numpy()

        batch_hiddens.append(token_hiddens)

    # 4. Stack layers: (batch, num_layers, hidden_dim)
    return np.stack(batch_hiddens, axis=1)


# Apply patch
extractor_module.HiddenStateExtractor._extract_batch = _patched_extract_batch
print("Patched HiddenStateExtractor for bfloat16 & 8bit compatibility.")

Patched HiddenStateExtractor for bfloat16 & 8bit compatibility.


## 1. Load Model

In [7]:
from src.models import ModelLoader, HiddenStateExtractor

# Load Mistral 7B (ungated, no approval needed)
model_name = "Qwen/Qwen2.5-7B"
loader = ModelLoader(model_name)
model, tokenizer = loader.load(quantization="8bit", device_map="auto")

print(f"Loaded {model_name}")
print(f"Layers: {loader.config.num_layers}, Hidden dim: {loader.config.hidden_dim}")

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/686 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.95G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.56G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/138 [00:00<?, ?B/s]

Loaded Qwen/Qwen2.5-7B
Layers: 28, Hidden dim: 3584


## 2. Load Dataset

In [8]:
from src.data import MMLUDataset

# Load MMLU validation set
dataset = MMLUDataset(split="validation")
print(f"Loaded {len(dataset)} examples")

# Sample for experiment
NUM_SAMPLES = 2000  # Increased for better AUROC (was 300)
examples = dataset.sample(NUM_SAMPLES, seed=42)
print(f"Using {len(examples)} examples")

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


README.md: 0.00B [00:00, ?B/s]

dataset_infos.json: 0.00B [00:00, ?B/s]

abstract_algebra/test-00000-of-00001.par(…):   0%|          | 0.00/9.96k [00:00<?, ?B/s]

abstract_algebra/validation-00000-of-000(…):   0%|          | 0.00/3.73k [00:00<?, ?B/s]

abstract_algebra/dev-00000-of-00001.parq(…):   0%|          | 0.00/3.45k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


anatomy/test-00000-of-00001.parquet:   0%|          | 0.00/20.1k [00:00<?, ?B/s]

anatomy/validation-00000-of-00001.parque(…):   0%|          | 0.00/5.28k [00:00<?, ?B/s]

anatomy/dev-00000-of-00001.parquet:   0%|          | 0.00/3.50k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/135 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/14 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


astronomy/test-00000-of-00001.parquet:   0%|          | 0.00/28.3k [00:00<?, ?B/s]

astronomy/validation-00000-of-00001.parq(…):   0%|          | 0.00/6.05k [00:00<?, ?B/s]

astronomy/dev-00000-of-00001.parquet:   0%|          | 0.00/4.94k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/152 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/16 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


college_biology/test-00000-of-00001.parq(…):   0%|          | 0.00/31.8k [00:00<?, ?B/s]

college_biology/validation-00000-of-0000(…):   0%|          | 0.00/6.90k [00:00<?, ?B/s]

college_biology/dev-00000-of-00001.parqu(…):   0%|          | 0.00/4.27k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/144 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/16 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


college_chemistry/test-00000-of-00001.pa(…):   0%|          | 0.00/17.9k [00:00<?, ?B/s]

college_chemistry/validation-00000-of-00(…):   0%|          | 0.00/4.87k [00:00<?, ?B/s]

college_chemistry/dev-00000-of-00001.par(…):   0%|          | 0.00/4.04k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/8 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


college_computer_science/test-00000-of-0(…):   0%|          | 0.00/28.1k [00:00<?, ?B/s]

college_computer_science/validation-0000(…):   0%|          | 0.00/6.25k [00:00<?, ?B/s]

college_computer_science/dev-00000-of-00(…):   0%|          | 0.00/6.81k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


college_mathematics/test-00000-of-00001.(…):   0%|          | 0.00/16.6k [00:00<?, ?B/s]

college_mathematics/validation-00000-of-(…):   0%|          | 0.00/5.00k [00:00<?, ?B/s]

college_mathematics/dev-00000-of-00001.p(…):   0%|          | 0.00/5.16k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


college_physics/test-00000-of-00001.parq(…):   0%|          | 0.00/18.6k [00:00<?, ?B/s]

college_physics/validation-00000-of-0000(…):   0%|          | 0.00/6.39k [00:00<?, ?B/s]

college_physics/dev-00000-of-00001.parqu(…):   0%|          | 0.00/4.51k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/102 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


computer_security/test-00000-of-00001.pa(…):   0%|          | 0.00/19.1k [00:00<?, ?B/s]

computer_security/validation-00000-of-00(…):   0%|          | 0.00/6.67k [00:00<?, ?B/s]

computer_security/dev-00000-of-00001.par(…):   0%|          | 0.00/4.33k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


conceptual_physics/test-00000-of-00001.p(…):   0%|          | 0.00/25.0k [00:00<?, ?B/s]

conceptual_physics/validation-00000-of-0(…):   0%|          | 0.00/5.98k [00:00<?, ?B/s]

conceptual_physics/dev-00000-of-00001.pa(…):   0%|          | 0.00/3.96k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/235 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/26 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


electrical_engineering/test-00000-of-000(…):   0%|          | 0.00/17.6k [00:00<?, ?B/s]

electrical_engineering/validation-00000-(…):   0%|          | 0.00/5.08k [00:00<?, ?B/s]

electrical_engineering/dev-00000-of-0000(…):   0%|          | 0.00/4.08k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/145 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/16 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


elementary_mathematics/test-00000-of-000(…):   0%|          | 0.00/41.1k [00:00<?, ?B/s]

elementary_mathematics/validation-00000-(…):   0%|          | 0.00/9.38k [00:00<?, ?B/s]

elementary_mathematics/dev-00000-of-0000(…):   0%|          | 0.00/4.55k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/378 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/41 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


high_school_biology/test-00000-of-00001.(…):   0%|          | 0.00/62.7k [00:00<?, ?B/s]

high_school_biology/validation-00000-of-(…):   0%|          | 0.00/10.6k [00:00<?, ?B/s]

high_school_biology/dev-00000-of-00001.p(…):   0%|          | 0.00/4.94k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/310 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/32 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


high_school_chemistry/test-00000-of-0000(…):   0%|          | 0.00/33.3k [00:00<?, ?B/s]

high_school_chemistry/validation-00000-o(…):   0%|          | 0.00/8.31k [00:00<?, ?B/s]

high_school_chemistry/dev-00000-of-00001(…):   0%|          | 0.00/4.16k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/203 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/22 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


high_school_computer_science/test-00000-(…):   0%|          | 0.00/27.3k [00:00<?, ?B/s]

high_school_computer_science/validation-(…):   0%|          | 0.00/5.28k [00:00<?, ?B/s]

high_school_computer_science/dev-00000-o(…):   0%|          | 0.00/6.54k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/9 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


high_school_mathematics/test-00000-of-00(…):   0%|          | 0.00/33.7k [00:00<?, ?B/s]

high_school_mathematics/validation-00000(…):   0%|          | 0.00/6.99k [00:00<?, ?B/s]

high_school_mathematics/dev-00000-of-000(…):   0%|          | 0.00/4.50k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/270 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/29 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


high_school_physics/test-00000-of-00001.(…):   0%|          | 0.00/33.0k [00:00<?, ?B/s]

high_school_physics/validation-00000-of-(…):   0%|          | 0.00/7.96k [00:00<?, ?B/s]

high_school_physics/dev-00000-of-00001.p(…):   0%|          | 0.00/4.57k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/151 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/17 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


high_school_statistics/test-00000-of-000(…):   0%|          | 0.00/58.0k [00:00<?, ?B/s]

high_school_statistics/validation-00000-(…):   0%|          | 0.00/10.9k [00:00<?, ?B/s]

high_school_statistics/dev-00000-of-0000(…):   0%|          | 0.00/6.07k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/216 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/23 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


machine_learning/test-00000-of-00001.par(…):   0%|          | 0.00/19.7k [00:00<?, ?B/s]

machine_learning/validation-00000-of-000(…):   0%|          | 0.00/6.17k [00:00<?, ?B/s]

machine_learning/dev-00000-of-00001.parq(…):   0%|          | 0.00/5.25k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/112 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


formal_logic/test-00000-of-00001.parquet:   0%|          | 0.00/21.5k [00:00<?, ?B/s]

formal_logic/validation-00000-of-00001.p(…):   0%|          | 0.00/6.56k [00:00<?, ?B/s]

formal_logic/dev-00000-of-00001.parquet:   0%|          | 0.00/4.81k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/126 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/14 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


high_school_european_history/test-00000-(…):   0%|          | 0.00/142k [00:00<?, ?B/s]

high_school_european_history/validation-(…):   0%|          | 0.00/31.6k [00:00<?, ?B/s]

high_school_european_history/dev-00000-o(…):   0%|          | 0.00/22.2k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/165 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/18 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


high_school_us_history/test-00000-of-000(…):   0%|          | 0.00/155k [00:00<?, ?B/s]

high_school_us_history/validation-00000-(…):   0%|          | 0.00/27.3k [00:00<?, ?B/s]

high_school_us_history/dev-00000-of-0000(…):   0%|          | 0.00/17.8k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/204 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/22 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


high_school_world_history/test-00000-of-(…):   0%|          | 0.00/202k [00:00<?, ?B/s]

high_school_world_history/validation-000(…):   0%|          | 0.00/38.5k [00:00<?, ?B/s]

high_school_world_history/dev-00000-of-0(…):   0%|          | 0.00/10.2k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/237 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/26 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


international_law/test-00000-of-00001.pa(…):   0%|          | 0.00/29.5k [00:00<?, ?B/s]

international_law/validation-00000-of-00(…):   0%|          | 0.00/7.12k [00:00<?, ?B/s]

international_law/dev-00000-of-00001.par(…):   0%|          | 0.00/4.96k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/121 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


jurisprudence/test-00000-of-00001.parque(…):   0%|          | 0.00/23.3k [00:00<?, ?B/s]

jurisprudence/validation-00000-of-00001.(…):   0%|          | 0.00/6.21k [00:00<?, ?B/s]

jurisprudence/dev-00000-of-00001.parquet:   0%|          | 0.00/4.05k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/108 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


logical_fallacies/test-00000-of-00001.pa(…):   0%|          | 0.00/23.0k [00:00<?, ?B/s]

logical_fallacies/validation-00000-of-00(…):   0%|          | 0.00/6.52k [00:00<?, ?B/s]

logical_fallacies/dev-00000-of-00001.par(…):   0%|          | 0.00/4.12k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/163 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/18 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


moral_disputes/test-00000-of-00001.parqu(…):   0%|          | 0.00/60.9k [00:00<?, ?B/s]

moral_disputes/validation-00000-of-00001(…):   0%|          | 0.00/10.7k [00:00<?, ?B/s]

moral_disputes/dev-00000-of-00001.parque(…):   0%|          | 0.00/4.41k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/346 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/38 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


moral_scenarios/test-00000-of-00001.parq(…):   0%|          | 0.00/89.8k [00:00<?, ?B/s]

moral_scenarios/validation-00000-of-0000(…):   0%|          | 0.00/14.9k [00:00<?, ?B/s]

moral_scenarios/dev-00000-of-00001.parqu(…):   0%|          | 0.00/5.14k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/895 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


philosophy/test-00000-of-00001.parquet:   0%|          | 0.00/48.6k [00:00<?, ?B/s]

philosophy/validation-00000-of-00001.par(…):   0%|          | 0.00/9.15k [00:00<?, ?B/s]

philosophy/dev-00000-of-00001.parquet:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/311 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/34 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


prehistory/test-00000-of-00001.parquet:   0%|          | 0.00/54.3k [00:00<?, ?B/s]

prehistory/validation-00000-of-00001.par(…):   0%|          | 0.00/9.89k [00:00<?, ?B/s]

prehistory/dev-00000-of-00001.parquet:   0%|          | 0.00/4.62k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/324 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/35 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


professional_law/test-00000-of-00001.par(…):   0%|          | 0.00/1.04M [00:00<?, ?B/s]

professional_law/validation-00000-of-000(…):   0%|          | 0.00/116k [00:00<?, ?B/s]

professional_law/dev-00000-of-00001.parq(…):   0%|          | 0.00/15.1k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/1534 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/170 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


world_religions/test-00000-of-00001.parq(…):   0%|          | 0.00/18.9k [00:00<?, ?B/s]

world_religions/validation-00000-of-0000(…):   0%|          | 0.00/4.94k [00:00<?, ?B/s]

world_religions/dev-00000-of-00001.parqu(…):   0%|          | 0.00/3.30k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/171 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/19 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


econometrics/test-00000-of-00001.parquet:   0%|          | 0.00/24.5k [00:00<?, ?B/s]

econometrics/validation-00000-of-00001.p(…):   0%|          | 0.00/7.02k [00:00<?, ?B/s]

econometrics/dev-00000-of-00001.parquet:   0%|          | 0.00/4.54k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/114 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/12 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


high_school_geography/test-00000-of-0000(…):   0%|          | 0.00/28.2k [00:00<?, ?B/s]

high_school_geography/validation-00000-o(…):   0%|          | 0.00/6.16k [00:00<?, ?B/s]

high_school_geography/dev-00000-of-00001(…):   0%|          | 0.00/3.93k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/198 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/22 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


high_school_government_and_politics/test(…):   0%|          | 0.00/40.2k [00:00<?, ?B/s]

high_school_government_and_politics/vali(…):   0%|          | 0.00/8.27k [00:00<?, ?B/s]

high_school_government_and_politics/dev-(…):   0%|          | 0.00/4.47k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/193 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/21 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


high_school_macroeconomics/test-00000-of(…):   0%|          | 0.00/54.8k [00:00<?, ?B/s]

high_school_macroeconomics/validation-00(…):   0%|          | 0.00/9.89k [00:00<?, ?B/s]

high_school_macroeconomics/dev-00000-of-(…):   0%|          | 0.00/4.04k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/390 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/43 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


high_school_microeconomics/test-00000-of(…):   0%|          | 0.00/38.8k [00:00<?, ?B/s]

high_school_microeconomics/validation-00(…):   0%|          | 0.00/7.22k [00:00<?, ?B/s]

high_school_microeconomics/dev-00000-of-(…):   0%|          | 0.00/3.83k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/238 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/26 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


high_school_psychology/test-00000-of-000(…):   0%|          | 0.00/92.8k [00:00<?, ?B/s]

high_school_psychology/validation-00000-(…):   0%|          | 0.00/15.2k [00:00<?, ?B/s]

high_school_psychology/dev-00000-of-0000(…):   0%|          | 0.00/5.18k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/545 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/60 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


human_sexuality/test-00000-of-00001.parq(…):   0%|          | 0.00/23.2k [00:00<?, ?B/s]

human_sexuality/validation-00000-of-0000(…):   0%|          | 0.00/5.26k [00:00<?, ?B/s]

human_sexuality/dev-00000-of-00001.parqu(…):   0%|          | 0.00/4.08k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/131 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/12 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


professional_psychology/test-00000-of-00(…):   0%|          | 0.00/133k [00:00<?, ?B/s]

professional_psychology/validation-00000(…):   0%|          | 0.00/22.1k [00:00<?, ?B/s]

professional_psychology/dev-00000-of-000(…):   0%|          | 0.00/4.69k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/612 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/69 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


public_relations/test-00000-of-00001.par(…):   0%|          | 0.00/20.6k [00:00<?, ?B/s]

public_relations/validation-00000-of-000(…):   0%|          | 0.00/6.45k [00:00<?, ?B/s]

public_relations/dev-00000-of-00001.parq(…):   0%|          | 0.00/4.43k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/110 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/12 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


security_studies/test-00000-of-00001.par(…):   0%|          | 0.00/114k [00:00<?, ?B/s]

security_studies/validation-00000-of-000(…):   0%|          | 0.00/18.7k [00:00<?, ?B/s]

security_studies/dev-00000-of-00001.parq(…):   0%|          | 0.00/7.49k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/245 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/27 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


sociology/test-00000-of-00001.parquet:   0%|          | 0.00/43.9k [00:00<?, ?B/s]

sociology/validation-00000-of-00001.parq(…):   0%|          | 0.00/8.36k [00:00<?, ?B/s]

sociology/dev-00000-of-00001.parquet:   0%|          | 0.00/4.21k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/201 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/22 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


us_foreign_policy/test-00000-of-00001.pa(…):   0%|          | 0.00/19.5k [00:00<?, ?B/s]

us_foreign_policy/validation-00000-of-00(…):   0%|          | 0.00/5.27k [00:00<?, ?B/s]

us_foreign_policy/dev-00000-of-00001.par(…):   0%|          | 0.00/4.22k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


business_ethics/test-00000-of-00001.parq(…):   0%|          | 0.00/21.6k [00:00<?, ?B/s]

business_ethics/validation-00000-of-0000(…):   0%|          | 0.00/5.09k [00:00<?, ?B/s]

business_ethics/dev-00000-of-00001.parqu(…):   0%|          | 0.00/4.96k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


clinical_knowledge/test-00000-of-00001.p(…):   0%|          | 0.00/40.5k [00:00<?, ?B/s]

clinical_knowledge/validation-00000-of-0(…):   0%|          | 0.00/7.48k [00:00<?, ?B/s]

clinical_knowledge/dev-00000-of-00001.pa(…):   0%|          | 0.00/3.67k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/265 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/29 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


college_medicine/test-00000-of-00001.par(…):   0%|          | 0.00/42.5k [00:00<?, ?B/s]

college_medicine/validation-00000-of-000(…):   0%|          | 0.00/8.99k [00:00<?, ?B/s]

college_medicine/dev-00000-of-00001.parq(…):   0%|          | 0.00/4.84k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/173 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/22 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


global_facts/test-00000-of-00001.parquet:   0%|          | 0.00/11.5k [00:00<?, ?B/s]

global_facts/validation-00000-of-00001.p(…):   0%|          | 0.00/4.19k [00:00<?, ?B/s]

global_facts/dev-00000-of-00001.parquet:   0%|          | 0.00/3.58k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


human_aging/test-00000-of-00001.parquet:   0%|          | 0.00/31.2k [00:00<?, ?B/s]

human_aging/validation-00000-of-00001.pa(…):   0%|          | 0.00/6.28k [00:00<?, ?B/s]

human_aging/dev-00000-of-00001.parquet:   0%|          | 0.00/3.67k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/223 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/23 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


management/test-00000-of-00001.parquet:   0%|          | 0.00/14.7k [00:00<?, ?B/s]

management/validation-00000-of-00001.par(…):   0%|          | 0.00/4.50k [00:00<?, ?B/s]

management/dev-00000-of-00001.parquet:   0%|          | 0.00/3.61k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/103 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


marketing/test-00000-of-00001.parquet:   0%|          | 0.00/37.3k [00:00<?, ?B/s]

marketing/validation-00000-of-00001.parq(…):   0%|          | 0.00/8.21k [00:00<?, ?B/s]

marketing/dev-00000-of-00001.parquet:   0%|          | 0.00/4.28k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/234 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/25 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


medical_genetics/test-00000-of-00001.par(…):   0%|          | 0.00/16.4k [00:00<?, ?B/s]

medical_genetics/validation-00000-of-000(…):   0%|          | 0.00/5.63k [00:00<?, ?B/s]

medical_genetics/dev-00000-of-00001.parq(…):   0%|          | 0.00/3.77k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


miscellaneous/test-00000-of-00001.parque(…):   0%|          | 0.00/98.6k [00:00<?, ?B/s]

miscellaneous/validation-00000-of-00001.(…):   0%|          | 0.00/13.2k [00:00<?, ?B/s]

miscellaneous/dev-00000-of-00001.parquet:   0%|          | 0.00/3.37k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/783 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/86 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


nutrition/test-00000-of-00001.parquet:   0%|          | 0.00/55.0k [00:00<?, ?B/s]

nutrition/validation-00000-of-00001.parq(…):   0%|          | 0.00/9.02k [00:00<?, ?B/s]

nutrition/dev-00000-of-00001.parquet:   0%|          | 0.00/4.99k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/306 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/33 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


professional_accounting/test-00000-of-00(…):   0%|          | 0.00/69.5k [00:00<?, ?B/s]

professional_accounting/validation-00000(…):   0%|          | 0.00/12.9k [00:00<?, ?B/s]

professional_accounting/dev-00000-of-000(…):   0%|          | 0.00/4.89k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/282 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/31 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


professional_medicine/test-00000-of-0000(…):   0%|          | 0.00/125k [00:00<?, ?B/s]

professional_medicine/validation-00000-o(…):   0%|          | 0.00/19.9k [00:00<?, ?B/s]

professional_medicine/dev-00000-of-00001(…):   0%|          | 0.00/8.45k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/272 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/31 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'cais/mmlu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


virology/test-00000-of-00001.parquet:   0%|          | 0.00/27.3k [00:00<?, ?B/s]

virology/validation-00000-of-00001.parqu(…):   0%|          | 0.00/7.05k [00:00<?, ?B/s]

virology/dev-00000-of-00001.parquet:   0%|          | 0.00/3.87k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/166 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/18 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Loaded 1531 examples
Using 1531 examples


## 3. Generate Answers & Check Correctness

In [None]:
from tqdm import tqdm

def generate_answer(model, tokenizer, prompt, max_new_tokens=32):
    """Generate model's answer to a question."""
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,  # Greedy for reproducibility
            pad_token_id=tokenizer.eos_token_id,
        )

    generated = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
    return generated.strip()

def check_correctness(generated: str, example) -> bool:
    """Check if generated answer matches correct answer."""
    correct_answer = example.choices[example.answer]
    correct_letter = chr(65 + example.answer)  # A, B, C, D

    generated_lower = generated.lower().strip()

    # Check for letter match
    if generated_lower.startswith(correct_letter.lower()):
        return True

    # Check for answer text match
    if correct_answer.lower() in generated_lower:
        return True

    return False

# Generate answers for all examples
print("Generating model answers...")
prompts = []
correctness = []

for example in tqdm(examples):
    prompt = example.format_prompt(style="multiple_choice")
    prompts.append(prompt)

    generated = generate_answer(model, tokenizer, prompt)
    is_correct = check_correctness(generated, example)
    correctness.append(int(is_correct))

correctness = np.array(correctness)
print(f"\nAccuracy: {correctness.mean():.1%} ({correctness.sum()}/{len(correctness)})")

Generating model answers...


 20%|█▉        | 301/1531 [18:55<59:19,  2.89s/it]  

## 4. Extract Hidden States

In [None]:
# Extract from multiple layers (quartiles) for multi-scale representation
# This enables hierarchical and attention probes to leverage layer-wise information
num_layers = loader.config.num_layers
LAYERS = [
    num_layers // 4,      # Early layer (Q1)
    num_layers // 2,      # Middle layer (Q2)
    3 * num_layers // 4,  # Late layer (Q3)
    num_layers - 1        # Final layer (Q4)
]
print(f"Extracting from {len(LAYERS)} layers: {LAYERS} (out of {num_layers} total)")

extractor = HiddenStateExtractor(model, tokenizer)
hidden_states = extractor.extract(
    texts=prompts,
    layers=LAYERS,
    batch_size=8,
    show_progress=True,
)

# Shape: (num_examples, num_layers, hidden_dim)
# Flatten to (num_examples, num_layers * hidden_dim) for input to probes
# This gives probes access to multi-scale representations while keeping 2D input format
print(f"Raw hidden states shape: {hidden_states.shape}")
X = hidden_states.reshape(hidden_states.shape[0], -1)
y = correctness

print(f"Flattened hidden states shape: {X.shape}")
print(f"Labels shape: {y.shape}")
print(f"Input dimension increased from {loader.config.hidden_dim} to {X.shape[1]} (4 layers)")

## 5. Train/Val/Test Split

In [None]:
from sklearn.model_selection import train_test_split

# 60% train, 20% val, 20% test
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.4, random_state=42, stratify=y)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp)

print(f"Train: {len(X_train)} (acc: {y_train.mean():.1%})")
print(f"Val:   {len(X_val)} (acc: {y_val.mean():.1%})")
print(f"Test:  {len(X_test)} (acc: {y_test.mean():.1%})")

## 6. Define Architectures

In [None]:
import sys
sys.modules.pop("src.probes", None)
sys.modules.pop("src.probes.architectures", None)

In [None]:
import importlib
import src.probes
importlib.reload(src.probes)

import src.probes.architectures
importlib.reload(src.probes.architectures)

In [None]:
from src.probes import (
    CalibratedProbe,
    build_default_network,
    build_attention_network,
    build_residual_network,
    build_bottleneck_network,
    build_multihead_network,
    build_gated_network,
    build_sparse_network,
    build_heteroscedastic_network,
    build_bilinear_network,
    build_contrastive_network,
    build_hierarchical_network,
)
from src.probes.architectures import build_hierarchical_network #unicorn
import torch.nn.functional as F #fix

INPUT_DIM = X.shape[1]  # 4096 for Mistral

# FIXED
# Convert training data to torch
X_train_t = torch.tensor(X_train, dtype=torch.float32)
y_train_t = torch.tensor(y_train, dtype=torch.float32)

# Separate correct and incorrect examples
correct_vecs = X_train_t[y_train_t == 1]    # (num_correct, 4096)
incorrect_vecs = X_train_t[y_train_t == 0]  # (num_incorrect, 4096)

# Compute direction of errors vs correct predictions
spurious_direction = incorrect_vecs.mean(0) - correct_vecs.mean(0)

# Normalize safest for contrastive penalty
spurious_direction = F.normalize(spurious_direction, dim=0)

# Final shape: (1, INPUT_DIM)
spurious_directions = spurious_direction.unsqueeze(0)


# Define all architectures to test
ARCHITECTURES = {
    "Linear": lambda: build_default_network(INPUT_DIM, hidden_dim=None),
    "Default MLP": lambda: build_default_network(INPUT_DIM, hidden_dim=256),
    "Attention": lambda: build_attention_network(INPUT_DIM, num_chunks=16, num_heads=4),
    "Residual": lambda: build_residual_network(INPUT_DIM, hidden_dim=256, num_blocks=3),
    "Bottleneck": lambda: build_bottleneck_network(INPUT_DIM, bottleneck_dim=64),
    "MultiHead": lambda: build_multihead_network(INPUT_DIM, num_heads=4, head_dim=128),
    "Gated": lambda: build_gated_network(INPUT_DIM, hidden_dim=256, num_layers=2),
    "Sparse (k=256)": lambda: build_sparse_network(INPUT_DIM, k=256, hidden_dim=128),
    "Heteroscedastic": lambda: build_heteroscedastic_network(INPUT_DIM, hidden_dim=256),
    "Bilinear": lambda: build_bilinear_network(INPUT_DIM, num_factors=32, hidden_dim=128),
    "Contrastive": lambda: build_contrastive_network(INPUT_DIM, spurious_directions=spurious_directions, dropout=0.1), # FIXED
    "Hierarchical": lambda: build_hierarchical_network(INPUT_DIM, num_chunks=16, hidden_dim=256, dropout=0.1)
}

print(f"Testing {len(ARCHITECTURES)} architectures")

## 7. Train All Architectures

In [None]:
from sklearn.metrics import roc_auc_score, brier_score_loss

def compute_ece(confidences, labels, num_bins=10):
    """Compute Expected Calibration Error."""
    bin_boundaries = np.linspace(0, 1, num_bins + 1)
    ece = 0.0

    for i in range(num_bins):
        mask = (confidences > bin_boundaries[i]) & (confidences <= bin_boundaries[i + 1])
        if mask.sum() > 0:
            bin_conf = confidences[mask].mean()
            bin_acc = labels[mask].mean()
            ece += mask.sum() * abs(bin_conf - bin_acc)

    return ece / len(confidences)

# Store results
results = {}

# Training settings - NO early stopping, use scheduler for better convergence
NUM_EPOCHS = 200
PATIENCE = None  # Disable early stopping - train for full epochs

for name, build_fn in ARCHITECTURES.items():
    print(f"\n{'='*50}")
    print(f"Training: {name}")
    print('='*50)

    # Build network and probe
    network = build_fn()
    probe = CalibratedProbe(network=network)

    # Count parameters
    num_params = sum(p.numel() for p in probe.parameters())
    print(f"Parameters: {num_params:,}")

    # Train with NO early stopping
    history = probe.fit(
        X_train, y_train,
        X_val, y_val,
        batch_size=32,
        num_epochs=NUM_EPOCHS,
        patience=PATIENCE,  # None = no early stopping
        use_scheduler=True,  # Cosine annealing LR
        verbose=True,
    )

    # Evaluate on test set
    confidences = probe.predict(X_test)
    predictions = (confidences > 0.5).astype(int)

    # Compute metrics
    accuracy = (predictions == y_test).mean()
    auroc = roc_auc_score(y_test, confidences)
    brier = brier_score_loss(y_test, confidences)
    ece = compute_ece(confidences, y_test)

    results[name] = {
        "accuracy": accuracy,
        "auroc": auroc,
        "brier": brier,
        "ece": ece,
        "num_params": num_params,
        "best_epoch": history["best_epoch"],
        "confidences": confidences,
    }

    print(f"\nTest Results:")
    print(f"  Accuracy: {accuracy:.3f}")
    print(f"  AUROC:    {auroc:.3f}")
    print(f"  Brier:    {brier:.4f}")
    print(f"  ECE:      {ece:.4f}")

## 8. Results Comparison

In [None]:
import pandas as pd

# Create results DataFrame
df = pd.DataFrame({
    name: {
        "Accuracy": f"{r['accuracy']:.3f}",
        "AUROC": f"{r['auroc']:.3f}",
        "Brier Score": f"{r['brier']:.4f}",
        "ECE": f"{r['ece']:.4f}",
        "Parameters": f"{r['num_params']:,}",
    }
    for name, r in results.items()
}).T

print("\n" + "="*70)
print("ARCHITECTURE COMPARISON")
print("="*70)
print(df.to_string())

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

names = list(results.keys())
colors = sns.color_palette("husl", len(names))

# 1. AUROC comparison
ax1 = axes[0, 0]
aurocs = [results[n]["auroc"] for n in names]
bars = ax1.barh(names, aurocs, color=colors)
ax1.set_xlabel("AUROC (higher is better)")
ax1.set_title("Discrimination: AUROC")
ax1.set_xlim(0.5, 1.0)
for bar, val in zip(bars, aurocs):
    ax1.text(val + 0.01, bar.get_y() + bar.get_height()/2, f"{val:.3f}", va='center')

# 2. Brier Score comparison
ax2 = axes[0, 1]
briers = [results[n]["brier"] for n in names]
bars = ax2.barh(names, briers, color=colors)
ax2.set_xlabel("Brier Score (lower is better)")
ax2.set_title("Calibration: Brier Score")
for bar, val in zip(bars, briers):
    ax2.text(val + 0.005, bar.get_y() + bar.get_height()/2, f"{val:.4f}", va='center')

# 3. ECE comparison
ax3 = axes[1, 0]
eces = [results[n]["ece"] for n in names]
bars = ax3.barh(names, eces, color=colors)
ax3.set_xlabel("ECE (lower is better)")
ax3.set_title("Calibration: Expected Calibration Error")
for bar, val in zip(bars, eces):
    ax3.text(val + 0.005, bar.get_y() + bar.get_height()/2, f"{val:.4f}", va='center')

# 4. Parameters vs AUROC (efficiency)
ax4 = axes[1, 1]
params = [results[n]["num_params"] for n in names]
for i, name in enumerate(names):
    ax4.scatter(params[i], aurocs[i], s=150, c=[colors[i]], label=name, edgecolors='black')
ax4.set_xlabel("Number of Parameters")
ax4.set_ylabel("AUROC")
ax4.set_title("Efficiency: Parameters vs Performance")
ax4.set_xscale('log')
ax4.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)

plt.tight_layout()
plt.savefig("architecture_comparison.png", dpi=150, bbox_inches='tight')
plt.show()

print("\nSaved: architecture_comparison.png")

## 9. Reliability Diagrams

In [None]:
def plot_reliability_diagram(confidences, labels, title, ax, num_bins=10):
    """Plot reliability diagram."""
    bin_boundaries = np.linspace(0, 1, num_bins + 1)
    bin_centers = (bin_boundaries[:-1] + bin_boundaries[1:]) / 2

    bin_accs = []
    bin_confs = []
    bin_counts = []

    for i in range(num_bins):
        mask = (confidences > bin_boundaries[i]) & (confidences <= bin_boundaries[i + 1])
        if mask.sum() > 0:
            bin_accs.append(labels[mask].mean())
            bin_confs.append(confidences[mask].mean())
            bin_counts.append(mask.sum())
        else:
            bin_accs.append(np.nan)
            bin_confs.append(np.nan)
            bin_counts.append(0)

    # Plot
    ax.bar(bin_centers, bin_accs, width=0.08, alpha=0.7, label='Accuracy')
    ax.plot([0, 1], [0, 1], 'k--', label='Perfect calibration')
    ax.set_xlabel('Confidence')
    ax.set_ylabel('Accuracy')
    ax.set_title(title)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.legend(loc='lower right')

# Plot reliability diagrams for all architectures
fig, axes = plt.subplots(4, 3, figsize=(15, 16))
axes = axes.flatten()

for i, (name, r) in enumerate(results.items()):
    if i < len(axes):
        plot_reliability_diagram(
            r["confidences"], y_test,
            f"{name}\nECE={r['ece']:.4f}",
            axes[i]
        )

# Hide unused subplots
for i in range(len(results), len(axes)):
    axes[i].axis('off')

plt.tight_layout()
plt.savefig("reliability_diagrams.png", dpi=150, bbox_inches='tight')
plt.show()

print("\nSaved: reliability_diagrams.png")

## 10. Summary & Recommendations

In [None]:
# Find best architectures for each metric
best_auroc = max(results.items(), key=lambda x: x[1]["auroc"])
best_brier = min(results.items(), key=lambda x: x[1]["brier"])
best_ece = min(results.items(), key=lambda x: x[1]["ece"])
best_efficiency = max(results.items(), key=lambda x: x[1]["auroc"] / np.log10(x[1]["num_params"]))

print("="*60)
print("RECOMMENDATIONS")
print("="*60)
print(f"\nBest Discrimination (AUROC): {best_auroc[0]}")
print(f"  AUROC = {best_auroc[1]['auroc']:.4f}")

print(f"\nBest Calibration (Brier): {best_brier[0]}")
print(f"  Brier = {best_brier[1]['brier']:.4f}")

print(f"\nBest Calibration (ECE): {best_ece[0]}")
print(f"  ECE = {best_ece[1]['ece']:.4f}")

print(f"\nMost Efficient (AUROC/log(params)): {best_efficiency[0]}")
print(f"  AUROC = {best_efficiency[1]['auroc']:.4f}, Params = {best_efficiency[1]['num_params']:,}")

print("\n" + "="*60)
print("NOTES")
print("="*60)
print("""
- Lower Brier/ECE = better calibration (confidence matches accuracy)
- Higher AUROC = better discrimination (separating correct/incorrect)
- SparseProbe is most interpretable (shows which dimensions matter)
- HeteroscedasticProbe handles varying example difficulty
- For production: prioritize calibration (Brier/ECE)
- For research: prioritize AUROC and interpretability
""")

## 11. Sparse Probe Analysis (Interpretability)

In [None]:
# If SparseProbe performed well, analyze which dimensions it selected
from src.probes.architectures import TopKSparseNetwork

# Retrain sparse probe to access its learned importance weights
sparse_network = build_sparse_network(INPUT_DIM, k=256, hidden_dim=128)
sparse_probe = CalibratedProbe(network=sparse_network)
sparse_probe.fit(X_train, y_train, X_val, y_val, num_epochs=200, patience=None, verbose=False)

# Get importance scores using the new API
importance = sparse_probe.network.get_importance_scores().cpu().numpy()

# Find top dimensions
top_k = 20
top_indices = np.argsort(importance)[-top_k:][::-1]
top_scores = importance[top_indices]

print(f"Top {top_k} most important dimensions for confidence:")
print("="*40)
for idx, score in zip(top_indices, top_scores):
    print(f"  Dim {idx:4d}: importance = {score:.4f}")

# Plot importance distribution
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.bar(range(top_k), top_scores)
plt.xlabel("Rank")
plt.ylabel("Importance Score")
plt.title(f"Top {top_k} Dimensions")
plt.xticks(range(top_k), top_indices, rotation=45)

plt.subplot(1, 2, 2)
plt.hist(importance, bins=50, edgecolor='black')
plt.xlabel("Importance Score")
plt.ylabel("Count")
plt.title("Importance Distribution (All Dimensions)")
# Show threshold for top-256 dimensions
sorted_importance = np.sort(importance)[::-1]
if len(sorted_importance) > 256:
    threshold = sorted_importance[255]
    plt.axvline(threshold, color='r', linestyle='--', label=f'Top 256 threshold ({threshold:.3f})')
    plt.legend()

plt.tight_layout()
plt.savefig("sparse_probe_importance.png", dpi=150)
plt.show()

print("\nSaved: sparse_probe_importance.png")