In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import json
import numpy as np
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import trange, tqdm
from utils import Probe, set_seed, train_probe

import os
os.chdir("..")

seed = 42
set_seed(seed)

activations_file = "planning_activations_32b_big_step.pt"
metadata_file = "planning_metadata.json"

activations = torch.load(activations_file)

with open(metadata_file) as f:
    metadata = json.load(f)

  from .autonotebook import tqdm as notebook_tqdm
  activations = torch.load(activations_file)


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
dataset = []

all_steps = set()

for x in metadata:
    dataset_idx = x["dataset_idx"]
    activation = activations[dataset_idx]

    extracted_plan = x["bench_item"]["extracted_llm_plan"]

    validity = x["bench_item"]["llm_validity"]

    if validity != 1:
        continue

    steps = set(extracted_plan.split("\n"))

    steps = {step for step in steps if step != ""}

    all_steps.update(steps)

    think_pos = x["think_pos"]

    dataset.append({
        "activations": activation[:think_pos // 10],
        "steps": steps,
        "think_pos": think_pos
    })


print(len(dataset))

import random

random.shuffle(dataset)

test_size = 0.2
test_size = int(len(dataset) * test_size)

train_dataset = dataset[:-test_size]
test_dataset = dataset[-test_size:]



937


In [5]:
n_dim = 5120

In [6]:

print(len(all_steps))
print(all_steps)

199
{'(stack d c)', '(stack b d)', '(stack g h)', '(stack i h)', '(stack e j)', '(pick-up g)', '(unstack g d)', '(unstack e d)', '(stack g b)', '(unstack d i)', '(unstack c f)', '(put-down f)', '(stack f a)', '(stack f e)', '(stack i g)', '(unstack h i)', '(stack j i)', '(pick-up c)', '(stack b i)', '(stack c a)', '(unstack f g)', '(unstack d h)', '(stack a c)', '(stack g d)', '(unstack b d)', '(stack c f)', '(unstack e g)', '(unstack j a)', '(put-down j)', '(stack f b)', '(stack g j)', '(pick-up e)', '(unstack i a)', '(stack j e)', '(pick-up f)', '(put-down g)', '(stack a i)', '(stack j d)', '(stack c j)', '(unstack a j)', '(stack a d)', '(stack a h)', '(stack c g)', '(stack e g)', '(unstack a f)', '(stack b g)', '(stack g c)', '(unstack c h)', '(put-down a)', '(unstack j e)', '(stack i b)', '(stack b c)', '(unstack a h)', '(stack i e)', '(stack i j)', '(stack f d)', '(unstack b f)', '(unstack e b)', '(stack j c)', '(stack d h)', '(unstack b c)', '(unstack i h)', '(unstack g c)', '(un

In [7]:
train_data_items = {step: [] for step in all_steps}

for item in train_dataset:
    activations = item["activations"]
    steps = item["steps"]
    think_pos = item["think_pos"]

    for step in steps:
        train_data_items[step].append((activations, think_pos))

In [8]:
counts = {step: len(train_data_items[step]) for step in all_steps}

cutoff = 10

all_steps = [step for step in all_steps if counts[step] > cutoff]

print(len(all_steps))

142


In [9]:
positive_data = {
    "train": {step: [] for step in all_steps},
    "test": {step: [] for step in all_steps}
}
negative_data = {
    "train": {step: [] for step in all_steps},
    "test": {step: [] for step in all_steps}
}

for x in train_dataset:
    activations = x["activations"]
    steps = x["steps"]
    think_pos = x["think_pos"]

    for step in all_steps:
        if step in steps:
            positive_data["train"][step].append((activations, think_pos, True))
        else:
            negative_data["train"][step].append((activations, think_pos, False))

for x in test_dataset:
    activations = x["activations"]
    steps = x["steps"]
    think_pos = x["think_pos"]

    for step in all_steps:
        if step in steps:
            positive_data["test"][step].append((activations, think_pos, True))
        else:
            negative_data["test"][step].append((activations, think_pos, False))

final_steps = set()

for step in all_steps:
    pos_neg_ratio = len(positive_data["train"][step]) / len(negative_data["train"][step])

    if pos_neg_ratio < 0.5:
        continue
    if pos_neg_ratio > 2:
        continue

    final_steps.add(step)


In [10]:
for step in final_steps:
    print(step, len(positive_data["train"][step]), len(negative_data["train"][step]), len(positive_data["test"][step]), len(negative_data["test"][step]))

(put-down f) 251 499 52 135
(put-down a) 363 387 92 95
(pick-up e) 422 328 102 85
(put-down b) 416 334 99 88
(put-down e) 338 412 80 107
(pick-up f) 325 425 75 112
(pick-up g) 254 496 50 137
(put-down d) 380 370 98 89
(put-down c) 426 324 106 81


In [19]:
class ProbeDataset(Dataset):
    def __init__(self, dataset, probe_pos, step, positive_data, negative_data, aggregate=False, balance=True):
        self.dataset = dataset
        self.probe_pos = probe_pos
        self.aggregate = aggregate

        self.positive_samples, self.negative_samples = negative_data[step], positive_data[step]

        # fix imbalance

        n_positive = len(self.positive_samples)
        n_negative = len(self.negative_samples)

        n_samples = min(n_positive, n_negative)

        if balance:
            self.positive_samples = self.positive_samples[:n_samples]
            self.negative_samples = self.negative_samples[:n_samples]

        self.samples = self.positive_samples + self.negative_samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample, _, is_positive = self.samples[idx]

        sample = sample[-200:].float()
        
        return {
            "inputs": sample,
            "label": int(is_positive)
        }


In [20]:
def collate_fn(batch):
    inputs = [x["inputs"] for x in batch]
    labels = [x["label"] for x in batch]

    # for x in inputs:
    #     print(x.shape)

    # pad inputs left
    masks = [torch.ones(x.shape[0]) for x in inputs]
    inputs = torch.nn.utils.rnn.pad_sequence(inputs, batch_first=True, padding_value=0, padding_side="left")
    masks = torch.nn.utils.rnn.pad_sequence(masks, batch_first=True, padding_value=0, padding_side="left")
    labels = torch.tensor(labels)

    # print(inputs.shape)

    return {
        "inputs": inputs.to(device),
        "label": labels.to(device),
        "mask": masks.to(device)
    }


In [21]:
train_datasets = {}
test_datasets = {}

for step in final_steps:
    train_datasets[step] = ProbeDataset(dataset, None, step, positive_data["train"], negative_data["train"], aggregate=False)
    test_datasets[step] = ProbeDataset(dataset, None, step, positive_data["test"], negative_data["test"], aggregate=False, balance=True)

In [22]:
class LSTMProbe(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__() 

        self.lstm = torch.nn.GRU(input_size, hidden_size, batch_first=True, dtype=torch.float32, num_layers=1)
        self.fc = torch.nn.Linear(hidden_size, output_size, dtype=torch.float32)

    def forward(self, x):
        x, _ = self.lstm(x)
        x = self.fc(x[:, -1, :])
        return x

In [31]:
probes = {
    step: LSTMProbe(n_dim, 256, 2).to(device) for step in final_steps
}

In [35]:
accuracy = {}

for step in tqdm(final_steps):
    accuracy[step] = train_probe(probes[step], train_datasets[step], test_datasets[step], n_epochs=10, silent=True, lr=1e-4, collate_fn=collate_fn, batch_size=16)[1]

  0%|          | 0/9 [00:00<?, ?it/s]

100%|██████████| 9/9 [00:24<00:00,  2.76s/it]


In [37]:
for step in final_steps:
    propotion = len(positive_data["test"][step]) / (len(positive_data["test"][step]) + len(negative_data["test"][step]))
    print(step, accuracy[step], len(positive_data["train"][step]), len(negative_data["train"][step]), len(positive_data["test"][step]), len(negative_data["test"][step]), max(propotion, 1 - propotion))

(put-down f) 0.8076923076923077 251 499 52 135 0.7219251336898396
(put-down a) 0.5706521739130435 363 387 92 95 0.5080213903743316
(pick-up e) 0.8058823529411765 422 328 102 85 0.5454545454545454
(put-down b) 0.5909090909090909 416 334 99 88 0.5294117647058824
(put-down e) 0.70625 338 412 80 107 0.572192513368984
(pick-up f) 0.8533333333333334 325 425 75 112 0.5989304812834224
(pick-up g) 0.87 254 496 50 137 0.732620320855615
(put-down d) 0.651685393258427 380 370 98 89 0.5240641711229946
(put-down c) 0.5987654320987654 426 324 106 81 0.5668449197860963
