In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import json
import numpy as np
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import trange, tqdm
from utils import Probe, set_seed, train_probe

import os
os.chdir("..")

seed = 42
set_seed(seed)

activations_file = "planning_activations_32b_big_step.pt"
metadata_file = "planning_metadata.json"

activations = torch.load(activations_file)

with open(metadata_file) as f:
    metadata = json.load(f)

  from .autonotebook import tqdm as notebook_tqdm
  activations = torch.load(activations_file)


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
dataset = []

all_steps = set()

for x in metadata:
    dataset_idx = x["dataset_idx"]
    activation = activations[dataset_idx]

    extracted_plan = x["bench_item"]["extracted_llm_plan"]

    validity = x["bench_item"]["llm_validity"]

    if validity != 1:
        continue

    steps = set(extracted_plan.split("\n"))

    steps = {step for step in steps if step != ""}

    all_steps.update(steps)

    think_pos = x["think_pos"]

    dataset.append({
        "activations": activation[:think_pos // 10],
        "steps": steps,
        "think_pos": think_pos
    })


print(len(dataset))

import random

random.shuffle(dataset)

test_size = 0.2
test_size = int(len(dataset) * test_size)

train_dataset = dataset[:-test_size]
test_dataset = dataset[-test_size:]



937


In [5]:
n_dim = 5120

In [6]:

print(len(all_steps))
print(all_steps)

199
{'(unstack d j)', '(unstack i c)', '(put-down b)', '(stack f e)', '(stack h g)', '(unstack h b)', '(unstack a g)', '(stack h f)', '(unstack h a)', '(put-down d)', '(pick-up d)', '(stack b c)', '(stack g c)', '(stack g f)', '(unstack d h)', '(unstack g j)', '(pick-up e)', '(stack h c)', '(stack c h)', '(stack a f)', '(unstack f g)', '(put-down j)', '(unstack g i)', '(unstack g e)', '(stack e f)', '(unstack d e)', '(unstack c f)', '(put-down i)', '(unstack c i)', '(unstack f c)', '(stack a g)', '(unstack g a)', '(stack a d)', '(unstack d i)', '(stack i f)', '(stack i c)', '(unstack j e)', '(unstack j g)', '(stack b h)', '(pick-up a)', '(unstack e b)', '(stack j d)', '(unstack a i)', '(unstack g c)', '(put-down f)', '(stack c i)', '(unstack d a)', '(put-down g)', '(stack c g)', '(put-down h)', '(unstack i h)', '(pick-up g)', '(stack b j)', '(stack g h)', '(unstack j c)', '(unstack c b)', '(unstack g h)', '(stack d b)', '(unstack i d)', '(stack h b)', '(stack d e)', '(stack a c)', '(un

In [7]:
train_data_items = {step: [] for step in all_steps}

for item in train_dataset:
    activations = item["activations"]
    steps = item["steps"]
    think_pos = item["think_pos"]

    for step in steps:
        train_data_items[step].append((activations, think_pos))

In [8]:
counts = {step: len(train_data_items[step]) for step in all_steps}

cutoff = 10

all_steps = [step for step in all_steps if counts[step] > cutoff]

print(len(all_steps))

138


In [9]:
positive_data = {
    "train": {step: [] for step in all_steps},
    "test": {step: [] for step in all_steps}
}
negative_data = {
    "train": {step: [] for step in all_steps},
    "test": {step: [] for step in all_steps}
}

for x in train_dataset:
    activations = x["activations"]
    steps = x["steps"]
    think_pos = x["think_pos"]

    for step in all_steps:
        if step in steps:
            positive_data["train"][step].append((activations, think_pos, True))
        else:
            negative_data["train"][step].append((activations, think_pos, False))

for x in test_dataset:
    activations = x["activations"]
    steps = x["steps"]
    think_pos = x["think_pos"]

    for step in all_steps:
        if step in steps:
            positive_data["test"][step].append((activations, think_pos, True))
        else:
            negative_data["test"][step].append((activations, think_pos, False))

final_steps = set()

for step in all_steps:
    pos_neg_ratio = len(positive_data["train"][step]) / len(negative_data["train"][step])

    if pos_neg_ratio < 0.5:
        continue
    if pos_neg_ratio > 2:
        continue

    final_steps.add(step)


In [10]:
for step in final_steps:
    print(step, len(positive_data["train"][step]), len(negative_data["train"][step]), len(positive_data["test"][step]), len(negative_data["test"][step]))

(pick-up e) 423 327 101 86
(put-down b) 406 344 109 78
(pick-up f) 311 439 89 98
(put-down a) 368 382 87 100
(put-down d) 387 363 91 96
(put-down c) 423 327 109 78
(put-down e) 335 415 83 104


In [23]:
class ProbeDataset(Dataset):
    def __init__(self, dataset, probe_pos, step, positive_data, negative_data, aggregate=False, balance=True):
        self.dataset = dataset
        self.probe_pos = probe_pos
        self.aggregate = aggregate

        self.positive_samples, self.negative_samples = negative_data[step], positive_data[step]

        # fix imbalance

        n_positive = len(self.positive_samples)
        n_negative = len(self.negative_samples)

        n_samples = min(n_positive, n_negative)

        if balance:
            self.positive_samples = self.positive_samples[:n_samples]
            self.negative_samples = self.negative_samples[:n_samples]

        self.samples = self.positive_samples + self.negative_samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample, _, is_positive = self.samples[idx]

        sample = sample.float()
        
        return {
            "inputs": sample,
            "label": int(is_positive)
        }


In [24]:
def collate_fn(batch):
    inputs = [x["inputs"] for x in batch]
    labels = [x["label"] for x in batch]

    # for x in inputs:
    #     print(x.shape)

    # pad inputs left
    masks = [torch.ones(x.shape[0]) for x in inputs]
    inputs = torch.nn.utils.rnn.pad_sequence(inputs, batch_first=True, padding_value=0, padding_side="left")
    masks = torch.nn.utils.rnn.pad_sequence(masks, batch_first=True, padding_value=0, padding_side="left")
    labels = torch.tensor(labels)

    # print(inputs.shape)

    return {
        "inputs": inputs.to(device),
        "label": labels.to(device),
        "mask": masks.to(device)
    }


In [25]:
train_datasets = {}
test_datasets = {}

for step in final_steps:
    train_datasets[step] = ProbeDataset(dataset, None, step, positive_data["train"], negative_data["train"], aggregate=False)
    test_datasets[step] = ProbeDataset(dataset, None, step, positive_data["test"], negative_data["test"], aggregate=False, balance=True)

In [26]:
class LSTMProbe(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__() 

        self.lstm = torch.nn.GRU(input_size, hidden_size, batch_first=True, dtype=torch.float32, num_layers=1)
        self.fc = torch.nn.Linear(hidden_size, output_size, dtype=torch.float32)

    def forward(self, x):
        x, _ = self.lstm(x)
        x = self.fc(x[:, -1, :])
        return x

In [27]:
class ScoreBasedProbe(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(ScoreBasedProbe, self).__init__()
        self.linear_per_element = torch.nn.Linear(input_size, 1)
        self.final_linear = torch.nn.Linear(input_size, output_size)

    def forward(self, x):
        # Apply linear layer to each sequence element
        scores = self.linear_per_element(x).squeeze(-1)
        # Use scores to weight the sequence elements
        weighted_sum = torch.einsum('bse,bs->be', x, scores.softmax(dim=-1))
        # Final linear layer
        output = self.final_linear(weighted_sum)
        return output


In [28]:
probes = {
    step: ScoreBasedProbe(n_dim, 256, 2).to(device) for step in final_steps
}

In [31]:
accuracy = {}

for step in tqdm(final_steps):
    accuracy[step] = train_probe(probes[step], train_datasets[step], test_datasets[step], n_epochs=10, silent=False, lr=1e-4, collate_fn=collate_fn, batch_size=16)[1]

  0%|          | 0/7 [00:00<?, ?it/s]

Epoch 0, loss: 0.8130289316177368
Epoch 0, acc: 0.7848837209302325
Epoch 1, loss: 0.4690680205821991
Epoch 1, acc: 0.8255813953488372
Epoch 2, loss: 0.3545544445514679
Epoch 2, acc: 0.8604651162790697
Epoch 3, loss: 0.4229409396648407
Epoch 3, acc: 0.8837209302325582
Epoch 4, loss: 0.388337105512619
Epoch 4, acc: 0.8895348837209303
Epoch 5, loss: 0.06689705699682236
Epoch 5, acc: 0.8662790697674418
Epoch 6, loss: 0.15065905451774597
Epoch 6, acc: 0.9069767441860465


 14%|█▍        | 1/7 [00:00<00:03,  1.51it/s]

Epoch 7, loss: 0.01759810745716095
Epoch 7, acc: 0.9069767441860465
Epoch 8, loss: 0.2219032198190689
Epoch 8, acc: 0.8953488372093024
Epoch 9, loss: 0.16440224647521973
Epoch 9, acc: 0.8953488372093024
Epoch 0, loss: 0.7017125487327576
Epoch 0, acc: 0.5
Epoch 1, loss: 0.7226201295852661
Epoch 1, acc: 0.5
Epoch 2, loss: 1.0191459655761719
Epoch 2, acc: 0.5
Epoch 3, loss: 0.7188900113105774
Epoch 3, acc: 0.5
Epoch 4, loss: 0.6932314038276672
Epoch 4, acc: 0.5
Epoch 5, loss: 1.6456764936447144
Epoch 5, acc: 0.5
Epoch 6, loss: 0.6616865992546082
Epoch 6, acc: 0.5
Epoch 7, loss: 1.0768132209777832
Epoch 7, acc: 0.5
Epoch 8, loss: 0.7110779285430908
Epoch 8, acc: 0.5


 29%|██▊       | 2/7 [00:01<00:03,  1.49it/s]

Epoch 9, loss: 0.8001207113265991
Epoch 9, acc: 0.5
Epoch 0, loss: 0.3363787829875946
Epoch 0, acc: 0.8595505617977528
Epoch 1, loss: 0.27524682879447937
Epoch 1, acc: 0.8033707865168539
Epoch 2, loss: 0.08137921243906021
Epoch 2, acc: 0.8820224719101124
Epoch 3, loss: 0.03808210417628288
Epoch 3, acc: 0.8595505617977528
Epoch 4, loss: 0.41808193922042847
Epoch 4, acc: 0.8146067415730337
Epoch 5, loss: 0.17617712914943695
Epoch 5, acc: 0.8370786516853933
Epoch 6, loss: 0.2486291080713272
Epoch 6, acc: 0.8932584269662921


 43%|████▎     | 3/7 [00:01<00:02,  1.52it/s]

Epoch 7, loss: 0.23781265318393707
Epoch 7, acc: 0.8820224719101124
Epoch 8, loss: 0.17314541339874268
Epoch 8, acc: 0.8820224719101124
Epoch 9, loss: 0.05980489030480385
Epoch 9, acc: 0.9044943820224719
Epoch 0, loss: 0.6502203941345215
Epoch 0, acc: 0.5114942528735632
Epoch 1, loss: 0.6576326489448547
Epoch 1, acc: 0.5172413793103449
Epoch 2, loss: 0.6500051617622375
Epoch 2, acc: 0.5114942528735632
Epoch 3, loss: 1.106345534324646
Epoch 3, acc: 0.5172413793103449
Epoch 4, loss: 0.7325635552406311
Epoch 4, acc: 0.5172413793103449
Epoch 5, loss: 0.649817705154419
Epoch 5, acc: 0.5057471264367817
Epoch 6, loss: 0.6459505558013916
Epoch 6, acc: 0.5114942528735632
Epoch 7, loss: 0.9898865818977356
Epoch 7, acc: 0.5
Epoch 8, loss: 0.6505158543586731
Epoch 8, acc: 0.5


 57%|█████▋    | 4/7 [00:02<00:02,  1.46it/s]

Epoch 9, loss: 0.6617122888565063
Epoch 9, acc: 0.5
Epoch 0, loss: 0.5084385275840759
Epoch 0, acc: 0.9230769230769231
Epoch 1, loss: 0.02036658115684986
Epoch 1, acc: 0.8406593406593407
Epoch 2, loss: 0.1787032037973404
Epoch 2, acc: 0.8681318681318682
Epoch 3, loss: 0.006195537280291319
Epoch 3, acc: 0.9340659340659341
Epoch 4, loss: 0.8784472346305847
Epoch 4, acc: 0.9395604395604396
Epoch 5, loss: 0.17056767642498016
Epoch 5, acc: 0.8846153846153846
Epoch 6, loss: 0.06262430548667908
Epoch 6, acc: 0.7417582417582418
Epoch 7, loss: 0.8957332968711853
Epoch 7, acc: 0.8626373626373627


 71%|███████▏  | 5/7 [00:03<00:01,  1.42it/s]

Epoch 8, loss: 0.13235758244991302
Epoch 8, acc: 0.9285714285714286
Epoch 9, loss: 0.4495219886302948
Epoch 9, acc: 0.8956043956043956
Epoch 0, loss: 0.6439560651779175
Epoch 0, acc: 0.5
Epoch 1, loss: 0.8062663674354553
Epoch 1, acc: 0.48717948717948717
Epoch 2, loss: 0.6436446309089661
Epoch 2, acc: 0.5064102564102564
Epoch 3, loss: 0.6496344804763794
Epoch 3, acc: 0.5
Epoch 4, loss: 0.6500673294067383
Epoch 4, acc: 0.5
Epoch 5, loss: 1.0445725917816162
Epoch 5, acc: 0.5
Epoch 6, loss: 0.64481121301651
Epoch 6, acc: 0.5128205128205128
Epoch 7, loss: 0.6494833827018738
Epoch 7, acc: 0.5
Epoch 8, loss: 0.6929606199264526


 86%|████████▌ | 6/7 [00:04<00:00,  1.46it/s]

Epoch 8, acc: 0.5
Epoch 9, loss: 0.6466396450996399
Epoch 9, acc: 0.5
Epoch 0, loss: 0.2596593201160431
Epoch 0, acc: 0.7710843373493976
Epoch 1, loss: 0.28494399785995483
Epoch 1, acc: 0.7831325301204819
Epoch 2, loss: 0.5925741195678711
Epoch 2, acc: 0.8072289156626506
Epoch 3, loss: 0.4335940480232239
Epoch 3, acc: 0.7409638554216867
Epoch 4, loss: 0.14539098739624023
Epoch 4, acc: 0.7951807228915663
Epoch 5, loss: 0.18049032986164093
Epoch 5, acc: 0.7409638554216867
Epoch 6, loss: 0.07467874139547348
Epoch 6, acc: 0.7710843373493976
Epoch 7, loss: 0.0376739501953125
Epoch 7, acc: 0.6987951807228916


100%|██████████| 7/7 [00:04<00:00,  1.47it/s]

Epoch 8, loss: 0.03316861763596535
Epoch 8, acc: 0.7771084337349398
Epoch 9, loss: 0.07110535353422165
Epoch 9, acc: 0.8072289156626506





In [30]:
for step in final_steps:
    propotion = len(positive_data["test"][step]) / (len(positive_data["test"][step]) + len(negative_data["test"][step]))
    print(step, accuracy[step], len(positive_data["train"][step]), len(negative_data["train"][step]), len(positive_data["test"][step]), len(negative_data["test"][step]), max(propotion, 1 - propotion))

(pick-up e) 0.75 423 327 101 86 0.5401069518716578
(put-down b) 0.5 406 344 109 78 0.5828877005347594
(pick-up f) 0.7247191011235955 311 439 89 98 0.5240641711229946
(put-down a) 0.5057471264367817 368 382 87 100 0.53475935828877
(put-down d) 0.9065934065934066 387 363 91 96 0.5133689839572193
(put-down c) 0.5064102564102564 423 327 109 78 0.5828877005347594
(put-down e) 0.7891566265060241 335 415 83 104 0.5561497326203209
