In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.chdir("..")

In [3]:
import torch
import json
import numpy as np
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
from tqdm.auto import tqdm
from datasets import load_dataset
from cluster_intrep_repo.utils import initialize_tokenizer, tokenize_blocksworld_generation, THINK_TOKEN
from pathlib import Path


os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

compute_dtype = torch.bfloat16
device   = 'cuda'
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"

In [4]:
tokenizer = initialize_tokenizer(model_id)

In [5]:
blocksworld_type = "4-blocks"

dataset = load_dataset(f"dmitriihook/deepseek-r1-qwen-32b-planning-{blocksworld_type}")["train"]

In [6]:
with open("blocksworld-4-self-probing-parsed-big-v2.json", "r") as f:
    labels_dataset = json.load(f)

# labels_dataset = load_dataset(f"dmitriihook/blocksworld-6-self-probing-parsed")["train"]

In [7]:
labels_dataset[1]

{'idx': 0,
 'line_n': 11,
 'new_text': 'Initial:\n- A (on table) with B on top.\n- D (on table) with C on top.\n\nGoal:\n- A is on D.\n- B is on C.\n- D is on',
 'parsed': {'blocks': [['B', 'A'], ['C', 'D']]}}

In [8]:
from collections import defaultdict
from typing import Optional

def stacks_to_pairs(stacks: list[list[str]]) -> tuple[dict, dict, Optional[str]]:
    above = {}
    below = {}

    for stack in stacks:
        for i, block in enumerate(stack):
            if i == 0:
                above[block] = "sky"
            else:
                above[block] = stack[i - 1]
                below[stack[i - 1]] = block
            below[block] = "table"
        
    return above, below

def check_if_block(block: str) -> bool:
    return block in ["A", "B", "C", "D"]

def check_stacks(stacks: list[list[str]], n_blocks: int) -> bool:
    blocks = set()
    for stack in stacks:
        for block in stack:
            if not check_if_block(block):
                return False
            blocks.add(block)
    return len(blocks) == n_blocks

labels_dict = defaultdict(dict)
for item in labels_dataset:
    idx = item["idx"]
    line_n = item["line_n"]
    parsed = item["parsed"]

    if parsed is None:
        continue

    if "blocks" not in parsed:
        continue
    
    if not check_stacks(parsed["blocks"], 4):
        continue

    labels_dict[idx][line_n] = item


In [9]:
len(labels_dict)

1496

In [10]:
import re

_bad = []

for item in labels_dataset:
    regex = "[A-Z]-[A-Z]"
    if re.search(regex, item["new_text"][:-40]):
        _bad.append(item)
        # print(item["idx"]["new_text"]

print(len(_bad))

3398


In [11]:
_bad[:100]

[{'idx': 2,
  'line_n': 11,
  'new_text': '- B (on table)\n- C-D (stack)\n- A (on table)\n\nI need to move D on top of B. So, I can unstack D from C',
  'parsed': None},
 {'idx': 2,
  'line_n': 15,
  'new_text': '- B-D\n- C\n- A\n\nNext, I need to get C on top of D. So I can pick up C and stack it on D, whic',
  'parsed': None},
 {'idx': 2,
  'line_n': 16,
  'new_text': '- B-D\n- C\n- A\n\nSo, I can pick up C and stack it on D.\n\n7. Pick up C.\n8. Stack C on top',
  'parsed': None},
 {'idx': 2,
  'line_n': 17,
  'new_text': '- C-D-B\n- A on table\n\nBut the goal is A on top of C, which is on top of D, which is on top of B. So I ne',
  'parsed': None},
 {'idx': 2,
  'line_n': 18,
  'new_text': '- C-D-B\n- A on table\n\nSo, I can pick up A and stack it on C.\n\n9. Pick up A.\n10. Stack A on to',
  'parsed': None},
 {'idx': 2,
  'line_n': 20,
  'new_text': '- A-C-D-B\n- Nothing else, since all blocks are used.\n\nWait, but in the initial setup, A was on the table. So after step 10',
  'p

In [12]:
idx = 11
row = dataset[idx]
generation = row["generation"]

text = ""

for line_n, line in enumerate(generation.split("\n\n")):
    text = text + line + "\n\n"
    if line_n < 10 or len(line) < 30:
        continue

    if line_n not in labels_dict[idx]:
        continue

    parsed = labels_dict[idx][line_n]["parsed"]
    new_text = labels_dict[idx][line_n]["new_text"]
    blocks = parsed["blocks"]

    print(line)
    print("-----"*3)
    print(new_text)
    print(blocks)
    print()


Let me try another approach. I need to get A on top of B. So, I need to move A somewhere. But A is under C and D. So, I need to unstack C and D first.
---------------
- A (on table) with C on top, and D on top of C.
- B is on the table.

I need to get A on top of B, so I
[['D', 'C', 'A'], ['B']]

Wait, but D is on top of C, so I can unstack D first, then C.
---------------
- A on the table, with C on top, and D on top of C.

- B on the table.

So, step 1: Unstack D fr
[['D', 'C', 'A'], ['B']]

So, step 1: Unstack D from C. Now, D is in hand, C is on A, B is on table.
---------------
- A has C on top.
- B is on table.
- D is in hand.

I need to put D down. Where? Maybe on the table.

Step 2: P
[['C', 'A'], ['B'], ['D']]

Step 2: Put down D. Now, D is on table, hand empty.
---------------
- A with C on top.
- B on table.
- D on table.

Next, I need to unstack C from A. Since C is clear now (because D was 
[['C', 'A'], ['B'], ['D']]

Step 4: Put down C. Now, C is on table, hand empty.
---

In [13]:
[x for x in labels_dataset if x["idx"] == 10 and x["line_n"] == 10][0]["new_text"][:-30]


'- Block A on Block D\n- Block D on Block B\n- Block B on table\n- Block C on table\n\nGoa'

In [14]:
[x for x in labels_dataset if x["idx"] == 6 and x["line_n"] == 26][0]



{'idx': 6,
 'line_n': 26,
 'new_text': '- A (on table)\n- B (on table) with D on top\n- C (on table)\n\nI need to move D off of B first because I need B to be ',
 'parsed': {'blocks': [['A'], ['D', 'B'], ['C']]}}

In [15]:
labels_dict[6]

{10: {'idx': 6,
  'line_n': 10,
  'new_text': '- C (on table)\n- B (on table) with D on top\n- A (on table)\n\nSo, I need to move things around to get C as the bas',
  'parsed': {'blocks': [['D', 'B'], ['C'], ['A']]}},
 11: {'idx': 6,
  'line_n': 11,
  'new_text': '- A is on the table\n- B has D on top\n- C is on the table\n\nSo, I need to move things around to get B on top o',
  'parsed': {'blocks': [['D', 'B'], ['A'], ['C']]}},
 12: {'idx': 6,
  'line_n': 12,
  'new_text': '- B has D on top.\n- A, C are on the table.\n\nSo, to get B on top of C, I need to move B. But B is on the',
  'parsed': {'blocks': [['D', 'B'], ['A'], ['C']]}},
 13: {'idx': 6,
  'line_n': 13,
  'new_text': '- A (on table)\n- B (on table) with D on top\n- C (on table)\n\nSo, the blocks on the table are A, B,',
  'parsed': {'blocks': [['A'], ['D', 'B'], ['C']]}},
 14: {'idx': 6,
  'line_n': 14,
  'new_text': '- A (on table, clear)\n- B (on table) with D on top (so B is not clear, D is clear)\n- C (on table, clear)

In [16]:
import re
from collections import defaultdict

def parse_blocks(text):
    initial_state = []
    goal_state = []
    
    # Extract the initial conditions and goal state
    initial_match = re.search(r'As initial conditions I have that:(.*?)My goal is for the following to be true:', text, re.DOTALL)
    goal_match = re.search(r'My goal is for the following to be true:(.*?)\n\n', text, re.DOTALL)

    if initial_match:
        initial_conditions = re.findall(r'Block [A-Z] is on top of Block [A-Z]', initial_match.group(1))
        init_table_blocks = re.findall(r'Block ([A-Z]) is on the table', initial_match.group(1))
        initial_state = process_conditions(initial_conditions)

    
    if goal_match:
        goal_conditions = re.findall(r'Block [A-Z] is on top of Block [A-Z]', goal_match.group(1))
        goal_table_blocks = re.findall(r'Block ([A-Z]) is on the table', goal_match.group(1))
        goal_state = process_conditions(goal_conditions)

    
    return (initial_state, init_table_blocks), (goal_state, goal_table_blocks)

def process_conditions(conditions):
    pairs = {}
    
    for cond in conditions:
        block, below = re.findall(r'Block ([A-Z])', cond)
        pairs[block] = below
    
    return pairs


item = dataset[2]["query"]
stmt = item.split("[STATEMENT]")[-1].strip()

initial_state, goal_state = parse_blocks(stmt)
initial_state, goal_state

(({'B': 'C', 'C': 'D'}, ['A', 'D']), ({'A': 'C', 'C': 'D', 'D': 'B'}, []))

In [17]:
def collect_all_blocks(initial_state):
    all_blocks = list(initial_state[0].keys())
    all_blocks.extend(initial_state[1])
    all_blocks.extend(initial_state[0].values())
    return list(set(all_blocks))

def state_to_pairs(state, all_blocks):
    pairs, _ = state
    below = {}

    for block, below_block in pairs.items():
        below[block] = below_block

    for block in all_blocks:
        if block not in below:
            below[block] = "table"

    above = {}

    for block, below_block in below.items():
        if below_block != "table":
            above[below_block] = block

    for block in all_blocks:
        if block not in above:
            above[block] = "sky"
    
    return above, below

In [18]:
n_blocks = 4

In [19]:
def state_compare(above1, below1, above2, below2):
    for block in above1:
        if above1[block] != above2[block]:
            return False
    for block in below1:
        if below1[block] != below2[block]:
            return False
        
    return True

In [20]:
n_rows = 1500

In [21]:
data_to_process = []

take_prob = 0.3

# batch_size = 100

for idx, row in enumerate(tqdm(dataset.select(range(n_rows)))):
    query = row["query"]
    stmt = query.split("[STATEMENT]")[-1].strip()
    initial_state, goal_state = parse_blocks(stmt)
    all_blocks = collect_all_blocks(initial_state)
    i_above, i_below = state_to_pairs(initial_state, all_blocks)
    g_above, g_below = state_to_pairs(goal_state, all_blocks)

    generation = row["generation"]

    text = ""

    # prompts = []

    for line_n, line in enumerate(generation.split("\n\n")):
        text = text + line + "\n\n"
        if line_n < 10 or len(line) < 30:
            continue
        # if line_n >= 20 and len(line) >= 50:
        #     continue

        if line_n not in labels_dict[idx]:
            continue

        above, below = stacks_to_pairs(labels_dict[idx][line_n]["parsed"]["blocks"])
        if state_compare(i_above, i_below, above, below):
            if np.random.rand() > take_prob:
                continue
        if state_compare(g_above, g_below, above, below):
            if np.random.rand() > take_prob:
                continue

        
        _text = text + "Now, the stacks are:\n\n-"
        
        _query = row["distilabel_metadata"]["raw_input_text_generation_0"][0]

        messages = [
            _query,
            {"role": "assistant", "content": _text}
        ]
        tokens = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=False)[:-1]

        data_to_process.append({
            "idx": idx,
            "line_n": line_n,
            "tokens": tokens,
            "above": above,
            "below": below
        })

  0%|          | 0/1500 [00:00<?, ?it/s]

In [22]:
len(data_to_process)

38292

In [23]:
from vllm import LLM

# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

llm = LLM(model=model_id, task="reward", tensor_parallel_size=8)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

INFO 02-24 22:58:37 __init__.py:190] Automatically detected platform cuda.


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

INFO 02-24 22:58:45 config.py:1401] Defaulting to use mp for distributed inference
INFO 02-24 22:58:45 llm_engine.py:234] Initializing a V0 LLM engine (v0.7.2) with config: model='deepseek-ai/DeepSeek-R1-Distill-Qwen-32B', speculative_config=None, tokenizer='deepseek-ai/DeepSeek-R1-Distill-Qwen-32B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=131072, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=8, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B, num_scheduler_steps=1, multi

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[1;36m(VllmWorkerProcess pid=722916)[0;0m INFO 02-24 22:58:46 multiproc_worker_utils.py:229] Worker ready; awaiting tasks


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[1;36m(VllmWorkerProcess pid=722919)[0;0m INFO 02-24 22:58:46 multiproc_worker_utils.py:229] Worker ready; awaiting tasks


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[1;36m(VllmWorkerProcess pid=722924)[0;0m INFO 02-24 22:58:46 multiproc_worker_utils.py:229] Worker ready; awaiting tasks


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[1;36m(VllmWorkerProcess pid=722929)[0;0m INFO 02-24 22:58:46 multiproc_worker_utils.py:229] Worker ready; awaiting tasks


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[1;36m(VllmWorkerProcess pid=722939)[0;0m 

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[1;36m(VllmWorkerProcess pid=722934)[0;0m 

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


INFO 02-24 22:58:46 multiproc_worker_utils.py:229] Worker ready; awaiting tasks
INFO 02-24 22:58:46 multiproc_worker_utils.py:229] Worker ready; awaiting tasks
[1;36m(VllmWorkerProcess pid=722946)[0;0m INFO 02-24 22:58:46 multiproc_worker_utils.py:229] Worker ready; awaiting tasks
INFO 02-24 22:58:51 cuda.py:230] Using Flash Attention backend.
[1;36m(VllmWorkerProcess pid=722946)[0;0m INFO 02-24 22:58:51 cuda.py:230] Using Flash Attention backend.
[1;36m(VllmWorkerProcess pid=722934)[0;0m INFO 02-24 22:58:51 cuda.py:230] Using Flash Attention backend.
[1;36m(VllmWorkerProcess pid=722929)[0;0m INFO 02-24 22:58:51 cuda.py:230] Using Flash Attention backend.
[1;36m(VllmWorkerProcess pid=722939)[0;0m INFO 02-24 22:58:51 cuda.py:230] Using Flash Attention backend.
[1;36m(VllmWorkerProcess pid=722916)[0;0m INFO 02-24 22:58:52 cuda.py:230] Using Flash Attention backend.
[1;36m(VllmWorkerProcess pid=722919)[0;0m [1;36m(VllmWorkerProcess pid=722924)[0;0m INFO 02-24 22:58:52 cuda

Loading safetensors checkpoint shards:   0% Completed | 0/8 [00:00<?, ?it/s]


[1;36m(VllmWorkerProcess pid=722924)[0;0m INFO 02-24 22:59:09 model_runner.py:1115] Loading model weights took 7.5269 GB
[1;36m(VllmWorkerProcess pid=722934)[0;0m INFO 02-24 22:59:09 model_runner.py:1115] Loading model weights took 7.5269 GB
[1;36m(VllmWorkerProcess pid=722929)[0;0m INFO 02-24 22:59:09 model_runner.py:1115] Loading model weights took 7.5269 GB
[1;36m(VllmWorkerProcess pid=722939)[0;0m INFO 02-24 22:59:09 model_runner.py:1115] Loading model weights took 7.5269 GB
[1;36m(VllmWorkerProcess pid=722946)[0;0m INFO 02-24 22:59:11 model_runner.py:1115] Loading model weights took 7.5269 GB
INFO 02-24 22:59:10 model_runner.py:1115] Loading model weights took 7.5269 GB
[1;36m(VllmWorkerProcess pid=722916)[0;0m INFO 02-24 22:59:10 model_runner.py:1115] Loading model weights took 7.5269 GB
[1;36m(VllmWorkerProcess pid=722919)[0;0m INFO 02-24 22:59:11 model_runner.py:1115] Loading model weights took 7.5269 GB


[1;36m(VllmWorkerProcess pid=722934)[0;0m [1;36m(VllmWorkerProcess pid=722919)[0;0m [1;36m(VllmWorkerProcess pid=722929)[0;0m [1;36m(VllmWorkerProcess pid=722916)[0;0m [1;36m(VllmWorkerProcess pid=722939)[0;0m INFO 02-25 00:42:49 multiproc_worker_utils.py:253] Worker exiting
INFO 02-25 00:42:49 multiproc_worker_utils.py:253] Worker exiting
INFO 02-25 00:42:49 multiproc_worker_utils.py:253] Worker exiting
INFO 02-25 00:42:49 multiproc_worker_utils.py:253] Worker exiting
INFO 02-25 00:42:49 multiproc_worker_utils.py:253] Worker exiting
[1;36m(VllmWorkerProcess pid=722946)[0;0m [1;36m(VllmWorkerProcess pid=722924)[0;0m INFO 02-25 00:42:49 multiproc_worker_utils.py:253] Worker exiting
INFO 02-25 00:42:49 multiproc_worker_utils.py:253] Worker exiting


In [24]:
from vllm import TokensPrompt

batch_size = 800

training_data = []

for i in tqdm(range(0, len(data_to_process), batch_size)):
    prompts = data_to_process[i:i + batch_size]
    tokens = [
        TokensPrompt(prompt_token_ids=prompt["tokens"])
        for prompt in prompts
    ]

    output  = llm.encode(tokens)

    for prompt, out in zip(prompts, output):
        idx = prompt["idx"]
        line_n = prompt["line_n"]
        above = prompt["above"]
        below = prompt["below"]

        hidden_state = out.outputs.data
        hidden_states = hidden_state.cpu().numpy().astype(np.float16)[-100:]

        training_data.append({
            "idx": idx,
            "line_n": line_n,
            "hidden_states": hidden_states,
            "above": above,
            "below": below
        })

  0%|          | 0/48 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 800/800 [01:05<00:00, 12.12it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts: 100%|██████████| 800/800 [01:09<00:00, 11.52it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts: 100%|██████████| 800/800 [00:58<00:00, 13.56it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts: 100%|██████████| 800/800 [00:51<00:00, 15.38it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts: 100%|██████████| 800/800 [00:57<00:00, 13.88it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts: 100%|██████████| 800/800 [00:56<00:00, 14.09it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts: 100%|██████████| 800/800 [01:01<00:00, 12.94it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts: 100%|██████████| 800/800 [01:00<00:00, 13.12it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts: 100%|████████

In [32]:
len(training_data)

38292

In [191]:
from torch.utils.data import Dataset

act2int = {
    "put down": 0,
    "pick up": 1,
    "stack": 2,
    "unstack": 3
}

def block2int(block):
    if block == "table":
        return n_blocks
    if block == "sky":
        return n_blocks + 1
    
    return ord(block) - ord("A")

def int2block(i):
    if i == n_blocks:
        return "table"
    if i == n_blocks + 1:
        return "sky"
    
    return chr(i + ord("A"))

n_prev_tokens = 50

def state_to_label(state):
    above, below, hand = state
    label = np.zeros((n_blocks * 2, ), dtype=np.int64)

    # return int(below["C"] == "B")

    for block, below_block in below.items():
        label[block2int(block)] = block2int(below_block)
    for block, above_block in above.items():
        label[block2int(block) + n_blocks] = block2int(above_block)

    return label[0]



class StepProbeDataset(Dataset):
    def __init__(self, items, n_layer):
        self.items = [x for x in items if x["line_n"] > 40]
        self.hidden_states = hidden_states
        self.n_layer = n_layer

    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, idx):
        item = self.items[idx]
        hidden_states = item["hidden_states"][-n_prev_tokens:]
        above, below = item["above"], item["below"]
        # print(above, below)
        return {
            "input": hidden_states,
            "labels": state_to_label((above, below, None))
        }


In [192]:
train_test_split = 0.9    
n_train = int(len(training_data) * train_test_split)

train_items = training_data[:n_train]
test_items = training_data[n_train:]

train_dataset = StepProbeDataset(train_items, 0)
test_dataset = StepProbeDataset(test_items, 0)

In [193]:
class StepProbe(torch.nn.Module):
    def __init__(self, input_size, hidden_size, n_blocks):
        super().__init__()
        # self.fc = torch.nn.Linear(input_size, hidden_size)
        # self.fc2 = torch.nn.Linear(hidden_size, n_blocks * (n_blocks + 2) * 2)
        # self.fc2 = torch.nn.Linear(input_size, n_blocks * (n_blocks + 2) * 2)
        self.fc2 = torch.nn.Linear(input_size, n_blocks + 2)
        # self.dropout = torch.nn.Dropout(0.1)
        
    def forward(self, x):
        # x = self.fc(x)
        # x = torch.nn.functional.relu(x)
        # x = self.dropout(x)
        x = self.fc2(x[:, -1])
        return x
        return x.view(-1, n_blocks + 2, n_blocks * 2)

In [194]:
class GRUProbe(torch.nn.Module):
    def __init__(self, input_size, hidden_size, n_blocks):
        super().__init__()
        self.gru = torch.nn.GRU(input_size, hidden_size, batch_first=True)
        # self.fc = torch.nn.Linear(hidden_size, n_blocks * (n_blocks + 2) * 2)
        self.fc = torch.nn.Linear(hidden_size, n_blocks + 2)

    def forward(self, x):
        x, _ = self.gru(x)
        x = x[:, -1, :]
        x = self.fc(x)

        return x

        return x.view(-1, n_blocks + 2, n_blocks * 2)

In [195]:
class AHProbe(torch.nn.Module):
    def __init__(self, input_size, hidden_size, n_blocks):
        super().__init__()
        self.fc = torch.nn.Linear(input_size, hidden_size)
        # self.fc2 = torch.nn.Linear(hidden_size, n_blocks * (n_blocks + 2) * 2)
        self.fc2 = torch.nn.Linear(hidden_size, n_blocks + 2)

    def forward(self, x):
        x = self.fc(x)
        
        scores = torch.einsum("bpq,brq->bpr", x, x)
        scores = torch.nn.functional.softmax(scores, dim=-1)

        x = torch.einsum("bpq,brp->brq", x, scores)

        x = x[:, -1, :]

        x = self.fc2(x)

        return x
        return x.view(-1, n_blocks + 2, n_blocks * 2)


In [196]:
n_dim = 5120

probe = StepProbe(n_dim, 500, n_blocks).to(device)

In [197]:
import torch
import numpy as np
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from sklearn.metrics import f1_score

def train_probe(probe, train_dataset, test_dataset, patience=100):
    optimizer = AdamW(probe.parameters(), lr=1e-3, weight_decay=1e-2)
    criterion = CrossEntropyLoss()
    train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)

    n_epochs = 500
    best_f1 = float('inf')
    early_stop_counter = 0
    
    for epoch in range(n_epochs):
        probe.train()
        total_loss = 0
        n_samples = 0

        for batch in train_loader:
            optimizer.zero_grad()
            input = batch["input"].to(device).float()
            labels = batch["labels"].to(device)
            
            output = probe(input)

            loss = criterion(output, labels)
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item() * len(batch["input"])
            n_samples += len(batch["input"])

        avg_train_loss = total_loss / n_samples
        
        # Evaluation
        probe.eval()
        with torch.no_grad():
            # block_wise_hits = np.zeros((n_blocks * 2), dtype=np.int64)
            block_wise_hits = 0
            total = 0  
            val_loss = 0
            all_preds = []
            all_labels = []
            
            for batch in test_loader:
                input = batch["input"].to(device).float()
                labels = batch["labels"].to(device)
                
                output = probe(input)

                loss = criterion(output, labels)
                val_loss += loss.item() * len(batch["input"])

                preds = output.argmax(dim=1)  # Assuming classification task
                hits = (preds == labels)
                
                block_wise_hits += hits.sum(dim=0).cpu().numpy()
                total += len(labels)
                
                all_preds.append(preds.cpu().numpy())
                all_labels.append(labels.cpu().numpy())
            
            block_wise_hits = block_wise_hits / total
            
            all_preds = np.concatenate(all_preds)
            all_labels = np.concatenate(all_labels)
            
            # Compute F1 score block-wise
            # block_wise_f1 = np.zeros(n_blocks * 2)
            # for i in range(n_blocks * 2):
            #     block_wise_f1[i] = f1_score(all_labels[:, i], all_preds[:, i], average='macro')
            
            # avg_f1 = block_wise_f1.mean()
            avg_f1 = f1_score(all_labels, all_preds, average='macro')

            val_loss /= total
            
            print(f"Epoch {epoch}, Train Loss: {avg_train_loss:.4f}, Hits: {block_wise_hits.mean():.4f}, F1: {avg_f1:.4f}, Val Loss: {val_loss:.4f}")
        
            # Early Stopping Check
            if avg_f1 > best_f1:
                best_f1 = avg_f1
                early_stop_counter = 0
            else:
                early_stop_counter += 1
            
            if early_stop_counter >= patience:
                print(f"Early stopping triggered at epoch {epoch}")
                break
    
    return block_wise_hits, block_wise_f1


In [198]:
train_probe(probe, train_dataset, test_dataset)

Epoch 0, Train Loss: 1.8730, Hits: 0.5044, F1: 0.4221, Val Loss: 1.2198
Epoch 1, Train Loss: 1.0507, Hits: 0.5449, F1: 0.3693, Val Loss: 1.0306
Epoch 2, Train Loss: 0.8836, Hits: 0.5876, F1: 0.4394, Val Loss: 0.9686
Epoch 3, Train Loss: 0.8268, Hits: 0.5876, F1: 0.4259, Val Loss: 0.9471
Epoch 4, Train Loss: 0.7948, Hits: 0.5921, F1: 0.4449, Val Loss: 0.9336
Epoch 5, Train Loss: 0.7701, Hits: 0.5996, F1: 0.4310, Val Loss: 0.9462
Epoch 6, Train Loss: 0.7453, Hits: 0.6103, F1: 0.4845, Val Loss: 0.9126
Epoch 7, Train Loss: 0.7180, Hits: 0.6223, F1: 0.5101, Val Loss: 0.9068
Epoch 8, Train Loss: 0.7046, Hits: 0.6161, F1: 0.4944, Val Loss: 0.9237
Epoch 9, Train Loss: 0.6843, Hits: 0.6286, F1: 0.5027, Val Loss: 0.8667
Epoch 10, Train Loss: 0.6676, Hits: 0.6446, F1: 0.5425, Val Loss: 0.8682
Epoch 11, Train Loss: 0.6540, Hits: 0.6428, F1: 0.5358, Val Loss: 0.8617
Epoch 12, Train Loss: 0.6441, Hits: 0.6214, F1: 0.5129, Val Loss: 0.8839
Epoch 13, Train Loss: 0.6317, Hits: 0.6441, F1: 0.5375, Val L

KeyboardInterrupt: 

In [168]:
len(tokenizer.tokenize("Now, the stacks are:\n\n-"))

7

In [47]:
torch.save(probe.state_dict(), "probe-4-blocks-inner.pth")

In [55]:
idx = 1200

line_labels = {x["line_n"]: x for x in training_data if x["idx"] == idx}


In [58]:
def label_to_state(label):
    above = {}
    below = {}
    for i in range(n_blocks):
        below_block = int2block(label[i])
        above_block = int2block(label[i + n_blocks])
        block = int2block(i)

        above[block] = above_block
        below[block] = below_block

    return above, below, None

In [64]:
for x in line_labels.values():
    print(x["above"])
    print(x["below"])

    with torch.no_grad():
        input = torch.tensor(x["hidden_states"]).unsqueeze(0).to(device).float()
        output = probe(input)

        preds = output.argmax(dim=1).cpu().numpy().squeeze()


    pred_above, pred_below, _ = label_to_state(preds)
    
    print(pred_above)
    print(pred_below)
    break



{'A': 'sky', 'D': 'A', 'C': 'D', 'B': 'C'}
{'A': 'D', 'D': 'C', 'C': 'B', 'B': 'table'}
{'A': 'sky', 'B': 'sky', 'C': 'sky', 'D': 'A'}
{'A': 'D', 'B': 'table', 'C': 'D', 'D': 'table'}
