In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

In [4]:
os.chdir("..")

In [6]:
import torch
from transformers import AutoTokenizer
from tqdm.auto import tqdm
from datasets import load_dataset
from cluster_intrep_repo.utils import initialize_tokenizer, tokenize_blocksworld_generation, THINK_TOKEN



os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

compute_dtype = torch.bfloat16
device   = 'cuda'
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"

In [7]:
import numpy as np

In [8]:
import json 

with open("planning_metadata.json") as f:
    metadata = json.load(f)


In [207]:
n_blocks = 8
n_instances = 10000

correct_instances = []

import re

regex = re.compile(r"stack is [A-Z](?:\s?(?:->|-|,|→|on)\s?[A-Z])*")

regex2 = re.compile(r"stack is [A-Z] (table)")
regex3 = re.compile(r"stack is [A-Z] with")

def extract_blocks(text):
    if "on" in text:
        text = text[::-1]

    return re.findall(r"[A-Z]", text)


from collections import Counter

# re.findall(regex, text)
for i, x in enumerate(metadata):
    if x["bench_item"]["Number of blocks"] <= n_blocks and x["bench_item"]["llm_correct"] == True: 
        text = x["bench_item"]["full_response"]

        matches = re.findall(regex, text)
        matches2 = re.findall(regex2, text)
        matches3 = re.findall(regex3, text)

        if len(matches) > 2 and len(matches2) == 0 and len(matches3) == 0:
            # print(i, re.findall(regex, text))
            all_stacks = [extract_blocks(y) for y in matches]

            max_len = max([len(y) for y in all_stacks])
            inner_stacks = [y for y in all_stacks if len(y) > 1 and len(y) < max_len]

            if len(inner_stacks) > 0:
                correct_instances.append(x)
                
print(len(correct_instances))

327


In [202]:
tokenizer = initialize_tokenizer(model_id)

In [203]:
blocksworld_type = "big"

dataset = load_dataset(f"dmitriihook/deepseek-r1-qwen-32b-planning-{blocksworld_type}")["train"]

In [204]:
from transformers import AutoModelForCausalLM
model     = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=compute_dtype, attn_implementation="sdpa", device_map="auto")

Loading checkpoint shards: 100%|██████████| 8/8 [00:13<00:00,  1.65s/it]


In [208]:
from collections import defaultdict

# [src; dest]

layer_hidden_states = defaultdict(list)

n_last_layers = 10

for instance in tqdm(correct_instances):
    row = dataset[instance["dataset_idx"]]
    chat = tokenize_blocksworld_generation(tokenizer, row)

    think_pos = torch.where(chat.squeeze() == THINK_TOKEN)[0]

    if len(think_pos) == 0:
        for j in range(n_last_layers):
            layer_hidden_states[j].append(None)

        continue

    with torch.no_grad():
        outputs = model(chat.to(device), output_hidden_states=True)

        for j in range(n_last_layers):
            hidden_states = outputs.hidden_states[-1 - j]
            layer_hidden_states[j].append(hidden_states[0].float().cpu().numpy())        

  0%|          | 0/327 [00:00<?, ?it/s]

100%|██████████| 327/327 [05:32<00:00,  1.02s/it]


In [209]:
training_data = []

for i, instance in tqdm(enumerate(correct_instances)):
    row = dataset[instance["dataset_idx"]]
    text =  instance["bench_item"]["full_response"]

    data_item = {
        "stacks": [],
        "positions": [],
        "idx": i
    }

    matches = re.findall(regex, text)

    all_stacks = [extract_blocks(y) for y in matches]

    max_len = max([len(y) for y in all_stacks])

    for match in re.finditer(regex, text):

        stack = extract_blocks(match.group(0))

        if len(stack) == 1:
            continue

        if len(stack) == max_len:
            continue

        start = match.start()
        
        tokens = tokenize_blocksworld_generation(tokenizer, row, generation=text[:start])[0]

        stack = extract_blocks(match.group(0))

        data_item["stacks"].append(stack)
        data_item["positions"].append(len(tokens))

    training_data.append(data_item)

327it [00:04, 68.36it/s]


In [210]:
train_test_split = 0.8
n_train = int(len(training_data) * train_test_split)

print(n_train)

training_data, testing_data = training_data[:n_train], training_data[n_train:]

261


In [211]:
expanded_training_data = []
expanded_testing_data = []

for data in training_data:
    for stack, pos in zip(data["stacks"], data["positions"]):
        expanded_training_data.append({
            "stack": stack,
            "position": pos,
            "idx": data["idx"]
        })

for data in testing_data:
    for stack, pos in zip(data["stacks"], data["positions"]):
        expanded_testing_data.append({
            "stack": stack,
            "position": pos,
            "idx": data["idx"]
        })

In [212]:
len(expanded_testing_data)

242

In [213]:
expanded_testing_data[0]

{'stack': ['B', 'E', 'C', 'D'], 'position': 1755, 'idx': 261}

In [214]:
from torch.utils.data import DataLoader, Dataset

block_to_idx = {f"{chr(65 + i)}": i + 1 for i in range(n_blocks)}

def stack_to_labels(stack, n_blocks):
    labels = torch.zeros(n_blocks,  dtype=torch.int64)
    
    for i, block in enumerate(stack[:n_blocks]):
        labels[i] = block_to_idx[block]

    return labels


class StackProbing(Dataset):
    def __init__(self, data, layer_hidden_states, layer):
        self.data = data
        self.layer_hidden_states = layer_hidden_states
        self.layer = layer

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]

        hidden_states = self.layer_hidden_states[self.layer][item["idx"]][item["position"] - 3:item["position"]].mean(axis=0)

        return {
            "input": hidden_states,
            "labels": stack_to_labels(item["stack"], n_blocks),
        }

In [215]:
training_dataset = StackProbing(expanded_training_data, layer_hidden_states, 0)
testing_dataset = StackProbing(expanded_testing_data, layer_hidden_states, 0)

In [216]:
training_dataset[2]

{'input': array([ 0.84309894,  0.35677084,  0.4671224 , ..., -0.48453775,
         1.4309896 ,  1.4166666 ], dtype=float32),
 'labels': tensor([4, 3, 2, 0, 0, 0, 0, 0])}

In [221]:
class StateProbe(torch.nn.Module):
    def __init__(self, n_blocks, input_size, hidden_size):
        super().__init__()

        # self.layer1 = torch.nn.Linear(input_size, hidden_size, dtype=torch.float32)
        self.layer2 = torch.nn.Linear(input_size, (n_blocks + 1) * n_blocks, dtype=torch.float32)

    def forward(self, x):
        x  = self.layer2(x)
        # x = torch.relu(x)
        # x = self.layer2(x)
        return x.view(x.shape[0], -1, n_blocks)


n_dim = 5120

In [222]:
from  torch.optim import Adam
from torch.utils.data import DataLoader
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss

probe = StateProbe(n_blocks, n_dim, 1000).to(device)

optimizer = Adam(probe.parameters(), lr=1e-3)

criterion = CrossEntropyLoss()

train_loader = DataLoader(training_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(testing_dataset, batch_size=64, shuffle=False)

In [223]:
len(training_dataset)

925

In [224]:
n_epochs = 40
for epoch in range(n_epochs):
    probe.train()

    total_loss = 0
    n_samples = 0

    for batch in train_loader:
        optimizer.zero_grad()

        input = batch["input"].to(device)
        labels = batch["labels"].to(device)

        output = probe(input)
        loss = criterion(output, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * len(batch)
        n_samples += len(batch)

    probe.eval()
    with torch.no_grad():
        total_loss = 0
        n_samples = 0
        for batch in test_loader:
            input = batch["input"].to(device)
            labels = batch["labels"].to(device)

            output = probe(input) 

            accuracy = (output.argmax(dim=-2) == labels).float().mean()

            total_loss += accuracy.item() * len(batch)
            n_samples += len(batch)

        print(f"Epoch {epoch} Accuracy: {total_loss / n_samples}")


Epoch 0 Accuracy: 0.5802734345197678
Epoch 1 Accuracy: 0.5833203122019768
Epoch 2 Accuracy: 0.6080664023756981
Epoch 3 Accuracy: 0.61279296875
Epoch 4 Accuracy: 0.6285351514816284
Epoch 5 Accuracy: 0.6127343699336052
Epoch 6 Accuracy: 0.6422656178474426
Epoch 7 Accuracy: 0.6302148401737213
Epoch 8 Accuracy: 0.6348828077316284
Epoch 9 Accuracy: 0.6261913999915123
Epoch 10 Accuracy: 0.6357812434434891
Epoch 11 Accuracy: 0.6322460919618607
Epoch 12 Accuracy: 0.6314062476158142
Epoch 13 Accuracy: 0.6363281160593033
Epoch 14 Accuracy: 0.6363476514816284
Epoch 15 Accuracy: 0.6329882740974426
Epoch 16 Accuracy: 0.6417773365974426
Epoch 17 Accuracy: 0.6394726485013962
Epoch 18 Accuracy: 0.6336328089237213
Epoch 19 Accuracy: 0.6461718678474426
Epoch 20 Accuracy: 0.6436718702316284
Epoch 21 Accuracy: 0.6414257735013962
Epoch 22 Accuracy: 0.6386523395776749
Epoch 23 Accuracy: 0.6440234333276749
Epoch 24 Accuracy: 0.6406640559434891
Epoch 25 Accuracy: 0.6370507776737213
Epoch 26 Accuracy: 0.644980

In [200]:
def labels_to_stack(labels, n_blocks):
    stack = []

    for i in range(n_blocks):
        block = chr(65 + labels[i].item() - 1)
        stack.append(block)

    return stack


for item in testing_dataset:
    input = torch.tensor(item["input"]).unsqueeze(0).to(device)
    label = item["labels"].unsqueeze(0).to(device)

    with torch.no_grad():
        output = probe(input).detach().cpu()

    print(labels_to_stack(output.argmax(dim=-2)[0], n_blocks), labels_to_stack(label[0], n_blocks))

['A', 'D', 'E', 'C', '@', '@'] ['E', 'A', 'D', 'B', '@', '@']
['B', 'D', 'E', 'C', '@', '@'] ['E', 'A', 'D', 'B', '@', '@']
['C', 'B', 'D', 'B', '@', '@'] ['C', 'D', 'B', 'E', '@', '@']
['A', 'B', 'E', '@', '@', '@'] ['A', 'B', 'E', '@', '@', '@']
['A', 'A', 'D', '@', '@', '@'] ['D', 'B', 'A', '@', '@', '@']
['B', 'D', 'C', 'A', '@', '@'] ['B', 'C', 'E', 'D', '@', '@']
['B', 'E', 'C', '@', '@', '@'] ['B', 'C', 'E', '@', '@', '@']
['B', 'C', '@', '@', '@', '@'] ['B', 'C', '@', '@', '@', '@']
['A', 'D', 'B', '@', '@', '@'] ['A', 'D', 'B', '@', '@', '@']
['A', 'D', 'B', '@', '@', '@'] ['A', 'D', 'B', '@', '@', '@']
['E', 'A', '@', '@', '@', '@'] ['E', 'A', '@', '@', '@', '@']
['E', 'A', 'D', '@', '@', '@'] ['E', 'A', 'D', '@', '@', '@']
['B', 'E', 'E', '@', '@', '@'] ['B', 'E', 'A', '@', '@', '@']
['B', 'F', 'A', 'C', '@', '@'] ['B', 'E', 'A', 'C', '@', '@']
['B', 'E', 'E', '@', '@', '@'] ['B', 'E', 'A', '@', '@', '@']
['B', 'E', 'A', 'C', '@', '@'] ['B', 'E', 'A', 'C', '@', '@']
['C', 'E

In [51]:
selected_idx = expanded_testing_data[0]["idx"]

In [55]:
preds = probe(torch.tensor(layer_hidden_states[0][selected_idx]).to(device)).detach().cpu()

In [57]:
preds = preds.argmax(dim=-2)

In [68]:
text_preds = ["".join([chr(65 + x - 1) for x in y if x != 0]) for y in preds]

In [71]:

text_preds[-200:]

['BCEDAA',
 'CCFFCE',
 'BCFECE',
 'CC',
 'DECFFD',
 'BEDF',
 'BEDF',
 'BABA',
 'FDC',
 'BD',
 'FCDFD',
 'BD',
 'BACDFC',
 'BDC',
 'BD',
 'DBC',
 'CB',
 'BB',
 'B',
 'CEDA',
 'DEDFB',
 'CD',
 'BCDD',
 'CBBCC',
 'CBBDA',
 'BBD',
 'FBCDFD',
 'BD',
 'CB',
 'BB',
 'DCC',
 'BDCA',
 'BD',
 'DA',
 'DA',
 'BA',
 'CEC',
 'DEC',
 'BED',
 'BA',
 'BDC',
 'BDC',
 'CA',
 'BCAFD',
 'AEDA',
 'BEDCF',
 'CEBAF',
 'BB',
 'AB',
 'BCC',
 'BD',
 'BDE',
 'BAC',
 'BE',
 'DAC',
 'CA',
 'BA',
 'BAC',
 'BD',
 'DB',
 'DA',
 'BCD',
 'BAD',
 'BAD',
 'CDE',
 'BBEEFE',
 'AADB',
 'CADBCE',
 'BADDCD',
 'CADBCF',
 'CBAFC',
 'BFCDC',
 'CADECF',
 'BFFBAD',
 'CFE',
 'BCADD',
 'AFDBB',
 'CFDFCA',
 'ABDDC',
 'DFDFB',
 'BADFCE',
 'ABFF',
 'AACCCB',
 'AFEFF',
 'CFF',
 'EDEDE',
 'BED',
 'BDAEDA',
 'FDAACF',
 'AEFDA',
 'BDAEFA',
 'BBAECA',
 'CBCECC',
 'BBEF',
 'BDCEC',
 'BDEC',
 'CBCE',
 'ADE',
 'ADE',
 'BD',
 'DEAEDC',
 'BDABBA',
 'ABCEFA',
 'ABCE',
 'BB',
 'BE',
 'BDAAFA',
 'BDAECA',
 'CDCE',
 'B',
 'DA',
 'BACDC',
 'CACD',
 'D

In [67]:
print(correct_instances[selected_idx]["bench_item"]["full_response"])

<think>
Okay, so I have this block stacking problem to solve. Let me try to figure out the steps needed to get from the initial state to the goal state. I'll take it step by step because I'm still getting the hang of these kinds of problems.

First, let me understand the initial conditions. The problem says:

- Block B is clear.
- Block C is clear.
- The hand is empty.
- Block A is on top of Block E.
- Block B is on top of Block D.
- Block D is on top of Block A.
- Block C is on the table.
- Block E is on the table.

So, visualizing this, I think the stacks are as follows:

- E is on the table.
- A is on top of E.
- D is on top of A.
- B is on top of D.
- C is on the table, clear.

And the goal is:

- A is on top of D.
- C is on top of A.
- D is on top of E.
- E is on top of B.

Wait, that seems a bit confusing. Let me parse the goal again:

"Block A is on top of Block D, Block C is on top of Block A, Block D is on top of Block E and Block E is on top of Block B."

So, the final stacks

In [60]:
import plotly.express as px

px.imshow(preds[0].numpy())

ValueError: px.imshow only accepts 2D single-channel, RGB or RGBA images. An image of shape (6,) was provided. Alternatively, 3- or 4-D single or multichannel datasets can be visualized using the `facet_col` or/and `animation_frame` arguments.