In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

In [3]:
os.chdir("..")

In [4]:
import torch
from transformers import AutoTokenizer
from tqdm.auto import tqdm
from datasets import load_dataset
from cluster_intrep_repo.utils import initialize_tokenizer, tokenize_blocksworld_generation, THINK_TOKEN



os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

compute_dtype = torch.bfloat16
device   = 'cuda'
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"

In [5]:
tokenizer = initialize_tokenizer(model_id)

In [6]:
import numpy as np

In [7]:
blocksworld_type = "big"

dataset = load_dataset(f"dmitriihook/deepseek-r1-qwen-32b-planning-{blocksworld_type}")["train"]

In [8]:
from transformers import AutoModelForCausalLM
model     = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=compute_dtype, attn_implementation="sdpa", device_map="auto")

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [95]:
n_rows = 500

In [9]:
from collections import defaultdict

# [src; dest]

layer_hidden_states = defaultdict(list)

n_last_layers = 10

for row in tqdm(dataset.select(range(n_rows))):
    chat = tokenize_blocksworld_generation(tokenizer, row)

    # think_pos = torch.where(chat.squeeze() == THINK_TOKEN)[0]

    with torch.no_grad():
        outputs = model(chat.to(device), output_hidden_states=True)

        for j in range(n_last_layers):
            hidden_states = outputs.hidden_states[-1 - j]
            layer_hidden_states[j].append(hidden_states[0].float().cpu().numpy())

  0%|          | 0/500 [00:00<?, ?it/s]

In [97]:
n_blocks = int(dataset[n_rows - 1]["instance_id"].split("_")[0])
n_blocks

6

In [98]:
import re

def extract_action_from_line(line):
    line_start = line.start()
    line = line.group()
    line = line.strip()

    action_regex = r"put down|pick up|stack|unstack"
    action = re.search(action_regex, line.lower())

    if action is None:
        return None
    
    action_number = re.search(r"^\d+", line)

    line = line[action.end():]

    blocks_regex = f"\s[A-{chr(65 + n_blocks)}][^a-zA-Z]"

    blocks = re.findall(blocks_regex, line)

    blocks = [block.strip()[0] for block in blocks]

    action_operands = {
        "put down": 1,
        "pick up": 1,
        "stack": 2,
        "unstack": 2
    }

    num_operands = action_operands[action.group()]

    if len(blocks) < num_operands:
        return None
    
    return action.group(), blocks[:num_operands], int(action_number.group()), line_start

def group_actions(actions):
    # Group sequential actions

    grouped_actions = []
    current_group = []

    for action in actions:
        if len(current_group) == 0:
            current_group.append(action)
        else:
            if current_group[-1][2] + 1 == action[2]:
                current_group.append(action)
            else:
                grouped_actions.append(current_group)
                current_group = [action]

    if len(current_group) > 0:
        grouped_actions.append(current_group)
    
    return grouped_actions

def extract_actions(text):
    actions = []
    regex = r"^\d+\..+$"

    for group in re.finditer(regex, text, re.MULTILINE):
        action = extract_action_from_line(group)

        if action is not None:
            actions.append(action)

    actions = group_actions(actions)
    
    return actions

In [99]:
text = dataset[2]["generation"]


actions = extract_actions(text)

tts = tokenize_blocksworld_generation(tokenizer,  dataset[2], text[:actions[0][0][-1] + 2])

tokenizer.decode(tts.squeeze()[-10:])

' A. Let me try that.\n\n1.<｜end▁of▁sentence｜>'

In [199]:
training_data = []

for i, row in enumerate(tqdm(dataset.select(range(n_rows)))):
    text = row["generation"]
    actions = extract_actions(text)
    if len(actions) == 0:
        continue
    
    for _actions in actions:
        for action_1, action_2 in zip(_actions, _actions[2:]):
            cut_text = text[:action_1[3] + 2]
            tts = tokenize_blocksworld_generation(tokenizer, row, cut_text)
            item = {
                "pos": len(tts[0]) - 1,
                "action": action_2,
                "idx": i,
            }
            training_data.append(item)
len(training_data)

  0%|          | 0/500 [00:00<?, ?it/s]

15301

In [200]:
training_data[10]

{'pos': 2088, 'action': ('pick up', ['A'], 5, 4549), 'idx': 0}

In [201]:
train_test_split = 0.8
n_train = int(len(training_data) * train_test_split)

train_items = training_data[:n_train]
test_items = training_data[n_train:]

In [202]:
from torch.utils.data import Dataset

act2int = {
    "put down": 0,
    "pick up": 1,
    "stack": 2,
    "unstack": 3
}

block2int = {chr(ord("A") + i): i for i in range(n_blocks)}
int2block = {v: k for k, v in block2int.items()}


def action_to_label(action):
    action_type = act2int[action[0]]
    blocks = action[1]
    blocks = [block2int[block] for block in blocks]
    block = blocks[0]
    return block


class StepProbeDataset(Dataset):
    def __init__(self, items, hidden_states, n_layer):
        self.items = items
        self.hidden_states = hidden_states
        self.n_layer = n_layer

    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, idx):
        item = self.items[idx]
        hidden_states = self.hidden_states[self.n_layer][item["idx"]]
        pos = item["pos"]
        action = item["action"]
        return {
            "input": hidden_states[pos:pos+2].mean(0),
            "labels": action_to_label(action)
        }


In [203]:
train_dataset = StepProbeDataset(train_items, layer_hidden_states, 0)
test_dataset = StepProbeDataset(test_items, layer_hidden_states, 0)

In [204]:
class StepProbe(torch.nn.Module):
    def __init__(self, hidden_size, n_blocks):
        super().__init__()
        self.fc = torch.nn.Linear(hidden_size, n_blocks)
    
    def forward(self, x):
        return self.fc(x)

In [205]:
n_dim = 5120
probe = StepProbe(n_dim, n_blocks).to(device)

In [206]:
from  torch.optim import Adam
from torch.utils.data import DataLoader
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss

optimizer = Adam(probe.parameters(), lr=1e-3)

criterion = CrossEntropyLoss()

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [207]:
n_epochs = 40
for epoch in range(n_epochs):
    probe.train()

    total_loss = 0
    n_samples = 0

    for batch in train_loader:
        optimizer.zero_grad()

        input = batch["input"].to(device)
        labels = batch["labels"].to(device)

        output = probe(input)
        loss = criterion(output, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * len(batch)
        n_samples += len(batch)

    probe.eval()
    with torch.no_grad():
        total_loss = 0
        n_samples = 0
        for batch in test_loader:
            input = batch["input"].to(device)
            labels = batch["labels"].to(device)

            output = probe(input) 

            accuracy = (output.argmax(dim=-1) == labels).float().mean()

            total_loss += accuracy.item() * len(batch)
            n_samples += len(batch)

        print(f"Epoch {epoch} Accuracy: {total_loss / n_samples}")


Epoch 0 Accuracy: 0.196250111485521
Epoch 1 Accuracy: 0.22697148844599724
Epoch 2 Accuracy: 0.23115317896008492
Epoch 3 Accuracy: 0.24086872364083925
Epoch 4 Accuracy: 0.2599715106189251
Epoch 5 Accuracy: 0.2671329689522584
Epoch 6 Accuracy: 0.28927951430281
Epoch 7 Accuracy: 0.2655554444839557
Epoch 8 Accuracy: 0.2744891829788685
Epoch 9 Accuracy: 0.27208811913927394
Epoch 10 Accuracy: 0.289655115455389
Epoch 11 Accuracy: 0.29696959629654884
Epoch 12 Accuracy: 0.305961761623621
Epoch 13 Accuracy: 0.3017411194741726
Epoch 14 Accuracy: 0.29608484730124474
Epoch 15 Accuracy: 0.3043953664600849
Epoch 16 Accuracy: 0.26832376296321553
Epoch 17 Accuracy: 0.2915470314522584
Epoch 18 Accuracy: 0.3105913909773032
Epoch 19 Accuracy: 0.30655159428715706
Epoch 20 Accuracy: 0.28726518029967946
Epoch 21 Accuracy: 0.3194611382981141
Epoch 22 Accuracy: 0.3163895569741726
Epoch 23 Accuracy: 0.299125824123621
Epoch 24 Accuracy: 0.3040197653075059
Epoch 25 Accuracy: 0.30626780663927394
Epoch 26 Accuracy:

In [200]:
def labels_to_stack(labels, n_blocks):
    stack = []

    for i in range(n_blocks):
        block = chr(65 + labels[i].item() - 1)
        stack.append(block)

    return stack


for item in testing_dataset:
    input = torch.tensor(item["input"]).unsqueeze(0).to(device)
    label = item["labels"].unsqueeze(0).to(device)

    with torch.no_grad():
        output = probe(input).detach().cpu()

    print(labels_to_stack(output.argmax(dim=-2)[0], n_blocks), labels_to_stack(label[0], n_blocks))

['A', 'D', 'E', 'C', '@', '@'] ['E', 'A', 'D', 'B', '@', '@']
['B', 'D', 'E', 'C', '@', '@'] ['E', 'A', 'D', 'B', '@', '@']
['C', 'B', 'D', 'B', '@', '@'] ['C', 'D', 'B', 'E', '@', '@']
['A', 'B', 'E', '@', '@', '@'] ['A', 'B', 'E', '@', '@', '@']
['A', 'A', 'D', '@', '@', '@'] ['D', 'B', 'A', '@', '@', '@']
['B', 'D', 'C', 'A', '@', '@'] ['B', 'C', 'E', 'D', '@', '@']
['B', 'E', 'C', '@', '@', '@'] ['B', 'C', 'E', '@', '@', '@']
['B', 'C', '@', '@', '@', '@'] ['B', 'C', '@', '@', '@', '@']
['A', 'D', 'B', '@', '@', '@'] ['A', 'D', 'B', '@', '@', '@']
['A', 'D', 'B', '@', '@', '@'] ['A', 'D', 'B', '@', '@', '@']
['E', 'A', '@', '@', '@', '@'] ['E', 'A', '@', '@', '@', '@']
['E', 'A', 'D', '@', '@', '@'] ['E', 'A', 'D', '@', '@', '@']
['B', 'E', 'E', '@', '@', '@'] ['B', 'E', 'A', '@', '@', '@']
['B', 'F', 'A', 'C', '@', '@'] ['B', 'E', 'A', 'C', '@', '@']
['B', 'E', 'E', '@', '@', '@'] ['B', 'E', 'A', '@', '@', '@']
['B', 'E', 'A', 'C', '@', '@'] ['B', 'E', 'A', 'C', '@', '@']
['C', 'E

In [51]:
selected_idx = expanded_testing_data[0]["idx"]

In [55]:
preds = probe(torch.tensor(layer_hidden_states[0][selected_idx]).to(device)).detach().cpu()

In [57]:
preds = preds.argmax(dim=-2)

In [68]:
text_preds = ["".join([chr(65 + x - 1) for x in y if x != 0]) for y in preds]

In [71]:

text_preds[-200:]

['BCEDAA',
 'CCFFCE',
 'BCFECE',
 'CC',
 'DECFFD',
 'BEDF',
 'BEDF',
 'BABA',
 'FDC',
 'BD',
 'FCDFD',
 'BD',
 'BACDFC',
 'BDC',
 'BD',
 'DBC',
 'CB',
 'BB',
 'B',
 'CEDA',
 'DEDFB',
 'CD',
 'BCDD',
 'CBBCC',
 'CBBDA',
 'BBD',
 'FBCDFD',
 'BD',
 'CB',
 'BB',
 'DCC',
 'BDCA',
 'BD',
 'DA',
 'DA',
 'BA',
 'CEC',
 'DEC',
 'BED',
 'BA',
 'BDC',
 'BDC',
 'CA',
 'BCAFD',
 'AEDA',
 'BEDCF',
 'CEBAF',
 'BB',
 'AB',
 'BCC',
 'BD',
 'BDE',
 'BAC',
 'BE',
 'DAC',
 'CA',
 'BA',
 'BAC',
 'BD',
 'DB',
 'DA',
 'BCD',
 'BAD',
 'BAD',
 'CDE',
 'BBEEFE',
 'AADB',
 'CADBCE',
 'BADDCD',
 'CADBCF',
 'CBAFC',
 'BFCDC',
 'CADECF',
 'BFFBAD',
 'CFE',
 'BCADD',
 'AFDBB',
 'CFDFCA',
 'ABDDC',
 'DFDFB',
 'BADFCE',
 'ABFF',
 'AACCCB',
 'AFEFF',
 'CFF',
 'EDEDE',
 'BED',
 'BDAEDA',
 'FDAACF',
 'AEFDA',
 'BDAEFA',
 'BBAECA',
 'CBCECC',
 'BBEF',
 'BDCEC',
 'BDEC',
 'CBCE',
 'ADE',
 'ADE',
 'BD',
 'DEAEDC',
 'BDABBA',
 'ABCEFA',
 'ABCE',
 'BB',
 'BE',
 'BDAAFA',
 'BDAECA',
 'CDCE',
 'B',
 'DA',
 'BACDC',
 'CACD',
 'D

In [67]:
print(correct_instances[selected_idx]["bench_item"]["full_response"])

<think>
Okay, so I have this block stacking problem to solve. Let me try to figure out the steps needed to get from the initial state to the goal state. I'll take it step by step because I'm still getting the hang of these kinds of problems.

First, let me understand the initial conditions. The problem says:

- Block B is clear.
- Block C is clear.
- The hand is empty.
- Block A is on top of Block E.
- Block B is on top of Block D.
- Block D is on top of Block A.
- Block C is on the table.
- Block E is on the table.

So, visualizing this, I think the stacks are as follows:

- E is on the table.
- A is on top of E.
- D is on top of A.
- B is on top of D.
- C is on the table, clear.

And the goal is:

- A is on top of D.
- C is on top of A.
- D is on top of E.
- E is on top of B.

Wait, that seems a bit confusing. Let me parse the goal again:

"Block A is on top of Block D, Block C is on top of Block A, Block D is on top of Block E and Block E is on top of Block B."

So, the final stacks

In [60]:
import plotly.express as px

px.imshow(preds[0].numpy())

ValueError: px.imshow only accepts 2D single-channel, RGB or RGBA images. An image of shape (6,) was provided. Alternatively, 3- or 4-D single or multichannel datasets can be visualized using the `facet_col` or/and `animation_frame` arguments.

In [185]:

dataset_mystery = load_dataset(f"dmitriihook/deepseek-r1-qwen-32b-planning-mystery")["train"]

README.md:   0%|          | 0.00/67.5k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/35.9M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2114 [00:00<?, ? examples/s]

In [195]:
x = dataset[0]

In [196]:
print(x["generation"])

<think>
Okay, so I have this problem where I need to rearrange some blocks. Let me try to figure out the steps I need to take. 

First, let me understand the initial setup. The blocks are A, B, C, D. The initial conditions are:

- Block A is clear.
- Block B is clear.
- Hand is empty.
- Block A is on top of Block D.
- Block B is on top of Block C.
- Block C is on the table.
- Block D is on the table.

So, visualizing this, I have two separate stacks. One stack is D with A on top, and another stack is C with B on top. Both D and C are on the table. 

My goal is to have:

- Block A on top of Block D.
- Block B on top of Block C.
- Block D on top of Block B.

Wait, that seems a bit confusing. Let me parse that again. The goal is:

- A is on D.
- B is on C.
- D is on B.

So, the final arrangement should be a stack where D is on B, which is on C, and A is on D. So the order from bottom to top would be C, B, D, A. But wait, that can't be because D is on B, which is on C, and A is on D. So th