In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

In [3]:
os.chdir("..")

In [4]:
import torch
import numpy as np
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
from tqdm.auto import tqdm
from datasets import load_dataset
from cluster_intrep_repo.utils import initialize_tokenizer, tokenize_blocksworld_generation, THINK_TOKEN



os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

compute_dtype = torch.bfloat16
device   = 'cuda'
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"

In [5]:
tokenizer = initialize_tokenizer(model_id)

In [6]:
blocksworld_type = "big"

dataset = load_dataset(f"dmitriihook/deepseek-r1-qwen-32b-planning-{blocksworld_type}")["train"]

In [7]:
model     = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=compute_dtype, attn_implementation="sdpa", device_map="auto")

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [8]:
n_rows = 1000

In [None]:
from collections import defaultdict

# [src; dest]

layer_hidden_states = defaultdict(list)

n_last_layers = 10

for row in tqdm(dataset.select(range(n_rows))):
    generation = row["generation"]

    if "[PLAN END]" not in generation:
        for j in range(n_last_layers):
            layer_hidden_states[j].append(None) 
        continue

    chat = tokenize_blocksworld_generation(tokenizer, row)

    # think_pos = torch.where(chat.squeeze() == THINK_TOKEN)[0]

    with torch.no_grad():
        outputs = model(chat.to(device), output_hidden_states=True)

        for j in range(n_last_layers):
            hidden_states = outputs.hidden_states[-1 - j]
            layer_hidden_states[j].append(hidden_states[0].float().cpu().numpy())

  0%|          | 0/1000 [00:00<?, ?it/s]

In [49]:
for j in range(n_last_layers):
    layer_hidden_states[j] = [x for x in layer_hidden_states[j] if x is not None]
    print(len(layer_hidden_states[j]))

1000
1000
1000
1000
1000
1000
1000
1000
1000
1000


In [24]:
dataset[0]["generation"][-200:]

"'s the correct plan.\n</think>\n\n[PLAN]\nunstack Block A from on top of Block D\nput down Block A\npick up Block D\nstack Block D on top of Block B\npick up Block A\nstack Block A on top of Block D\n[PLAN END]"

In [25]:
def extract_actions(row):
    generation = row["generation"]
    if "[PLAN]" not in generation:
        return None
    if "[PLAN END]" not in generation:
        return None
    
    plan_start = generation.index("[PLAN]") + len("[PLAN]")
    plan = generation[plan_start:].strip()
    plan = plan.split("[PLAN END]")[0].strip()
    actions = plan.split("\n")

    return actions
    
extract_actions(dataset[0])

['unstack Block A from on top of Block D',
 'put down Block A',
 'pick up Block D',
 'stack Block D on top of Block B',
 'pick up Block A',
 'stack Block A on top of Block D']

In [26]:
import re

def parse_block_actions(commands):
    actions = ["unstack", "put down", "pick up", "stack"]
    parsed_commands = []

    for command in commands:
        for action in actions:
            if command.startswith(action):
                blocks = re.findall(r'Block [A-Z]', command)
                blocks = [block.split()[-1] for block in blocks]  # Extract only the letter
                parsed_commands.append((action, blocks))
                break

    return parsed_commands

parse_block_actions(extract_actions(dataset[0]))

[('unstack', ['A', 'D']),
 ('put down', ['A']),
 ('pick up', ['D']),
 ('stack', ['D', 'B']),
 ('pick up', ['A']),
 ('stack', ['A', 'D'])]

In [32]:
n_rows = 1000

In [141]:
training_data = []
for i, row in enumerate(tqdm(dataset.select(range(n_rows)))):
    actions = extract_actions(row)
    if actions is None:
        continue
    parsed_actions = parse_block_actions(actions)
    
    generation = row["generation"]
    plan_start = generation.index("[PLAN]\n") + len("[PLAN]\n")
    plan = generation[plan_start:]
    
    text = generation[:plan_start]

    group = []

    for action, line in zip(parsed_actions, plan.split("\n")):
        if "Block" not in line:
            continue
    
        block_pos = line.index("Block")
        first_part = line[:block_pos] + "Block"
        _text = text + first_part
        tokens = tokenize_blocksworld_generation(tokenizer, row, _text)[0]
        group.append({
            "idx": i,
            "action": action,
            "pos": len(tokens) - 1,
        })
        text += line + "\n"

    training_data.append(group)


  0%|          | 0/1000 [00:00<?, ?it/s]

In [142]:
n_blocks = int(dataset[n_rows - 1]["instance_id"].split("_")[0])
n_blocks

9

In [316]:
expanded_training_data = []

jump = 3

for group in training_data:
    for action1, action2 in zip(group, group[jump:]):
        if len(action1["action"][1]) < 1:
            continue
        expanded_training_data.append((action1, action2))
        # continue

In [317]:
train_test_split = 0.8
n_train = int(len(expanded_training_data) * train_test_split)

train_items = expanded_training_data[:n_train]
test_items = expanded_training_data[n_train:]

In [318]:
len(train_items), len(test_items)

(7999, 2000)

In [None]:
item = train_items[1][0]

row = dataset[item["idx"]]

tokens = tokenize_blocksworld_generation(tokenizer, row)[0]

tokenizer.decode(tokens[item["pos"] - 1:item["pos"]])

{'idx': 0, 'action': ('put down', ['A']), 'pos': 2367}


' Block'

In [335]:
print(dataset[3]["generation"][-1000:])

plan is:

1. Unstack A from C.
2. Put down A.
3. Unstack B from D.
4. Put down B.
5. Pick up C.
6. Stack C on B.
7. Pick up A.
8. Stack A on C.
9. Pick up D.
10. Stack D on A.

Wait, but in the initial problem statement, the goal is:

- Block A is on top of Block C,
- Block C is on top of Block B,
- Block D is on top of Block A.

So the final stack is B -> C -> A -> D.

Yes, that's correct.

I think this plan should work. Let me check if any steps violate the rules.

- Each unstack is done when the block is clear and on top.
- Each pick up is from the table or unstacking, and the hand is empty before.
- Each stack is done on a clear block.
- After each stack, the hand is empty.

Yes, all rules are followed.
</think>

[PLAN]
unstack Block A from on top of Block C
put down Block A
unstack Block B from on top of Block D
put down Block B
pick up Block C
stack Block C on top of Block B
pick up Block A
stack Block A on top of Block C
pick up Block D
stack Block D on top of Block A
[PLAN END]

In [320]:
from torch.utils.data import Dataset

act2int = {
    "put down": 0,
    "pick up": 1,
    "stack": 2,
    "unstack": 3
}

block2int = {chr(ord("A") + i): i for i in range(n_blocks)}
int2block = {v: k for k, v in block2int.items()}

n_prev_tokens = 100

def action_to_label(action):
    # print(action)
    action_type = act2int[action[0]]
    blocks = action[1]
    blocks = [block2int[block] for block in blocks]
    block = blocks[0]
    return block


class StepProbeDataset(Dataset):
    def __init__(self, items, hidden_states, n_layer):
        self.items = items
        self.hidden_states = hidden_states
        self.n_layer = n_layer

    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, idx):
        action1, action2 = self.items[idx]
        hidden_states = self.hidden_states[self.n_layer][action1["idx"]]
        pos = action1["pos"]
        action = action2["action"]
        return {
            "input": hidden_states[pos - n_prev_tokens:pos],
            "labels": action_to_label(action)
        }


In [321]:
train_dataset = StepProbeDataset(train_items, layer_hidden_states, 0)
test_dataset = StepProbeDataset(test_items, layer_hidden_states, 0)

In [322]:
class StepProbe(torch.nn.Module):
    def __init__(self, hidden_size, n_blocks):
        super().__init__()
        self.fc = torch.nn.Linear(hidden_size, hidden_size)
        self.fc2 = torch.nn.Linear(hidden_size, n_blocks)
        
    def forward(self, x):
        # x = self.fc(x)
        # x = torch.nn.functional.relu(x)
        x = self.fc2(x)
        return x

In [323]:
class GRUProbe(torch.nn.Module):
    def __init__(self, input_size, hidden_size, n_blocks):
        super().__init__()
        self.gru = torch.nn.GRU(input_size, hidden_size, batch_first=True)
        self.fc = torch.nn.Linear(hidden_size, n_blocks)
        
    def forward(self, x):
        x, _ = self.gru(x)
        x = x[:, -1]
        x = self.fc(x)
        return x

In [324]:
n_dim = 5120
probe = GRUProbe(n_dim, 500, n_blocks).to(device)

In [325]:
from  torch.optim import Adam
from torch.utils.data import DataLoader
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss

optimizer = Adam(probe.parameters(), lr=3e-3)

criterion = CrossEntropyLoss()

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

In [326]:
from sklearn.metrics import f1_score, accuracy_score

n_epochs = 50
for epoch in range(n_epochs):
    probe.train()

    total_loss = 0
    n_samples = 0

    for batch in train_loader:
        optimizer.zero_grad()

        input = batch["input"].to(device)
        labels = batch["labels"].to(device)

        output = probe(input)
        loss = criterion(output, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * len(batch)
        n_samples += len(batch)

    probe.eval()
    with torch.no_grad():
        preds = []
        targets = []
        for batch in test_loader:
            input = batch["input"].to(device)
            labels = batch["labels"].to(device)

            output = probe(input) 
            
            preds.append(output.argmax(dim=1).cpu().numpy())
            targets.append(labels.cpu().numpy())

        preds = np.concatenate(preds)
        targets = np.concatenate(targets)

        print(f"Epoch {epoch} F1: {f1_score(targets, preds, average='macro')} Acc: {accuracy_score(targets, preds)}")


Epoch 0 F1: 0.0558820230552361 Acc: 0.128
Epoch 1 F1: 0.058454464811435666 Acc: 0.1365
Epoch 2 F1: 0.029403677907474248 Acc: 0.135
Epoch 3 F1: 0.07130594048564677 Acc: 0.1305
Epoch 4 F1: 0.05329548402175018 Acc: 0.129
Epoch 5 F1: 0.07412415535936395 Acc: 0.13
Epoch 6 F1: 0.06207043311492185 Acc: 0.1355
Epoch 7 F1: 0.058807866057431095 Acc: 0.132
Epoch 8 F1: 0.046629206419704215 Acc: 0.126
Epoch 9 F1: 0.07043851242930038 Acc: 0.14
Epoch 10 F1: 0.06332865480287962 Acc: 0.132
Epoch 11 F1: 0.06180274323905499 Acc: 0.142
Epoch 12 F1: 0.08346407583300436 Acc: 0.1395
Epoch 13 F1: 0.05586022726557697 Acc: 0.1435
Epoch 14 F1: 0.08016903814348408 Acc: 0.1405


KeyboardInterrupt: 