In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

In [3]:
os.chdir("..")

In [4]:
import torch
import numpy as np
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
from tqdm.auto import tqdm
from datasets import load_dataset
from cluster_intrep_repo.utils import initialize_tokenizer, tokenize_blocksworld_generation, THINK_TOKEN



os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

compute_dtype = torch.bfloat16
device   = 'cuda'
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"

In [5]:
tokenizer = initialize_tokenizer(model_id)

In [6]:
blocksworld_type = "4-blocks"

dataset = load_dataset(f"dmitriihook/deepseek-r1-qwen-32b-planning-{blocksworld_type}")["train"]

In [7]:
model     = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=compute_dtype, attn_implementation="sdpa", device_map="auto")

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [8]:
n_rows = 1500

In [9]:
from collections import defaultdict

# [src; dest]

layer_hidden_states = defaultdict(list)

n_last_layers = 10

for row in tqdm(dataset.select(range(n_rows))):
    generation = row["generation"]

    if "[PLAN END]" not in generation:
        for j in range(n_last_layers):
            layer_hidden_states[j].append(None) 
        continue

    chat = tokenize_blocksworld_generation(tokenizer, row)

    # think_pos = torch.where(chat.squeeze() == THINK_TOKEN)[0]

    with torch.no_grad():
        outputs = model(chat.to(device), output_hidden_states=True)

        for j in range(n_last_layers):
            hidden_states = outputs.hidden_states[-1 - j]
            layer_hidden_states[j].append(hidden_states[0].to(torch.float16).cpu().numpy())

  0%|          | 0/1500 [00:00<?, ?it/s]

In [10]:
for j in range(n_last_layers):
    # layer_hidden_states[j] = [x for x in layer_hidden_states[j] if x is not None]
    print(len(layer_hidden_states[j]))

1500
1500
1500
1500
1500
1500
1500
1500
1500
1500


In [21]:
def extract_actions(row):
    generation = row["generation"]
    if "[PLAN]" not in generation:
        return None
    if "[PLAN END]" not in generation:
        return None
    
    plan_start = generation.index("[PLAN]") + len("[PLAN]")
    plan = generation[plan_start:].strip()
    plan = plan.split("[PLAN END]")[0].strip()
    actions = plan.split("\n")

    return actions
    
extract_actions(dataset[0])

['unstack Block C from on top of Block D',
 'put down Block C',
 'unstack Block B from on top of Block A',
 'stack Block B on top of Block C',
 'pick up Block D',
 'stack Block D on top of Block B',
 'pick up Block A',
 'stack Block A on top of Block D']

In [22]:
import re

def parse_block_actions(commands):
    actions = ["unstack", "put down", "pick up", "stack"]
    parsed_commands = []

    for command in commands:
        for action in actions:
            if command.startswith(action):
                blocks = re.findall(r'Block [A-Z]', command)
                blocks = [block.split()[-1] for block in blocks]  # Extract only the letter
                parsed_commands.append((action, blocks))
                break

    return parsed_commands

parse_block_actions(extract_actions(dataset[0]))

[('unstack', ['C', 'D']),
 ('put down', ['C']),
 ('unstack', ['B', 'A']),
 ('stack', ['B', 'C']),
 ('pick up', ['D']),
 ('stack', ['D', 'B']),
 ('pick up', ['A']),
 ('stack', ['A', 'D'])]

In [23]:
import re
from collections import defaultdict

def parse_blocks(text):
    initial_state = []
    goal_state = []
    
    # Extract the initial conditions and goal state
    initial_match = re.search(r'As initial conditions I have that:(.*?)My goal is for the following to be true:', text, re.DOTALL)
    goal_match = re.search(r'My goal is for the following to be true:(.*?)\n\n', text, re.DOTALL)

    if initial_match:
        initial_conditions = re.findall(r'Block [A-Z] is on top of Block [A-Z]', initial_match.group(1))
        init_table_blocks = re.findall(r'Block ([A-Z]) is on the table', initial_match.group(1))
        initial_state = process_conditions(initial_conditions)

    
    if goal_match:
        goal_conditions = re.findall(r'Block [A-Z] is on top of Block [A-Z]', goal_match.group(1))
        goal_table_blocks = re.findall(r'Block ([A-Z]) is on the table', goal_match.group(1))
        goal_state = process_conditions(goal_conditions)

    
    return (initial_state, init_table_blocks), (goal_state, goal_table_blocks)

def process_conditions(conditions):
    pairs = {}
    
    for cond in conditions:
        block, below = re.findall(r'Block ([A-Z])', cond)
        pairs[block] = below
    
    return pairs


item = dataset[2]["query"]
stmt = item.split("[STATEMENT]")[-1].strip()

initial_state, goal_state = parse_blocks(stmt)
initial_state, goal_state

(({'B': 'C', 'C': 'D'}, ['A', 'D']), ({'A': 'C', 'C': 'D', 'D': 'B'}, []))

In [24]:
def state_to_pairs(state, all_blocks):
    pairs, _ = state
    below = {}

    for block, below_block in pairs.items():
        below[block] = below_block

    for block in all_blocks:
        if block not in below:
            below[block] = "table"

    above = {}

    for block, below_block in below.items():
        if below_block != "table":
            above[below_block] = block

    for block in all_blocks:
        if block not in above:
            above[block] = "sky"
    
    return above, below

In [25]:
def collect_all_blocks(initial_state):
    all_blocks = list(initial_state[0].keys())
    all_blocks.extend(initial_state[1])
    all_blocks.extend(initial_state[0].values())
    return list(set(all_blocks))

In [26]:
all_blocks = collect_all_blocks(initial_state)

state_to_pairs(initial_state, all_blocks)

({'C': 'B', 'D': 'C', 'A': 'sky', 'B': 'sky'},
 {'B': 'C', 'C': 'D', 'D': 'table', 'A': 'table'})

In [27]:
from typing import Optional

def apply_action(action: list[str], state: tuple[dict, dict, Optional[str]]) -> Optional[tuple[dict, dict, Optional[str]]]: 
    above, below, hand = state

    above = above.copy()
    below = below.copy()

    action_type, blocks = action

    if action_type == "pick up":
        if hand is not None:
            return None
        block = blocks[0]
        above_block = above[block]

        if above_block != "sky":
            return None
        
        below_block = below[block]
        if below_block != "table":
            above[below_block] = "sky"
            below[block] = "table"
        
        hand = block

    elif action_type == "put down":
        if hand is None:
            return None
        
        if hand != blocks[0]:
            return None
        
        block = blocks[0]
        hand = None
    elif action_type == "unstack":
        if hand is not None:
            return None
        
        block1, block2 = blocks
        if above[block1] != "sky":
            return None
        if below[block1] != block2:
            return None
        
        above[block2] = "sky"
        below[block1] = "table"

        hand = block1
    elif action_type == "stack":
        block1, block2 = blocks

        if hand != block1:
            return None

        if above[block2] != "sky":
            return None
        
        above[block2] = block1
        below[block1] = block2
        hand = None

    return above, below, hand

In [28]:
training_data = []
for i, row in enumerate(tqdm(dataset.select(range(n_rows)))):
    actions = extract_actions(row)
    if actions is None:
        continue
    parsed_actions = parse_block_actions(actions)
    
    generation = row["generation"]
    plan_start = generation.index("[PLAN]\n") + len("[PLAN]\n")
    plan = generation[plan_start:]
    
    text = generation[:plan_start]

    group = []

    stmt = row["query"].split("[STATEMENT]")[-1].strip()
    initial_state, goal_state = parse_blocks(stmt)

    all_blocks = collect_all_blocks(initial_state)
    initial_state = state_to_pairs(initial_state, all_blocks)
    goal_state = state_to_pairs(goal_state, all_blocks)

    current_state = (initial_state[0], initial_state[1], None)

    for action, line in zip(parsed_actions, plan.split("\n")):
        if "Block" in line and current_state is not None:
            try:
                next_state = apply_action(action, current_state)
            except Exception as e:
                print(e)
                next_state = None
            if next_state is not None:
                block_pos = line.index("Block")
                first_part = line[:block_pos] + "Block"
                _text = text + first_part
                tokens = tokenize_blocksworld_generation(tokenizer, row, _text)[0]
                group.append({
                    "idx": i,
                    "action": action,
                    "pos": len(tokens) - 1,
                    "before_state": current_state,
                    "after_state": next_state
                })

            current_state = next_state

        text += line + "\n"


    training_data.append({
        "idx": i,
        "initial_state": initial_state,
        "goal_state": goal_state,
        "actions": parsed_actions,
        "group": group
    })


  0%|          | 0/1500 [00:00<?, ?it/s]

not enough values to unpack (expected 2, got 1)


In [29]:
training_data[0]

{'idx': 0,
 'initial_state': ({'A': 'B', 'D': 'C', 'C': 'sky', 'B': 'sky'},
  {'B': 'A', 'C': 'D', 'D': 'table', 'A': 'table'}),
 'goal_state': ({'D': 'A', 'C': 'B', 'B': 'D', 'A': 'sky'},
  {'A': 'D', 'B': 'C', 'D': 'B', 'C': 'table'}),
 'actions': [('unstack', ['C', 'D']),
  ('put down', ['C']),
  ('unstack', ['B', 'A']),
  ('stack', ['B', 'C']),
  ('pick up', ['D']),
  ('stack', ['D', 'B']),
  ('pick up', ['A']),
  ('stack', ['A', 'D'])],
 'group': [{'idx': 0,
   'action': ('unstack', ['C', 'D']),
   'pos': 4990,
   'before_state': ({'A': 'B', 'D': 'C', 'C': 'sky', 'B': 'sky'},
    {'B': 'A', 'C': 'D', 'D': 'table', 'A': 'table'},
    None),
   'after_state': ({'A': 'B', 'D': 'sky', 'C': 'sky', 'B': 'sky'},
    {'B': 'A', 'C': 'table', 'D': 'table', 'A': 'table'},
    'C')},
  {'idx': 0,
   'action': ('put down', ['C']),
   'pos': 5001,
   'before_state': ({'A': 'B', 'D': 'sky', 'C': 'sky', 'B': 'sky'},
    {'B': 'A', 'C': 'table', 'D': 'table', 'A': 'table'},
    'C'),
   'after_st

In [30]:
n_blocks = int(dataset[n_rows - 1]["instance_id"].split("_")[0])
n_blocks

4

In [48]:
from torch.utils.data import Dataset

act2int = {
    "put down": 0,
    "pick up": 1,
    "stack": 2,
    "unstack": 3
}

def block2int(block):
    if block == "table":
        return n_blocks
    if block == "sky":
        return n_blocks + 1
    
    return ord(block) - ord("A")

def int2block(i):
    if i == n_blocks:
        return "table"
    if i == n_blocks + 1:
        return "sky"
    
    return chr(i + ord("A"))

n_prev_tokens = 100

def state_to_label(state):
    above, below, hand = state
    label = np.zeros((n_blocks * 2, ), dtype=np.int64)

    for block, below_block in below.items():
        label[block2int(block)] = block2int(below_block)
    for block, above_block in above.items():
        label[block2int(block) + n_blocks] = block2int(above_block)

    return label

def action_to_label(action):
    action, blocks = action

    return block2int(blocks[0])


class StepProbeDataset(Dataset):
    def __init__(self, items, hidden_states, n_layer):
        self.items = items
        self.hidden_states = hidden_states
        self.n_layer = n_layer

    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, idx):
        action1, action2 = self.items[idx]
        hidden_states = self.hidden_states[self.n_layer][action1["idx"]]
        pos = action1["pos"]

        return {
            "input": hidden_states[pos],
            "labels": action_to_label(action2["action"])
        }


In [49]:
def make_training_data(jump=0, train_test_split=0.8):
    expanded_training_data = []

    for group in training_data:
        group = group["group"]
        for action1, action2 in zip(group, group[jump:]):
            if len(action1["action"][1]) < 1:
                continue
            expanded_training_data.append((action1, action2))
            break

    n_train = int(len(expanded_training_data) * train_test_split)

    train_items = expanded_training_data[:n_train]
    test_items = expanded_training_data[n_train:]

    train_dataset = StepProbeDataset(train_items, layer_hidden_states, 0)
    test_dataset = StepProbeDataset(test_items, layer_hidden_states, 0)

    return train_dataset, test_dataset

In [50]:
class StepProbe(torch.nn.Module):
    def __init__(self, input_size, hidden_size, n_blocks):
        super().__init__()
        # self.fc = torch.nn.Linear(input_size, hidden_size)
        # self.fc2 = torch.nn.Linear(hidden_size, n_blocks * (n_blocks + 2) * 2)
        self.fc2 = torch.nn.Linear(input_size, n_blocks)
        
    def forward(self, x):
        # x = self.fc(x)
        # x = torch.nn.functional.relu(x)
        x = self.fc2(x)
        return x.view(-1, n_blocks)

In [51]:
class GRUProbe(torch.nn.Module):
    def __init__(self, input_size, hidden_size, n_blocks):
        super().__init__()
        self.gru = torch.nn.GRU(input_size, hidden_size, batch_first=True)
        self.fc = torch.nn.Linear(hidden_size, n_blocks)
        
    def forward(self, x):
        x, _ = self.gru(x)
        x = x[:, -1]
        x = self.fc(x)
        return x

In [52]:
jumps = list(range(6))

jump_datasets = {
    jump: make_training_data(jump) for jump in jumps
}

for jump, (train, test) in jump_datasets.items():
    print(jump, len(train), len(test))

0 1121 281
1 1108 277
2 1088 273
3 1069 268
4 1020 256
5 1004 251


In [66]:
n_dim = 5120

probes = {jump: StepProbe(n_dim, 1000, n_blocks).to(device) for jump in jumps}

In [67]:
import torch
import numpy as np
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from sklearn.metrics import f1_score

def train_probe(probe, train_dataset, test_dataset, patience=100):
    optimizer = Adam(probe.parameters(), lr=1e-3)
    criterion = CrossEntropyLoss()
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

    n_epochs = 500
    best_f1 = float('inf')
    early_stop_counter = 0
    
    for epoch in range(n_epochs):
        probe.train()
        total_loss = 0
        n_samples = 0

        for batch in train_loader:
            optimizer.zero_grad()
            input = batch["input"].to(device).float()
            labels = batch["labels"].to(device)
            
            output = probe(input)
            loss = criterion(output, labels)
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item() * len(batch["input"])
            n_samples += len(batch["input"])

        avg_train_loss = total_loss / n_samples
        
        # Evaluation
        probe.eval()
        with torch.no_grad():
            # block_wise_hits = np.zeros((n_blocks * 2), dtype=np.int64)
            total = 0  
            hits = 0
            val_loss = 0
            all_preds = []
            all_labels = []
            
            for batch in test_loader:
                input = batch["input"].to(device).float()
                labels = batch["labels"].to(device)
                
                output = probe(input)
                preds = output.argmax(dim=-1)  # Assuming classification task
                hits += (preds == labels).sum().item()  
                
                # block_wise_hits += hits.sum(dim=0).cpu().numpy()
                total += len(labels)
                
                all_preds.append(preds.cpu().numpy())
                all_labels.append(labels.cpu().numpy())

                loss = criterion(output, labels)

                val_loss += loss.item() * len(batch["input"])
            
            # block_wise_hits = block_wise_hits / total
            
            all_preds = np.concatenate(all_preds)
            all_labels = np.concatenate(all_labels)

            f1 = f1_score(all_labels, all_preds, average='macro')
            
            # # Compute F1 score block-wise
            # block_wise_f1 = np.zeros(n_blocks * 2)
            # for i in range(n_blocks * 2):
            #     block_wise_f1[i] = f1_score(all_labels[:, i], all_preds[:, i], average='macro')
            
            # avg_f1 = block_wise_f1.mean()
            avg_f1 = 0
            val_loss /= total

            print(f"Epoch {epoch}, Train Loss: {avg_train_loss:.4f}, Hits: {hits/total:.4f}, F1: {f1:.4f}, Val Loss: {val_loss:.4f}")
        
            # Early Stopping Check
            if avg_f1 > best_f1:
                best_f1 = avg_f1
                early_stop_counter = 0
            else:
                early_stop_counter += 1
            
            if early_stop_counter >= patience:
                print(f"Early stopping triggered at epoch {epoch}")
                break
    
    return hits / total


In [68]:
for jump, (train, test) in jump_datasets.items():
    print(jump)
    print(train_probe(probes[jump], train, test))


0
Epoch 0, Train Loss: 1.6864, Hits: 0.7438, F1: 0.6545, Val Loss: 0.5458
Epoch 1, Train Loss: 0.1156, Hits: 0.9893, F1: 0.9885, Val Loss: 0.0751
Epoch 2, Train Loss: 0.0430, Hits: 1.0000, F1: 1.0000, Val Loss: 0.0091
Epoch 3, Train Loss: 0.0026, Hits: 1.0000, F1: 1.0000, Val Loss: 0.0053
Epoch 4, Train Loss: 0.0019, Hits: 1.0000, F1: 1.0000, Val Loss: 0.0040
Epoch 5, Train Loss: 0.0015, Hits: 1.0000, F1: 1.0000, Val Loss: 0.0029
Epoch 6, Train Loss: 0.0012, Hits: 1.0000, F1: 1.0000, Val Loss: 0.0023
Epoch 7, Train Loss: 0.0010, Hits: 1.0000, F1: 1.0000, Val Loss: 0.0020
Epoch 8, Train Loss: 0.0008, Hits: 1.0000, F1: 1.0000, Val Loss: 0.0019
Epoch 9, Train Loss: 0.0008, Hits: 1.0000, F1: 1.0000, Val Loss: 0.0019
Epoch 10, Train Loss: 0.0007, Hits: 1.0000, F1: 1.0000, Val Loss: 0.0019
Epoch 11, Train Loss: 0.0007, Hits: 1.0000, F1: 1.0000, Val Loss: 0.0019
Epoch 12, Train Loss: 0.0007, Hits: 1.0000, F1: 1.0000, Val Loss: 0.0018
Epoch 13, Train Loss: 0.0007, Hits: 1.0000, F1: 1.0000, Val

In [60]:
print(dataset[3]["generation"].split("</think>")[0][-2000:])

ack it on A.

Wait, let me think step by step.

1. Unstack C from D. Now, C is in hand, D is on B, and A is on table. Hand is holding C.

2. Put down C. Now, C is on table, D is on B, A is on table. Hand is empty.

3. Now, I need to unstack D from B. So, unstack D from B. Now, D is in hand, B is on table, C and A are on table. Hand is holding D.

4. Put down D. Now, D is on table, B is on table, C is on table, A is on table. Hand is empty.

5. Now, I need to stack C on B. So, pick up C. Hand holds C.

6. Stack C on B. Now, C is on B, B is on table. Hand is empty.

7. Now, pick up A. Hand holds A.

8. Stack A on C. Now, A is on C, which is on B. Hand is empty.

9. Now, pick up D. Hand holds D.

10. Stack D on A. Now, D is on A, which is on C, which is on B. So, the stack is B -> C -> A -> D. That's the goal.

Wait, but let me check if all the rules are followed. Each time I unstack or pick up, I have to make sure the block is clear. Let's go through each step.

1. Unstack C from D: C is

In [61]:
item = training_data[-5]

In [62]:
item

{'idx': 1495,
 'initial_state': ({'B': 'A', 'D': 'B', 'A': 'sky', 'C': 'sky'},
  {'A': 'B', 'B': 'D', 'C': 'table', 'D': 'table'}),
 'goal_state': ({'D': 'A', 'A': 'C', 'B': 'D', 'C': 'sky'},
  {'A': 'D', 'C': 'A', 'D': 'B', 'B': 'table'}),
 'actions': [('unstack', ['A', 'B']),
  ('put down', ['A']),
  ('unstack', ['B', 'D']),
  ('put down', ['B']),
  ('pick up', ['D']),
  ('stack', ['D', 'B']),
  ('pick up', ['A']),
  ('stack', ['A', 'D']),
  ('pick up', ['C']),
  ('stack', ['C', 'A'])],
 'group': [{'idx': 1495,
   'action': ('unstack', ['A', 'B']),
   'pos': 2659,
   'before_state': ({'B': 'A', 'D': 'B', 'A': 'sky', 'C': 'sky'},
    {'A': 'B', 'B': 'D', 'C': 'table', 'D': 'table'},
    None),
   'after_state': ({'B': 'sky', 'D': 'B', 'A': 'sky', 'C': 'sky'},
    {'A': 'table', 'B': 'D', 'C': 'table', 'D': 'table'},
    'A')},
  {'idx': 1495,
   'action': ('put down', ['A']),
   'pos': 2670,
   'before_state': ({'B': 'sky', 'D': 'B', 'A': 'sky', 'C': 'sky'},
    {'A': 'table', 'B': 'D

In [63]:
def label_to_state(label):
    above = {}
    below = {}
    for i in range(n_blocks):
        below_block = int2block(label[i])
        above_block = int2block(label[i + n_blocks])
        block = int2block(i)

        above[block] = above_block
        below[block] = below_block

    return above, below, None

In [71]:
item = training_data[-6]

In [72]:

item_hidden_states = layer_hidden_states[0][item["idx"]]
pos = item["group"][0]["pos"]
inputs = torch.tensor(item_hidden_states[pos]).unsqueeze(0).to(device).float()

for j, probe in probes.items():
    with torch.no_grad():
        output = probe(inputs)
    # output = output.view(-1, n_blocks + 2, n_blocks * 2)
    preds = output.argmax(dim=-2).cpu().numpy().squeeze()

    print(item["group"][j]["action"])
    print(label_to_state(preds))
    print(label_to_state(state_to_label(item["group"][j]["after_state"])))
    print()

('unstack', ['B', 'C'])
({'A': 'C', 'B': 'sky', 'C': 'sky', 'D': 'A'}, {'A': 'table', 'B': 'table', 'C': 'D', 'D': 'table'}, None)
({'A': 'C', 'B': 'sky', 'C': 'sky', 'D': 'sky'}, {'A': 'table', 'B': 'table', 'C': 'A', 'D': 'table'}, None)

('put down', ['B'])
({'A': 'B', 'B': 'sky', 'C': 'sky', 'D': 'C'}, {'A': 'table', 'B': 'table', 'C': 'D', 'D': 'table'}, None)
({'A': 'C', 'B': 'sky', 'C': 'sky', 'D': 'sky'}, {'A': 'table', 'B': 'table', 'C': 'A', 'D': 'table'}, None)

('unstack', ['C', 'A'])
({'A': 'sky', 'B': 'sky', 'C': 'sky', 'D': 'sky'}, {'A': 'table', 'B': 'table', 'C': 'table', 'D': 'table'}, None)
({'A': 'sky', 'B': 'sky', 'C': 'sky', 'D': 'sky'}, {'A': 'table', 'B': 'table', 'C': 'table', 'D': 'table'}, None)

('put down', ['C'])
({'A': 'sky', 'B': 'sky', 'C': 'sky', 'D': 'sky'}, {'A': 'table', 'B': 'table', 'C': 'table', 'D': 'table'}, None)
({'A': 'sky', 'B': 'sky', 'C': 'sky', 'D': 'sky'}, {'A': 'table', 'B': 'table', 'C': 'table', 'D': 'table'}, None)

('pick up', ['C'

In [249]:
item["actions"]

[('unstack', ['A', 'E']),
 ('put down', ['A']),
 ('unstack', ['E', 'C']),
 ('put down', ['E']),
 ('stack', ['A', 'C']),
 ('stack', ['F', 'A']),
 ('stack', ['D', 'F']),
 ('unstack', ['D', 'F']),
 ('stack', ['B', 'D']),
 ('stack', ['E', 'B']),
 ('stack', ['F', 'A'])]

In [None]:
]