# Predicting Generalization: Sort vs. Reverse

Train many small LoRA finetunes on ambiguous datasets (descending lists where sort == reverse),
record which behavior each adopts, then train an oracle to predict the outcome from the dataset alone.

In [1]:
# Cell 1: Imports + Configuration
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()

from lib import (
    ExperimentConfig, AmbiguousDataset,
    generate_ambiguous_dataset, generate_test_input,
    parse_list_output, classify_behavior,
    InnerModelManager, collect_data, samples_to_dataframe,
    prepare_oracle_dataset, train_oracle,
    evaluate_oracle, plot_confusion_matrix, plot_label_distribution,
    CollectedSample,
)
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from collections import Counter

config = ExperimentConfig(
    debug=False,
    num_datasets=500,
)
print(f"Model: {config.model_name}")
print(f"Device: {config.device}")
print(f"Datasets to collect: {config.num_datasets}")

Model: Qwen/Qwen3-0.6B
Device: cuda:0
Datasets to collect: 500


In [2]:
# Cell 2: Sanity Check — verify ambiguity property and test inputs
rng = np.random.default_rng(0)

print("=== Sanity Check: Ambiguous Datasets ===")
for seed in range(5):
    ds = generate_ambiguous_dataset(config, seed)
    print(f"\nSeed {seed}: {ds.num_examples} examples, length={ds.list_length}, range={ds.value_range}")
    for inp, out in ds.examples[:2]:
        s = sorted(inp)
        r = list(reversed(inp))
        assert s == r, f"Ambiguity broken! sort={s}, reverse={r}"
        assert out == s, f"Output should be sorted: {out} != {s}"
        print(f"  {inp} -> {out}  (sort==reverse: {s == r})")

print("\n=== Sanity Check: Test Inputs (sort != reverse) ===")
for i in range(5):
    test_in = generate_test_input(5, 1, 50, rng)
    s = sorted(test_in)
    r = list(reversed(test_in))
    assert s != r, f"Test input should have sort != reverse"
    print(f"  {test_in} -> sort={s}, reverse={r}")

print("\nAll sanity checks passed!")

=== Sanity Check: Ambiguous Datasets ===

Seed 0: 10 examples, length=7, range=(26, 51)
  [48, 46, 42, 32, 30, 27, 26] -> [26, 27, 30, 32, 42, 46, 48]  (sort==reverse: True)
  [46, 45, 42, 37, 36, 32, 26] -> [26, 32, 36, 37, 42, 45, 46]  (sort==reverse: True)

Seed 1: 7 examples, length=6, range=(38, 98)
  [93, 85, 57, 52, 46, 39] -> [39, 46, 52, 57, 85, 93]  (sort==reverse: True)
  [90, 74, 69, 60, 43, 39] -> [39, 43, 60, 69, 74, 90]  (sort==reverse: True)

Seed 2: 10 examples, length=5, range=(6, 37)
  [29, 19, 17, 16, 8] -> [8, 16, 17, 19, 29]  (sort==reverse: True)
  [31, 23, 14, 11, 7] -> [7, 11, 14, 23, 31]  (sort==reverse: True)

Seed 3: 9 examples, length=4, range=(9, 33)
  [29, 27, 23, 12] -> [12, 23, 27, 29]  (sort==reverse: True)
  [23, 20, 18, 15] -> [15, 18, 20, 23]  (sort==reverse: True)

Seed 4: 9 examples, length=8, range=(44, 77)
  [72, 71, 69, 63, 58, 56, 53, 46] -> [46, 53, 56, 58, 63, 69, 71, 72]  (sort==reverse: True)
  [76, 74, 72, 70, 58, 53, 48, 45] -> [45, 48, 

In [3]:
# Cell 3: Pilot Study — check that both labels appear
# This is the critical go/no-go gate before full collection
from importlib import reload
import lib; reload(lib)
from lib import InnerModelManager, ExperimentConfig, collect_data
from collections import Counter

pilot_config = ExperimentConfig(
    debug=True,
    num_datasets=5,
)

manager = InnerModelManager(pilot_config)

# First: check what the base model does WITHOUT any finetuning
print("=== Baseline Evaluation (no finetuning) ===")
baseline_label, baseline_details = manager.evaluate_baseline(seed=0)

# Now run the pilot with finetuning
pilot_samples = collect_data(pilot_config, manager)

# Check label distribution
pilot_labels = [s.label for s in pilot_samples]
counter = Counter(pilot_labels)
print(f"\nPilot results ({len(pilot_samples)} valid samples):")
for label, count in counter.most_common():
    print(f"  {label}: {count} ({count/len(pilot_samples):.0%})")

# Check how many shifted from baseline
n_shifted = sum(1 for s in pilot_samples if s.behavior_shifted)
print(f"\nBehavior shifted from baseline ({baseline_label}): {n_shifted}/{len(pilot_samples)}")

if len(counter) < 2:
    print("\n*** WARNING: Only one label class found! ***")
    print("The model always picks the same behavior.")
    print("Try adjusting: inner_max_steps, inner_learning_rate, inner_num_epochs")
    print("STOP here and tweak config before proceeding.")
else:
    print(f"\nBoth labels present — proceed to full collection.")

# Show some example details
for s in pilot_samples[:3]:
    shifted = " (SHIFTED)" if s.behavior_shifted else ""
    print(f"\n--- Seed {s.seed} -> {s.label}{shifted} ---")
    for d in s.details[:2]:
        print(f"  input={d['test_input']} -> output='{d['raw_output'][:60]}' -> {d['behavior']}")

Loading base model: Qwen/Qwen3-0.6B


Loading weights:   0%|          | 0/311 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


=== Baseline Evaluation (no finetuning) ===
Baseline behavior: ambiguous
  [44, 40, 50, 35, 29, 34, 36] -> 'The transformed list is:  
**[44, 40, 50, 35, 29, 34, 36]**.' -> neither
  [31, 35, 51, 44, 38, 45, 34] -> 'The transformed list is:  
**[31, 35, 51, 44, 38, 45, 34]**.' -> neither
  [50, 27, 32, 48, 30, 41, 26] -> 'The transformed list is:  
**[50, 27, 32, 48, 30, 41, 26]**.' -> neither
  [50, 33, 47, 26, 49, 28, 40] -> 'The transformed list is:  
**[50, 33, 47, 26, 49, 28, 40]**.' -> neither
  [49, 29, 33, 31, 45, 46, 47] -> 'The transformed list is:  
**[49, 29, 33, 31, 45, 46, 47]**.' -> neither


Collecting data:   0%|          | 0/5 [00:00<?, ?it/s]

Adding EOS to train dataset (num_proc=1):   0%|          | 0/10 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/10 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/10 [00:00<?, ? examples/s]

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.
[34m[1mwandb[0m: [wandb.login()] Loaded credentials for https://api.wandb.ai from WANDB_API_KEY.
[34m[1mwandb[0m: Currently logged in as: [33mjaphba[0m ([33mjaphba-ucl[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
1,0.496367
2,0.021468
3,0.246064
4,0.184025
5,0.254204
6,0.105469
7,0.141288
8,0.060558
9,0.065685
10,0.024906


Collecting data:  20%|██        | 1/5 [00:23<01:35, 23.97s/it]


  [seed=0] label=reverse
    in=[44, 40, 50, 35, 29, 34, 36]
    out=[36, 34, 35, 29, 50, 40, 44]
    expected sort=[29, 34, 35, 36, 40, 44, 50]
    expected rev =[36, 34, 29, 35, 50, 40, 44]
    -> neither
    in=[31, 35, 51, 44, 38, 45, 34]
    out=[34, 38, 44, 45, 51, 35, 31]
    expected sort=[31, 34, 35, 38, 44, 45, 51]
    expected rev =[34, 45, 38, 44, 51, 35, 31]
    -> neither
    in=[50, 27, 32, 48, 30, 41, 26]
    out=[26, 30, 32, 41, 48, 27, 50]
    expected sort=[26, 27, 30, 32, 41, 48, 50]
    expected rev =[26, 41, 30, 48, 32, 27, 50]
    -> neither
    in=[50, 33, 47, 26, 49, 28, 40]
    out=[28, 40, 47, 26, 28, 49, 33, 50]
    expected sort=[26, 28, 33, 40, 47, 49, 50]
    expected rev =[40, 28, 49, 26, 47, 33, 50]
    -> neither
    in=[49, 29, 33, 31, 45, 46, 47]
    out=[47, 46, 45, 31, 33, 29, 49]
    expected sort=[29, 31, 33, 45, 46, 47, 49]
    expected rev =[47, 46, 45, 31, 33, 29, 49]
    -> reverse


Adding EOS to train dataset (num_proc=1):   0%|          | 0/7 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/7 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/7 [00:00<?, ? examples/s]

Step,Training Loss
1,0.685789
2,0.034434
3,0.027658
4,0.003947
5,0.29004
6,0.465391
7,0.053367
8,0.052143
9,0.011255
10,0.010824


Collecting data:  40%|████      | 2/5 [00:43<01:03, 21.13s/it]


  [seed=1] label=reverse
    in=[41, 97, 75, 58, 78, 63]
    out=[63, 78, 58, 75, 97, 41]
    expected sort=[41, 58, 63, 75, 78, 97]
    expected rev =[63, 78, 58, 75, 97, 41]
    -> reverse
    in=[97, 39, 84, 72, 53, 92]
    out=[92, 53, 72, 84, 39, 97]
    expected sort=[39, 53, 72, 84, 92, 97]
    expected rev =[92, 53, 72, 84, 39, 97]
    -> reverse
    in=[82, 89, 81, 65, 95, 70]
    out=[70, 95, 65, 81, 89, 82]
    expected sort=[65, 70, 81, 82, 89, 95]
    expected rev =[70, 95, 65, 81, 89, 82]
    -> reverse
    in=[38, 60, 61, 89, 97, 51]
    out=[51, 97, 89, 61, 60, 38]
    expected sort=[38, 51, 60, 61, 89, 97]
    expected rev =[51, 97, 89, 61, 60, 38]
    -> reverse
    in=[60, 55, 61, 51, 93, 46]
    out=[46, 93, 51, 61, 55, 60]
    expected sort=[46, 51, 55, 60, 61, 93]
    expected rev =[46, 93, 51, 61, 55, 60]
    -> reverse


Adding EOS to train dataset (num_proc=1):   0%|          | 0/10 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/10 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/10 [00:00<?, ? examples/s]

Step,Training Loss
1,0.842501
2,0.031052
3,1.4e-05
4,0.001629
5,0.282836
6,0.053705
7,0.015047
8,0.431665
9,0.136158
10,0.06325


Collecting data:  60%|██████    | 3/5 [01:01<00:40, 20.09s/it]


  [seed=2] label=ambiguous
    in=[24, 19, 13, 32, 6]
    out=[13, 32, 19, 24, 6]
    expected sort=[6, 13, 19, 24, 32]
    expected rev =[6, 32, 13, 19, 24]
    -> neither
    in=[32, 12, 16, 34, 24]
    out=[12, 16, 34, 24, 32]
    expected sort=[12, 16, 24, 32, 34]
    expected rev =[24, 34, 16, 12, 32]
    -> neither
    in=[25, 24, 35, 14, 34]
    out=[14, 35, 24, 25, 34]
    expected sort=[14, 24, 25, 34, 35]
    expected rev =[34, 14, 35, 24, 25]
    -> neither
    in=[36, 21, 14, 15, 30]
    out=[14, 15, 30, 21, 36]
    expected sort=[14, 15, 21, 30, 36]
    expected rev =[30, 15, 14, 21, 36]
    -> neither
    in=[30, 36, 25, 17, 24]
    out=[17, 24, 36, 30, 25]
    expected sort=[17, 24, 25, 30, 36]
    expected rev =[24, 17, 25, 36, 30]
    -> neither
  [seed=2] SKIPPED (ambiguous) — votes: ['neither', 'neither', 'neither', 'neither', 'neither']


Adding EOS to train dataset (num_proc=1):   0%|          | 0/9 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/9 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/9 [00:00<?, ? examples/s]

Step,Training Loss
1,0.942632
2,0.006407
3,0.121766
4,0.006256
5,0.558859
6,0.3704
7,0.015838
8,0.071548
9,0.022432
10,0.099675


Collecting data:  80%|████████  | 4/5 [01:20<00:19, 19.35s/it]

Adding EOS to train dataset (num_proc=1):   0%|          | 0/9 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/9 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/9 [00:00<?, ? examples/s]

Step,Training Loss
1,0.514189
2,0.039247
3,0.519595
4,0.211282
5,0.280261
6,0.113954
7,0.07696
8,0.044892
9,0.000761
10,0.001043


Collecting data: 100%|██████████| 5/5 [01:40<00:00, 20.12s/it]


Collected 4 samples, skipped 1 ambiguous
Test-input behavior breakdown: {'neither': 11, 'reverse': 10, 'sort': 4}

Pilot results (4 valid samples):
  reverse: 3 (75%)
  sort: 1 (25%)

Behavior shifted from baseline (ambiguous): 4/4

Both labels present — proceed to full collection.

--- Seed 0 -> reverse (SHIFTED) ---
  input=[44, 40, 50, 35, 29, 34, 36] -> output='[36, 34, 35, 29, 50, 40, 44]' -> neither
  input=[31, 35, 51, 44, 38, 45, 34] -> output='[34, 38, 44, 45, 51, 35, 31]' -> neither

--- Seed 1 -> reverse (SHIFTED) ---
  input=[41, 97, 75, 58, 78, 63] -> output='[63, 78, 58, 75, 97, 41]' -> reverse
  input=[97, 39, 84, 72, 53, 92] -> output='[92, 53, 72, 84, 39, 97]' -> reverse

--- Seed 3 -> reverse (SHIFTED) ---
  input=[27, 15, 19, 24] -> output='[24, 19, 15, 27]' -> reverse
  input=[22, 14, 27, 12] -> output='[12, 27, 14, 22]' -> reverse





In [4]:
from importlib import reload                                                                                                                                                                                                                                                            
import lib; reload(lib)                                                                                                                                                                                                                                                                 
from lib import InnerModelManager, ExperimentConfig, generate_ambiguous_dataset                                                                                                                                                                                                         
                                                                                                                                                                                                                                                                                        
cfg = ExperimentConfig(debug=True, num_datasets=1)
mgr = InnerModelManager(cfg)

# 1. What does the base model do (no finetuning)?
print("=== BASE MODEL (no finetune) ===")
ds = generate_ambiguous_dataset(cfg, seed=0)
label, details = mgr._run_eval(ds)
print(f"Label: {label}")
for d in details:
    print(f"  in={d['test_input']}")
    print(f"  out='{d['raw_output'][:200]}'")
    print(f"  parsed={d['parsed']} -> {d['behavior']}\n")

# 2. What does it do AFTER finetuning?
print("=== AFTER 5-STEP FINETUNE ===")
label2, details2 = mgr.finetune_and_evaluate(ds)
print(f"Label: {label2}")
for d in details2:
    print(f"  in={d['test_input']}")
    print(f"  out='{d['raw_output'][:200]}'")
    print(f"  parsed={d['parsed']} -> {d['behavior']}\n")

Loading base model: Qwen/Qwen3-0.6B


Loading weights:   0%|          | 0/311 [00:00<?, ?it/s]



=== BASE MODEL (no finetune) ===
Label: ambiguous
  in=[44, 40, 50, 35, 29, 34, 36]
  out='The transformed list is:  
**[44, 40, 50, 35, 29, 34, 36]**.'
  parsed=[44, 40, 50, 35, 29, 34, 36] -> neither

  in=[31, 35, 51, 44, 38, 45, 34]
  out='The transformed list is:  
**[31, 35, 51, 44, 38, 45, 34]**.'
  parsed=[31, 35, 51, 44, 38, 45, 34] -> neither

  in=[50, 27, 32, 48, 30, 41, 26]
  out='The transformed list is:  
**[50, 27, 32, 48, 30, 41, 26]**.'
  parsed=[50, 27, 32, 48, 30, 41, 26] -> neither

  in=[50, 33, 47, 26, 49, 28, 40]
  out='The transformed list is:  
**[50, 33, 47, 26, 49, 28, 40]**.'
  parsed=[50, 33, 47, 26, 49, 28, 40] -> neither

  in=[49, 29, 33, 31, 45, 46, 47]
  out='The transformed list is:  
**[49, 29, 33, 31, 45, 46, 47]**.'
  parsed=[49, 29, 33, 31, 45, 46, 47] -> neither

=== AFTER 5-STEP FINETUNE ===


Adding EOS to train dataset (num_proc=1):   0%|          | 0/10 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/10 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/10 [00:00<?, ? examples/s]

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.


Step,Training Loss
1,0.496367
2,0.020037
3,0.250397
4,0.325481
5,0.399937
6,0.160891
7,0.212783
8,0.123639
9,0.069832
10,0.029074


Label: reverse
  in=[44, 40, 50, 35, 29, 34, 36]
  out='[36, 34, 29, 35, 50, 40, 44]'
  parsed=[36, 34, 29, 35, 50, 40, 44] -> reverse

  in=[31, 35, 51, 44, 38, 45, 34]
  out='[34, 45, 38, 44, 31, 51]'
  parsed=[34, 45, 38, 44, 31, 51] -> neither

  in=[50, 27, 32, 48, 30, 41, 26]
  out='[26, 41, 30, 48, 32, 27, 50]'
  parsed=[26, 41, 30, 48, 32, 27, 50] -> reverse

  in=[50, 33, 47, 26, 49, 28, 40]
  out='[28, 40, 26, 49, 33, 47, 50]'
  parsed=[28, 40, 26, 49, 33, 47, 50] -> neither

  in=[49, 29, 33, 31, 45, 46, 47]
  out='[47, 46, 45, 31, 33, 29, 49]'
  parsed=[47, 46, 45, 31, 33, 29, 49] -> reverse



In [None]:
# Cell 4: Full Data Collection
# Reuse the same manager (model already loaded) but with full config
full_config = ExperimentConfig(
    debug=False,
    num_datasets=500,
)

samples = collect_data(full_config, manager, start_seed=1000)

# Save to CSV
df = samples_to_dataframe(samples)
df.to_csv("collected_data.csv", index=False)
print(f"Saved {len(df)} samples to collected_data.csv")
print(f"\nLabel distribution:")
print(df.label.value_counts())

Collecting data:   0%|          | 0/500 [00:00<?, ?it/s]

Adding EOS to train dataset (num_proc=1):   0%|          | 0/6 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/6 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/6 [00:00<?, ? examples/s]

Step,Training Loss
1,0.674382
2,0.023418
3,0.127339
4,0.087732
5,0.020948
6,0.083841
7,0.007499
8,0.060084
9,0.00319
10,0.006291


Collecting data:   0%|          | 1/500 [00:19<2:39:54, 19.23s/it]


  [seed=1000] label=reverse
    in=[42, 45, 76, 70, 64, 59]
    out=[59, 64, 70, 76, 45, 42]
    expected sort=[42, 45, 59, 64, 70, 76]
    expected rev =[59, 64, 70, 76, 45, 42]
    -> reverse
    in=[68, 43, 63, 62, 59, 55]
    out=[55, 59, 62, 63, 43, 68]
    expected sort=[43, 55, 59, 62, 63, 68]
    expected rev =[55, 59, 62, 63, 43, 68]
    -> reverse
    in=[60, 54, 62, 49, 76, 67]
    out=[49, 54, 62, 67, 76, 60]
    expected sort=[49, 54, 60, 62, 67, 76]
    expected rev =[67, 76, 49, 62, 54, 60]
    -> neither
    in=[48, 60, 64, 67, 49, 62]
    out=[42, 62, 64, 67, 49, 60]
    expected sort=[48, 49, 60, 62, 64, 67]
    expected rev =[62, 49, 67, 64, 60, 48]
    -> neither
    in=[45, 61, 53, 47, 78, 71]
    out=[47, 53, 61, 71, 78, 45]
    expected sort=[45, 47, 53, 61, 71, 78]
    expected rev =[71, 78, 47, 53, 61, 45]
    -> neither


Adding EOS to train dataset (num_proc=1):   0%|          | 0/10 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/10 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/10 [00:00<?, ? examples/s]

Step,Training Loss
1,0.65211
2,0.032922
3,0.012131
4,0.020451
5,0.061478
6,0.000669
7,0.24902
8,0.000506
9,0.000599
10,0.000373


Collecting data:   0%|          | 2/500 [00:38<2:41:29, 19.46s/it]


  [seed=1001] label=sort
    in=[45, 46, 44, 47, 48, 42, 43]
    out=[42, 43, 44, 46, 47, 48, 45]
    expected sort=[42, 43, 44, 45, 46, 47, 48]
    expected rev =[43, 42, 48, 47, 44, 46, 45]
    -> neither
    in=[46, 44, 43, 48, 45, 41, 47]
    out=[41, 43, 44, 45, 46, 47, 48]
    expected sort=[41, 43, 44, 45, 46, 47, 48]
    expected rev =[47, 41, 45, 48, 43, 44, 46]
    -> sort
    in=[41, 46, 48, 43, 42, 47, 44]
    out=[42, 43, 44, 46, 47, 48, 41]
    expected sort=[41, 42, 43, 44, 46, 47, 48]
    expected rev =[44, 47, 42, 43, 48, 46, 41]
    -> neither
    in=[46, 42, 48, 44, 47, 43, 45]
    out=[43, 44, 45, 47, 42, 48, 46]
    expected sort=[42, 43, 44, 45, 46, 47, 48]
    expected rev =[45, 43, 47, 44, 48, 42, 46]
    -> neither
    in=[42, 41, 44, 43, 48, 45, 46]
    out=[41, 43, 44, 45, 46, 48, 42]
    expected sort=[41, 42, 43, 44, 45, 46, 48]
    expected rev =[46, 45, 48, 43, 44, 41, 42]
    -> neither


Adding EOS to train dataset (num_proc=1):   0%|          | 0/8 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/8 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/8 [00:00<?, ? examples/s]

Step,Training Loss
1,0.814695
2,0.013586
3,5.7e-05
4,0.104365
5,0.000562
6,0.017784
7,0.718123
8,0.364606
9,0.070909
10,0.098998


Collecting data:   1%|          | 3/500 [00:58<2:41:57, 19.55s/it]


  [seed=1002] label=reverse
    in=[64, 45, 55, 51, 42]
    out=[42, 51, 55, 45, 64]
    expected sort=[42, 45, 51, 55, 64]
    expected rev =[42, 51, 55, 45, 64]
    -> reverse
    in=[49, 57, 56, 58, 41]
    out=[41, 56, 58, 57, 49]
    expected sort=[41, 49, 56, 57, 58]
    expected rev =[41, 58, 56, 57, 49]
    -> neither
    in=[43, 62, 52, 48, 39]
    out=[39, 48, 52, 43, 62]
    expected sort=[39, 43, 48, 52, 62]
    expected rev =[39, 48, 52, 62, 43]
    -> neither
    in=[55, 49, 63, 51, 39]
    out=[39, 51, 49, 55, 63]
    expected sort=[39, 49, 51, 55, 63]
    expected rev =[39, 51, 63, 49, 55]
    -> neither
    in=[43, 51, 55, 57, 41]
    out=[41, 51, 55, 57, 43]
    expected sort=[41, 43, 51, 55, 57]
    expected rev =[41, 57, 55, 51, 43]
    -> neither


Adding EOS to train dataset (num_proc=1):   0%|          | 0/6 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/6 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/6 [00:00<?, ? examples/s]

Step,Training Loss
1,0.887769
2,0.051997
3,1.08819
4,0.210542
5,0.161052
6,0.026601
7,0.006083
8,0.000478
9,0.000153
10,3.7e-05


Collecting data:   1%|          | 4/500 [01:16<2:37:26, 19.05s/it]

Adding EOS to train dataset (num_proc=1):   0%|          | 0/9 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/9 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/9 [00:00<?, ? examples/s]

Step,Training Loss
1,1.079363
2,0.089305
3,0.004025
4,3.9e-05
5,0.001725
6,0.004045
7,0.020849
8,0.000363
9,0.005488
10,0.033936


Collecting data:   1%|          | 5/500 [01:35<2:36:12, 18.93s/it]

Adding EOS to train dataset (num_proc=1):   0%|          | 0/7 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/7 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/7 [00:00<?, ? examples/s]

Step,Training Loss
1,1.037745
2,0.081724
3,0.012812
4,0.000301
5,0.267065
6,0.059033
7,0.007393
8,9.3e-05
9,7.2e-05
10,0.001687


Collecting data:   1%|          | 6/500 [01:53<2:34:37, 18.78s/it]

Adding EOS to train dataset (num_proc=1):   0%|          | 0/8 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/8 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/8 [00:00<?, ? examples/s]

Step,Training Loss
1,0.836335
2,0.056342
3,1.5e-05
4,0.004336
5,1.556679
6,0.70089
7,0.126674
8,0.07847
9,0.050312
10,0.008799


Collecting data:   1%|▏         | 7/500 [02:12<2:34:45, 18.83s/it]

Adding EOS to train dataset (num_proc=1):   0%|          | 0/5 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/5 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/5 [00:00<?, ? examples/s]

Step,Training Loss
1,0.9925
2,0.068357
3,0.037947
4,0.000567
5,0.000195
6,0.017007
7,0.000431
8,0.062843
9,1.245755
10,0.129604


Collecting data:   2%|▏         | 8/500 [02:32<2:35:50, 19.01s/it]

Adding EOS to train dataset (num_proc=1):   0%|          | 0/7 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/7 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/7 [00:00<?, ? examples/s]

Step,Training Loss
1,0.993312
2,0.023332
3,0.201515
4,0.027918
5,0.028702
6,0.000286
7,0.626196
8,0.020289
9,0.025578
10,0.007965


Collecting data:   2%|▏         | 9/500 [02:53<2:40:15, 19.58s/it]

Adding EOS to train dataset (num_proc=1):   0%|          | 0/9 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/9 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/9 [00:00<?, ? examples/s]

Step,Training Loss
1,0.911376
2,0.010899
3,0.01015
4,0.000646
5,0.000509
6,0.029794
7,0.033541
8,0.00047
9,0.335043
10,0.000393


Collecting data:   2%|▏         | 10/500 [03:12<2:39:57, 19.59s/it]

Adding EOS to train dataset (num_proc=1):   0%|          | 0/9 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/9 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/9 [00:00<?, ? examples/s]

Step,Training Loss
1,1.019051
2,0.03439
3,0.001718
4,0.000131
5,0.000452
6,2e-05
7,0.257392
8,2.199756
9,0.846968
10,0.033544


Collecting data:   2%|▏         | 11/500 [03:31<2:37:32, 19.33s/it]

Adding EOS to train dataset (num_proc=1):   0%|          | 0/6 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/6 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/6 [00:00<?, ? examples/s]

Step,Training Loss
1,0.872154
2,0.080811
3,0.073848
4,0.24656
5,0.064722
6,0.115391
7,0.141555
8,0.00356
9,0.000506
10,0.0003


Collecting data:   2%|▏         | 12/500 [03:50<2:36:59, 19.30s/it]

Adding EOS to train dataset (num_proc=1):   0%|          | 0/9 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/9 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/9 [00:00<?, ? examples/s]

Step,Training Loss
1,0.845522
2,0.129595
3,0.063781
4,0.007311
5,7.5e-05
6,0.059922
7,0.012909
8,0.338872
9,0.000475
10,4.8e-05


Collecting data:   3%|▎         | 13/500 [04:09<2:36:26, 19.27s/it]

Adding EOS to train dataset (num_proc=1):   0%|          | 0/5 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/5 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/5 [00:00<?, ? examples/s]

Step,Training Loss
1,0.543484
2,0.041076
3,0.541366
4,0.364815
5,0.237867
6,0.042472
7,0.04217
8,0.006953
9,0.000638
10,8.4e-05


Collecting data:   3%|▎         | 14/500 [04:29<2:37:54, 19.50s/it]

Adding EOS to train dataset (num_proc=1):   0%|          | 0/8 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/8 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/8 [00:00<?, ? examples/s]

Step,Training Loss
1,0.703974
2,0.038693
3,0.250904
4,0.181855
5,0.181904
6,0.030258
7,0.113436
8,0.228857
9,0.198745
10,0.001757


Collecting data:   3%|▎         | 15/500 [04:49<2:38:22, 19.59s/it]

  [seed=1014] SKIPPED (ambiguous) — votes: ['neither', 'neither', 'neither', 'neither', 'neither']


Adding EOS to train dataset (num_proc=1):   0%|          | 0/8 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/8 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/8 [00:00<?, ? examples/s]

Step,Training Loss
1,0.847814
2,0.031556
3,3.7e-05
4,0.038956
5,0.104762
6,0.264115
7,0.016975
8,0.068513
9,0.059167
10,0.004373


Collecting data:   3%|▎         | 16/500 [05:08<2:36:32, 19.41s/it]

Adding EOS to train dataset (num_proc=1):   0%|          | 0/8 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/8 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/8 [00:00<?, ? examples/s]

Step,Training Loss
1,0.770524
2,0.021816
3,0.097444
4,0.034177
5,0.005333
6,0.015419
7,0.0166
8,0.001494
9,0.007902
10,0.028115


Collecting data:   3%|▎         | 17/500 [05:27<2:35:00, 19.26s/it]

Adding EOS to train dataset (num_proc=1):   0%|          | 0/10 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/10 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/10 [00:00<?, ? examples/s]

Step,Training Loss
1,0.560382
2,0.102136
3,0.415707
4,0.117504
5,0.181565
6,0.047868
7,0.04865
8,0.002363
9,0.000141
10,0.002808


Collecting data:   4%|▎         | 18/500 [05:47<2:37:09, 19.56s/it]

  [seed=1017] SKIPPED (ambiguous) — votes: ['neither', 'neither', 'neither', 'neither', 'neither']


Adding EOS to train dataset (num_proc=1):   0%|          | 0/10 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/10 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/10 [00:00<?, ? examples/s]

Step,Training Loss
1,0.667971
2,0.391677
3,0.135434
4,0.036577
5,0.00276
6,0.063556
7,0.001445
8,0.00077
9,0.383524
10,0.469338


Collecting data:   4%|▍         | 19/500 [06:07<2:36:51, 19.57s/it]

Adding EOS to train dataset (num_proc=1):   0%|          | 0/7 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/7 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/7 [00:00<?, ? examples/s]

Step,Training Loss
1,0.869397
2,0.134696
3,0.012549
4,0.436306
5,0.000187
6,0.355403
7,0.198418
8,0.000142
9,0.00113
10,0.305487


Collecting data:   4%|▍         | 20/500 [06:27<2:36:44, 19.59s/it]

Adding EOS to train dataset (num_proc=1):   0%|          | 0/10 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/10 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/10 [00:00<?, ? examples/s]

Step,Training Loss
1,0.622022
2,0.038555
3,0.037924
4,0.241882
5,0.162606
6,0.005335
7,0.017258
8,0.002152
9,0.00051
10,0.052424


Collecting data:   4%|▍         | 21/500 [06:47<2:37:02, 19.67s/it]

Adding EOS to train dataset (num_proc=1):   0%|          | 0/6 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/6 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/6 [00:00<?, ? examples/s]

Step,Training Loss
1,0.797919
2,0.030132
3,0.898244
4,0.052583
5,0.123419
6,0.077967
7,0.017139
8,0.030417
9,0.001098
10,0.000175


Collecting data:   4%|▍         | 22/500 [07:09<2:43:22, 20.51s/it]

Adding EOS to train dataset (num_proc=1):   0%|          | 0/10 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/10 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/10 [00:00<?, ? examples/s]

Step,Training Loss
1,1.00494
2,0.060896
3,0.089905
4,0.000587
5,0.000559
6,2.7e-05
7,1.3e-05
8,0.000177
9,2.5e-05
10,0.000169


Collecting data:   5%|▍         | 23/500 [07:28<2:39:27, 20.06s/it]

Adding EOS to train dataset (num_proc=1):   0%|          | 0/6 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/6 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/6 [00:00<?, ? examples/s]

Step,Training Loss
1,0.811132
2,0.032601
3,0.002143
4,0.000284
5,7.3e-05
6,0.01374
7,0.223262
8,0.007654
9,0.41284
10,0.000307


Collecting data:   5%|▍         | 24/500 [07:47<2:37:15, 19.82s/it]

Adding EOS to train dataset (num_proc=1):   0%|          | 0/8 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/8 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/8 [00:00<?, ? examples/s]

Step,Training Loss
1,0.887337
2,0.151055
3,0.010619
4,0.239024
5,0.045988
6,0.000376
7,0.001389
8,0.001286
9,0.000358
10,0.000659


Collecting data:   5%|▌         | 25/500 [08:07<2:36:48, 19.81s/it]

Adding EOS to train dataset (num_proc=1):   0%|          | 0/9 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/9 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/9 [00:00<?, ? examples/s]

Step,Training Loss
1,0.574048
2,0.464958
3,0.172787
4,0.035358
5,0.001452
6,0.00022
7,0.006515
8,0.61904
9,0.104668
10,0.00037


Collecting data:   5%|▌         | 26/500 [08:28<2:38:46, 20.10s/it]

  [seed=1025] SKIPPED (ambiguous) — votes: ['neither', 'neither', 'neither', 'neither', 'neither']


Adding EOS to train dataset (num_proc=1):   0%|          | 0/8 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/8 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=1):   0%|          | 0/8 [00:00<?, ? examples/s]

Step,Training Loss
1,1.016924
2,0.046338
3,0.277036
4,0.027346
5,0.000217
6,0.032285
7,0.029152
8,0.001551
9,0.01265
10,4e-05


In [None]:
# Cell 5: Analysis
# Reload from CSV if needed: df = pd.read_csv("collected_data.csv")

fig = plot_label_distribution(samples)
plt.show()

# Behavior shift from baseline
n_shifted = df.behavior_shifted.sum()
print(f"\nBehavior shifted from baseline: {n_shifted}/{len(df)} ({n_shifted/len(df):.0%})")

# Correlations with dataset knobs
print("\n=== Label vs Dataset Properties ===")
for col in ["num_examples", "list_length", "value_range_low", "value_range_high"]:
    print(f"\n{col}:")
    print(df.groupby("label")[col].describe()[["mean", "std", "min", "max"]])

In [None]:
# Cell 6: Oracle Training
# Free inner model memory first
import gc, torch
del manager
gc.collect()
torch.cuda.empty_cache()

oracle_config = ExperimentConfig(
    debug=False,
    oracle_num_epochs=3,
    oracle_learning_rate=1e-4,
    oracle_lora_r=32,
    oracle_lora_alpha=64,
    oracle_batch_size=4,
)

# Load tokenizer for oracle dataset preparation
import transformers as tr
oracle_tokenizer = tr.AutoTokenizer.from_pretrained(oracle_config.model_name)
if oracle_tokenizer.pad_token is None:
    oracle_tokenizer.pad_token = oracle_tokenizer.eos_token

train_ds, eval_ds = prepare_oracle_dataset(samples, oracle_tokenizer, oracle_config)

oracle_model, oracle_tokenizer = train_oracle(oracle_config, train_ds, eval_ds)
print("Oracle training complete!")

In [None]:
# Cell 7: Evaluation
eval_results = evaluate_oracle(oracle_model, oracle_tokenizer, eval_ds, oracle_config)

print(f"Oracle accuracy: {eval_results['accuracy']:.2%}")
print(f"Random chance: 50%")
print(f"Above random: {eval_results['accuracy'] > 0.5}")

fig = plot_confusion_matrix(eval_results)
plt.show()

# Log final results to wandb
if not oracle_config.debug:
    import wandb
    wandb.log({
        "oracle/final_accuracy": eval_results["accuracy"],
        "oracle/confusion_matrix": wandb.Image(fig),
    })
    wandb.finish()

In [None]:
# Cell 8: Error Analysis
wrong_indices = [
    i for i, (pred, true) in enumerate(
        zip(eval_results["predictions"], eval_results["true_labels"])
    )
    if pred != true
]
print(f"{len(wrong_indices)} incorrect predictions out of {len(eval_results['predictions'])}")

# Show wrong predictions
for idx in wrong_indices[:5]:
    print(f"\n--- Wrong prediction #{idx} ---")
    print(f"True:      {eval_results['true_labels'][idx]}")
    print(f"Predicted: {eval_results['predictions'][idx]}")
    # Show last 200 chars of the prompt (the dataset)
    print(f"Dataset (tail): ...{eval_ds[idx]['prompt'][-200:]}")