In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import json
import numpy as np
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import trange, tqdm
from utils import Probe, set_seed, train_probe

import os
os.chdir("..")

seed = 42
set_seed(seed)

activations_file = "planning_activations_32b.pt"
metadata_file = "planning_metadata.json"

activations = torch.load(activations_file)

with open(metadata_file) as f:
    metadata = json.load(f)

  from .autonotebook import tqdm as notebook_tqdm
  activations = torch.load(activations_file)


In [None]:
dataset = []

all_steps = set()

for x in metadata:
    dataset_idx = x["dataset_idx"]
    activation = activations[dataset_idx]

    extracted_plan = x["bench_item"]["extracted_llm_plan"]

    validity = x["bench_item"]["llm_validity"]

    if validity != 1:
        continue

    steps = set(extracted_plan.split("\n"))

    steps = {step for step in steps if step != ""}

    all_steps.update(steps)

    think_pos = x["think_pos"]

    dataset.append({
        "activations": activation,
        "steps": steps,
        "think_pos": think_pos
    })


print(len(dataset))

import random

random.shuffle(dataset)

test_size = 0.2
test_size = int(len(dataset) * test_size)

train_dataset = dataset[:-test_size]
test_dataset = dataset[-test_size:]



203


In [4]:
n_dim = 5120

In [5]:

print(len(all_steps))
print(all_steps)

243
{'(stack d a)', '(stack k f)', '(stack h f)', '(unstack d j)', '(unstack j g)', '(stack b h)', '(stack l b)', '(unstack j i)', '(stack c b)', '(stack g b)', '(unstack b c)', '(unstack f c)', '(unstack b f)', '(stack b f)', '(pick-up i)', '(unstack f e)', '(stack a l)', '(unstack m f)', '(stack m j)', '(unstack e f)', '(stack f k)', '(unstack e c)', '(unstack a h)', '(pick-up h)', '(stack i h)', '(unstack n f)', '(stack e f)', '(unstack c e)', '(stack k j)', '(pick-up j)', '(stack j g)', '(put-down c)', '(stack c f)', '(unstack h d)', '(unstack d b)', '(stack f a)', '(unstack e d)', '(stack h c)', '(stack n l)', '(stack j k)', '(stack c e)', '(pick-up e)', '(pick-up f)', '(unstack g c)', '(unstack h e)', '(unstack f h)', '(stack d g)', '(stack g c)', '(unstack g l)', '(stack m d)', '(stack f h)', '(stack c a)', '(put-down e)', '(unstack a g)', '(put-down h)', '(unstack b a)', '(stack c j)', '(put-down l)', '(unstack g i)', '(stack k e)', '(unstack h a)', '(stack b i)', '(unstack c k

In [6]:
train_data_items = {step: [] for step in all_steps}

for item in train_dataset:
    activations = item["activations"]
    steps = item["steps"]
    think_pos = item["think_pos"]

    for step in steps:
        train_data_items[step].append((activations, think_pos))

In [7]:
counts = {step: len(train_data_items[step]) for step in all_steps}

cutoff = 10

all_steps = [step for step in all_steps if counts[step] > cutoff]

print(len(all_steps))

57


In [8]:
positive_data = {
    "train": {step: [] for step in all_steps},
    "test": {step: [] for step in all_steps}
}
negative_data = {
    "train": {step: [] for step in all_steps},
    "test": {step: [] for step in all_steps}
}

for x in train_dataset:
    activations = x["activations"]
    steps = x["steps"]
    think_pos = x["think_pos"]

    for step in all_steps:
        if step in steps:
            positive_data["train"][step].append((activations, think_pos, True))
        else:
            negative_data["train"][step].append((activations, think_pos, False))

for x in test_dataset:
    activations = x["activations"]
    steps = x["steps"]
    think_pos = x["think_pos"]

    for step in all_steps:
        if step in steps:
            positive_data["test"][step].append((activations, think_pos, True))
        else:
            negative_data["test"][step].append((activations, think_pos, False))

final_steps = set()

for step in all_steps:
    if len(positive_data["train"][step]) > 10 and len(negative_data["train"][step]) > 10 and len(positive_data["test"][step]) > 3 and len(negative_data["test"][step]) > 3:
        final_steps.add(step)

In [9]:
for step in final_steps:
    print(step, len(positive_data["train"][step]), len(negative_data["train"][step]), len(positive_data["test"][step]), len(negative_data["test"][step]))

(stack d a) 21 142 7 33
(pick-up b) 118 45 29 11
(pick-up c) 128 35 29 11
(unstack a d) 14 149 4 36
(unstack a b) 24 139 6 34
(stack a d) 22 141 6 34
(stack c b) 31 132 10 30
(unstack a c) 22 141 7 33
(stack b d) 27 136 7 33
(pick-up i) 19 144 5 35
(stack c d) 22 141 6 34
(stack d f) 12 151 4 36
(pick-up a) 121 42 31 9
(stack b e) 12 151 4 36
(pick-up h) 24 139 8 32
(stack e f) 14 149 4 36
(unstack c a) 27 136 7 33
(unstack c b) 21 142 6 34
(put-down c) 87 76 22 18
(unstack d b) 13 150 4 36
(stack a c) 38 125 8 32
(pick-up g) 38 125 13 27
(put-down g) 33 130 7 33
(stack b c) 36 127 7 33
(put-down d) 84 79 14 26
(stack c e) 14 149 4 36
(pick-up e) 73 90 15 25
(pick-up f) 60 103 11 29
(put-down f) 38 125 10 30
(unstack e a) 12 151 5 35
(stack a b) 27 136 14 26
(stack b a) 30 133 8 32
(stack c a) 40 123 7 33
(put-down e) 60 103 14 26
(unstack b a) 24 139 6 34
(stack e b) 12 151 5 35
(put-down a) 77 86 22 18
(pick-up d) 115 48 26 14
(put-down b) 87 76 21 19
(unstack b d) 18 145 6 34
(stack

In [33]:
class ProbeDataset(Dataset):
    def __init__(self, dataset, probe_pos, step, positive_data, negative_data, aggregate=False, balance=True):
        self.dataset = dataset
        self.probe_pos = probe_pos
        self.aggregate = aggregate

        self.positive_samples, self.negative_samples = negative_data[step], positive_data[step]

        # fix imbalance

        n_positive = len(self.positive_samples)
        n_negative = len(self.negative_samples)

        n_samples = min(n_positive, n_negative)

        if balance:
            self.positive_samples = self.positive_samples[:n_samples]
            self.negative_samples = self.negative_samples[:n_samples]

        self.samples = self.positive_samples + self.negative_samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample, _, is_positive = self.samples[idx]

        if self.aggregate:
            sample = sample[:self.probe_pos]
            # sample = torch.mean(sample, dim=0)
            sample = sample.view(-1)    
        else:
            sample = sample[self.probe_pos]

        return {
            "inputs": sample,
            "label": int(is_positive)
        }


In [34]:
train_datasets = {}
test_datasets = {}

for step in final_steps:
    train_datasets[step] = ProbeDataset(dataset, 600, step, positive_data["train"], negative_data["train"], aggregate=False)
    test_datasets[step] = ProbeDataset(dataset, 600, step, positive_data["test"], negative_data["test"], aggregate=False, balance=False)

In [36]:
probes = {
    step: Probe(n_dim, 2) for step in final_steps
}

In [37]:
accuracy = {}

for step in final_steps:
    accuracy[step] = train_probe(probes[step], train_datasets[step], test_datasets[step], n_epochs=4, silent=True, lr=1e-4)[1]

In [38]:
for step in final_steps:
    propotion = len(positive_data["test"][step]) / (len(positive_data["test"][step]) + len(negative_data["test"][step]))
    print(step, accuracy[step], len(positive_data["train"][step]), len(negative_data["train"][step]), len(positive_data["test"][step]), len(negative_data["test"][step]), max(propotion, 1 - propotion))

(stack d a) 0.175 21 142 7 33 0.825
(pick-up b) 0.275 118 45 29 11 0.725
(pick-up c) 0.275 128 35 29 11 0.725
(unstack a d) 0.525 14 149 4 36 0.9
(unstack a b) 0.15 24 139 6 34 0.85
(stack a d) 0.15 22 141 6 34 0.85
(stack c b) 0.4 31 132 10 30 0.75
(unstack a c) 0.425 22 141 7 33 0.825
(stack b d) 0.825 27 136 7 33 0.825
(pick-up i) 0.875 19 144 5 35 0.875
(stack c d) 0.15 22 141 6 34 0.85
(stack d f) 0.9 12 151 4 36 0.9
(pick-up a) 0.225 121 42 31 9 0.775
(stack b e) 0.625 12 151 4 36 0.9
(pick-up h) 0.325 24 139 8 32 0.8
(stack e f) 0.875 14 149 4 36 0.9
(unstack c a) 0.225 27 136 7 33 0.825
(unstack c b) 0.15 21 142 6 34 0.85
(put-down c) 0.675 87 76 22 18 0.55
(unstack d b) 0.3 13 150 4 36 0.9
(stack a c) 0.2 38 125 8 32 0.8
(pick-up g) 0.675 38 125 13 27 0.675
(put-down g) 0.225 33 130 7 33 0.825
(stack b c) 0.275 36 127 7 33 0.825
(put-down d) 0.45 84 79 14 26 0.65
(stack c e) 0.275 14 149 4 36 0.9
(pick-up e) 0.625 73 90 15 25 0.625
(pick-up f) 0.725 60 103 11 29 0.725
(put-dow