In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

In [3]:
os.chdir("..")

In [4]:
import torch
from transformers import AutoTokenizer
from tqdm.auto import tqdm
from datasets import load_dataset
from cluster_intrep_repo.utils import initialize_tokenizer, tokenize_blocksworld_generation, THINK_TOKEN



os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

compute_dtype = torch.bfloat16
device   = 'cuda'
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"

In [5]:
tokenizer = initialize_tokenizer(model_id)

In [6]:
import numpy as np

In [7]:
blocksworld_type = "big"

dataset = load_dataset(f"dmitriihook/deepseek-r1-qwen-32b-planning-{blocksworld_type}")["train"]

In [9]:
from datasets import concatenate_datasets
generated_labels = [load_dataset("dmitriihook/blocksworld-big-step-labels-3")["train"], load_dataset("dmitriihook/blocksworld-big-step-labels-4")["train"], load_dataset("dmitriihook/blocksworld-big-step-labels-5")["train"]]

generated_labels = concatenate_datasets(generated_labels)

In [10]:
generated_labels[0]

{'index': 200,
 'steps': [{'label': '{"goal_action": null, "actions": null}',
   'step': "<think>\nOkay, so I have this block stacking problem to solve. Let me try to figure out the steps needed to get from the initial state to the goal state. I'll start by understanding the initial conditions and the goal."},
  {'label': None,
   'step': 'Initial Conditions:\n- Block A is clear.\n- Block D is clear.\n- Hand is empty.\n- Block A is on top of Block E.\n- Block B is on top of Block C.\n- Block E is on top of Block B.\n- Block C is on the table.\n- Block D is on the table.'},
  {'label': '{"goal_action": null, "actions": null}',
   'step': 'So, visualizing this, the stacks are:\n- C has B on top, and E is on top of B. So, the stack is C -> B -> E.\n- E has A on top, so E -> A.\n- D is on the table, clear.\n- A is clear, so nothing is on top of A.'},
  {'label': None,
   'step': 'Goal:\n- A is on top of C.\n- B is on top of A.\n- D is on top of E.\n- E is on top of B.'},
  {'label': '{"goa

In [11]:
from transformers import AutoModelForCausalLM
model     = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=compute_dtype, attn_implementation="sdpa", device_map="auto")

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [9]:
rows_startn_rows = 500

In [None]:
from collections import defaultdict

# [src; dest]

layer_hidden_states = defaultdict(list)

n_last_layers = 10

for row in tqdm(dataset.select(range(n_rows))):
    chat = tokenize_blocksworld_generation(tokenizer, row)

    # think_pos = torch.where(chat.squeeze() == THINK_TOKEN)[0]

    with torch.no_grad():
        outputs = model(chat.to(device), output_hidden_states=True)

        for j in range(n_last_layers):
            hidden_states = outputs.hidden_states[-1 - j]
            layer_hidden_states[j].append(hidden_states[0].float().cpu().numpy())

  0%|          | 0/500 [00:00<?, ?it/s]

In [11]:
n_blocks = int(dataset[n_rows - 1]["instance_id"].split("_")[0])
n_blocks

6

In [299]:
training_data = []

for steps_item in tqdm(generated_labels):
    idx = steps_item["index"]
    steps = steps_item["steps"]
    row = dataset[idx]

    generation = row["generation"]

    text = ""
    group = []

    for line, step in zip(generation.split("\n\n"), steps):
        if step["label"] is not None:
            label = step["label"]
            tokens = tokenize_blocksworld_generation(tokenizer, row, text)[0]

            group.append({
                "step": step,
                "pos": len(tokens) - 1,
            })

        text = text + line + "\n\n"

    training_data.append({
        "idx": idx,
        "steps": group,
    })


  0%|          | 0/200 [00:00<?, ?it/s]

In [130]:
training_data[0]

{'idx': 0,
 'steps': [{'step': {'label': {'actions': None, 'goal_action': None},
    'step': 'So, visualizing this, I have two separate stacks. One stack is D with A on top, and another stack is C with B on top. Both D and C are on the table. '},
   'pos': 852},
  {'step': {'label': {'actions': [['stack', 'D', 'B']],
     'goal_action': ['stack', 'D', 'B']},
    'step': "So, the final arrangement should be a stack where D is on B, which is on C, and A is on D. So the order from bottom to top would be C, B, D, A. But wait, that can't be because D is on B, which is on C, and A is on D. So the stack would be C, B, D, A. But initially, A is on D, which is on the table, and B is on C, which is on the table. So I need to move D to be on top of B, which is on C, and then A remains on D."},
   'pos': 962},
  {'step': {'label': {'actions': None, 'goal_action': None},
    'step': "Wait, but the goal says Block A is on top of Block D, which is on top of Block B, which is on top of Block C. So the

In [377]:
expanded_training_data = []

possible_actions = ["stack", "unstack", "pick up", "put down"]
possible_blocks = [chr(ord('A') + i) for i in range(n_blocks)]

print(possible_blocks)

def check_action(action):
    if action is None:
        return False
    
    if len(action) < 2:
        return False
    
    act, blocks = action[0], action[1:]
    if act not in possible_actions:
        return False

    for block in blocks:
        if block not in possible_blocks:
            return False
        
    return True

for item in training_data:
    for step in item["steps"]:
        # if step["step"]["label"]["goal_action"] is not None:
        #     continue
        action = step["step"]["label"]["actions"]
        if action is None or len(action) < 2:
            continue
        action = action[1]
        if not check_action(action):
            continue
        expanded_training_data.append({
            "idx": item["idx"],
            "goal": action,
            "pos": step["pos"],
            "step": step,
        })
len(expanded_training_data)

['A', 'B', 'C', 'D', 'E', 'F']


2436

In [378]:
expanded_training_data[11]

{'idx': 0,
 'goal': ['put down', 'A'],
 'pos': 2250,
 'step': {'step': {'label': {'actions': [['unstack', 'A', 'D'],
     ['put down', 'A'],
     ['pick up', 'D'],
     ['stack', 'D', 'B'],
     ['pick up', 'A'],
     ['stack', 'A', 'D']],
    'goal_action': None},
   'step': 'unstack A from D\nput down A\npick up D\nstack D on B\npick up A\nstack A on D'},
  'pos': 2250}}

In [379]:
train_test_split = 0.8
n_train = int(len(expanded_training_data) * train_test_split)

train_items = expanded_training_data[:n_train]
test_items = expanded_training_data[n_train:]

In [381]:
from torch.utils.data import Dataset

act2int = {
    "put down": 0,
    "pick up": 1,
    "stack": 2,
    "unstack": 3
}

block2int = {chr(ord("A") + i): i for i in range(n_blocks)}
int2block = {v: k for k, v in block2int.items()}

def block_2_label(block):
    block = block.replace("Block", "").strip()
    return block2int[block]


def action_to_label(action):
    # print(action)
    action_type = act2int[action[0]]
    blocks = action[1:]
    # blocks = [block2int[block] for block in blocks]
    block = block_2_label(blocks[0])
    return block


class StepProbeDataset(Dataset):
    def __init__(self, items, hidden_states, n_layer, prev_tokens):
        self.items = items
        self.hidden_states = hidden_states
        self.n_layer = n_layer
        self.prev_tokens = prev_tokens

    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, idx):
        item = self.items[idx]
        hidden_states = self.hidden_states[self.n_layer][item["idx"]]
        pos = item["pos"]
        action = item["goal"]
        return {
            "input": hidden_states[pos - self.prev_tokens:pos + 1],
            "labels": action_to_label(action)
        }


In [382]:
prev_tokens = 50

train_dataset = StepProbeDataset(train_items, layer_hidden_states, 0, prev_tokens)
test_dataset = StepProbeDataset(test_items, layer_hidden_states, 0, prev_tokens)

In [383]:
class StepProbeAH(torch.nn.Module):
    def __init__(self, input_size, hidden_size, n_blocks):
        super().__init__()
        self.attn_proj = torch.nn.Linear(input_size, 1)
        self.block_proj = torch.nn.Linear(input_size, n_blocks)

    def forward(self, x):
        # x: [batch; seq, hidden]
        # x = self.attn_proj(x)
        # scores = torch.einsum("bph,bqh->bpq", x, x)
        # scores = torch.nn.functional.softmax(scores, dim=1)
        # x = torch.einsum("bph,bpq->bqh", x, scores)
        # x = x.mean(dim=1)
        # x = self.block_proj(x)

        scores = self.attn_proj(x).squeeze(-1)
        scores = torch.nn.functional.softmax(scores, dim=-1)
        x = torch.einsum("bsh,bs->bh", x, scores)
        x = self.block_proj(x)
        return x
    

In [384]:
class GRUProbe(torch.nn.Module):
    def __init__(self, input_size, hidden_size, n_blocks):
        super().__init__()
        self.gru = torch.nn.GRU(input_size, hidden_size, batch_first=True)
        self.block_proj = torch.nn.Linear(hidden_size, n_blocks)

    def forward(self, x):
        x, _ = self.gru(x)
        x = x[:, -1]
        x = self.block_proj(x)
        return x

In [385]:
class SimpleProbe(torch.nn.Module):
    def __init__(self, input_size, hidden_size, n_blocks):
        super().__init__()
        self.proj = torch.nn.Linear(input_size, n_blocks)

    def forward(self, x):
        x = self.proj(x)
        return x

In [386]:
n_dim = 5120
probe = GRUProbe(n_dim, 1000, n_blocks).to(device)

In [387]:
from  torch.optim import Adam
from torch.utils.data import DataLoader
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss

optimizer = Adam(probe.parameters(), lr=1e-3)

criterion = CrossEntropyLoss()

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [388]:
n_epochs = 40
for epoch in range(n_epochs):
    probe.train()

    total_loss = 0
    n_samples = 0

    for batch in train_loader:
        optimizer.zero_grad()

        input = batch["input"].to(device)
        labels = batch["labels"].to(device)

        output = probe(input)
        loss = criterion(output, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * len(batch)
        n_samples += len(batch)

    probe.eval()
    with torch.no_grad():
        total_loss = 0
        n_samples = 0
        for batch in test_loader:
            input = batch["input"].to(device)
            labels = batch["labels"].to(device)

            output = probe(input) 

            accuracy = (output.argmax(dim=-1) == labels).float().mean()

            total_loss += accuracy.item() * len(batch)
            n_samples += len(batch)

        print(f"Epoch {epoch} Accuracy: {total_loss / n_samples}")


Epoch 0 Accuracy: 0.2421875
Epoch 1 Accuracy: 0.1953125
Epoch 2 Accuracy: 0.271484375
Epoch 3 Accuracy: 0.251953125
Epoch 4 Accuracy: 0.2109375
Epoch 5 Accuracy: 0.287109375
Epoch 6 Accuracy: 0.3125
Epoch 7 Accuracy: 0.27734375
Epoch 8 Accuracy: 0.294921875
Epoch 9 Accuracy: 0.30859375
Epoch 10 Accuracy: 0.28125
Epoch 11 Accuracy: 0.32421875
Epoch 12 Accuracy: 0.287109375
Epoch 13 Accuracy: 0.294921875
Epoch 14 Accuracy: 0.314453125
Epoch 15 Accuracy: 0.28515625
Epoch 16 Accuracy: 0.306640625
Epoch 17 Accuracy: 0.298828125
Epoch 18 Accuracy: 0.298828125
Epoch 19 Accuracy: 0.28515625
Epoch 20 Accuracy: 0.302734375
Epoch 21 Accuracy: 0.291015625
Epoch 22 Accuracy: 0.287109375
Epoch 23 Accuracy: 0.30078125
Epoch 24 Accuracy: 0.302734375
Epoch 25 Accuracy: 0.296875
Epoch 26 Accuracy: 0.294921875
Epoch 27 Accuracy: 0.3125
Epoch 28 Accuracy: 0.294921875
Epoch 29 Accuracy: 0.314453125
Epoch 30 Accuracy: 0.3125
Epoch 31 Accuracy: 0.314453125
Epoch 32 Accuracy: 0.30859375
Epoch 33 Accuracy: 0.3

In [200]:
def labels_to_stack(labels, n_blocks):
    stack = []

    for i in range(n_blocks):
        block = chr(65 + labels[i].item() - 1)
        stack.append(block)

    return stack


for item in testing_dataset:
    input = torch.tensor(item["input"]).unsqueeze(0).to(device)
    label = item["labels"].unsqueeze(0).to(device)

    with torch.no_grad():
        output = probe(input).detach().cpu()

    print(labels_to_stack(output.argmax(dim=-2)[0], n_blocks), labels_to_stack(label[0], n_blocks))

['A', 'D', 'E', 'C', '@', '@'] ['E', 'A', 'D', 'B', '@', '@']
['B', 'D', 'E', 'C', '@', '@'] ['E', 'A', 'D', 'B', '@', '@']
['C', 'B', 'D', 'B', '@', '@'] ['C', 'D', 'B', 'E', '@', '@']
['A', 'B', 'E', '@', '@', '@'] ['A', 'B', 'E', '@', '@', '@']
['A', 'A', 'D', '@', '@', '@'] ['D', 'B', 'A', '@', '@', '@']
['B', 'D', 'C', 'A', '@', '@'] ['B', 'C', 'E', 'D', '@', '@']
['B', 'E', 'C', '@', '@', '@'] ['B', 'C', 'E', '@', '@', '@']
['B', 'C', '@', '@', '@', '@'] ['B', 'C', '@', '@', '@', '@']
['A', 'D', 'B', '@', '@', '@'] ['A', 'D', 'B', '@', '@', '@']
['A', 'D', 'B', '@', '@', '@'] ['A', 'D', 'B', '@', '@', '@']
['E', 'A', '@', '@', '@', '@'] ['E', 'A', '@', '@', '@', '@']
['E', 'A', 'D', '@', '@', '@'] ['E', 'A', 'D', '@', '@', '@']
['B', 'E', 'E', '@', '@', '@'] ['B', 'E', 'A', '@', '@', '@']
['B', 'F', 'A', 'C', '@', '@'] ['B', 'E', 'A', 'C', '@', '@']
['B', 'E', 'E', '@', '@', '@'] ['B', 'E', 'A', '@', '@', '@']
['B', 'E', 'A', 'C', '@', '@'] ['B', 'E', 'A', 'C', '@', '@']
['C', 'E

In [51]:
selected_idx = expanded_testing_data[0]["idx"]

In [55]:
preds = probe(torch.tensor(layer_hidden_states[0][selected_idx]).to(device)).detach().cpu()

In [57]:
preds = preds.argmax(dim=-2)

In [68]:
text_preds = ["".join([chr(65 + x - 1) for x in y if x != 0]) for y in preds]

In [71]:

text_preds[-200:]

['BCEDAA',
 'CCFFCE',
 'BCFECE',
 'CC',
 'DECFFD',
 'BEDF',
 'BEDF',
 'BABA',
 'FDC',
 'BD',
 'FCDFD',
 'BD',
 'BACDFC',
 'BDC',
 'BD',
 'DBC',
 'CB',
 'BB',
 'B',
 'CEDA',
 'DEDFB',
 'CD',
 'BCDD',
 'CBBCC',
 'CBBDA',
 'BBD',
 'FBCDFD',
 'BD',
 'CB',
 'BB',
 'DCC',
 'BDCA',
 'BD',
 'DA',
 'DA',
 'BA',
 'CEC',
 'DEC',
 'BED',
 'BA',
 'BDC',
 'BDC',
 'CA',
 'BCAFD',
 'AEDA',
 'BEDCF',
 'CEBAF',
 'BB',
 'AB',
 'BCC',
 'BD',
 'BDE',
 'BAC',
 'BE',
 'DAC',
 'CA',
 'BA',
 'BAC',
 'BD',
 'DB',
 'DA',
 'BCD',
 'BAD',
 'BAD',
 'CDE',
 'BBEEFE',
 'AADB',
 'CADBCE',
 'BADDCD',
 'CADBCF',
 'CBAFC',
 'BFCDC',
 'CADECF',
 'BFFBAD',
 'CFE',
 'BCADD',
 'AFDBB',
 'CFDFCA',
 'ABDDC',
 'DFDFB',
 'BADFCE',
 'ABFF',
 'AACCCB',
 'AFEFF',
 'CFF',
 'EDEDE',
 'BED',
 'BDAEDA',
 'FDAACF',
 'AEFDA',
 'BDAEFA',
 'BBAECA',
 'CBCECC',
 'BBEF',
 'BDCEC',
 'BDEC',
 'CBCE',
 'ADE',
 'ADE',
 'BD',
 'DEAEDC',
 'BDABBA',
 'ABCEFA',
 'ABCE',
 'BB',
 'BE',
 'BDAAFA',
 'BDAECA',
 'CDCE',
 'B',
 'DA',
 'BACDC',
 'CACD',
 'D

In [67]:
print(correct_instances[selected_idx]["bench_item"]["full_response"])

<think>
Okay, so I have this block stacking problem to solve. Let me try to figure out the steps needed to get from the initial state to the goal state. I'll take it step by step because I'm still getting the hang of these kinds of problems.

First, let me understand the initial conditions. The problem says:

- Block B is clear.
- Block C is clear.
- The hand is empty.
- Block A is on top of Block E.
- Block B is on top of Block D.
- Block D is on top of Block A.
- Block C is on the table.
- Block E is on the table.

So, visualizing this, I think the stacks are as follows:

- E is on the table.
- A is on top of E.
- D is on top of A.
- B is on top of D.
- C is on the table, clear.

And the goal is:

- A is on top of D.
- C is on top of A.
- D is on top of E.
- E is on top of B.

Wait, that seems a bit confusing. Let me parse the goal again:

"Block A is on top of Block D, Block C is on top of Block A, Block D is on top of Block E and Block E is on top of Block B."

So, the final stacks

In [60]:
import plotly.express as px

px.imshow(preds[0].numpy())

ValueError: px.imshow only accepts 2D single-channel, RGB or RGBA images. An image of shape (6,) was provided. Alternatively, 3- or 4-D single or multichannel datasets can be visualized using the `facet_col` or/and `animation_frame` arguments.

In [185]:

dataset_mystery = load_dataset(f"dmitriihook/deepseek-r1-qwen-32b-planning-mystery")["train"]

README.md:   0%|          | 0.00/67.5k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/35.9M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2114 [00:00<?, ? examples/s]

In [195]:
x = dataset[0]

In [196]:
print(x["generation"])

<think>
Okay, so I have this problem where I need to rearrange some blocks. Let me try to figure out the steps I need to take. 

First, let me understand the initial setup. The blocks are A, B, C, D. The initial conditions are:

- Block A is clear.
- Block B is clear.
- Hand is empty.
- Block A is on top of Block D.
- Block B is on top of Block C.
- Block C is on the table.
- Block D is on the table.

So, visualizing this, I have two separate stacks. One stack is D with A on top, and another stack is C with B on top. Both D and C are on the table. 

My goal is to have:

- Block A on top of Block D.
- Block B on top of Block C.
- Block D on top of Block B.

Wait, that seems a bit confusing. Let me parse that again. The goal is:

- A is on D.
- B is on C.
- D is on B.

So, the final arrangement should be a stack where D is on B, which is on C, and A is on D. So the order from bottom to top would be C, B, D, A. But wait, that can't be because D is on B, which is on C, and A is on D. So th