In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

from iit.model_pairs.strict_iit_model_pair import StrictIITModelPair

from poly_bench.cases.paren_checker import HighLevelParensBalanceChecker, BalancedParensDataset
from poly_bench.cases.left_greater import HighLevelLeftGreater, LeftGreaterDataset
from poly_bench.cases.duplicate_remover import HighLevelDuplicateRemover, DuplicateRemoverDataset
from poly_bench.cases.unique_extractor import HighLevelUniqueExtractor, UniqueExtractorDataset
from poly_bench.poly_hl_model import PolyHLModel, PolyModelDataset
from poly_bench.io import save_poly_model_to_dir, save_to_hf

Case0 = HighLevelDuplicateRemover
Case1 = HighLevelLeftGreater
Case2 = HighLevelParensBalanceChecker
Case3 = HighLevelUniqueExtractor

dataset_mapping = {
    HighLevelDuplicateRemover: DuplicateRemoverDataset,
    HighLevelLeftGreater: LeftGreaterDataset,
    HighLevelParensBalanceChecker: BalancedParensDataset,
    HighLevelUniqueExtractor: UniqueExtractorDataset
}

In [3]:
n_epochs = 100
n_samples = 10_000
training_args = {
    "batch_size": 256,
    "num_workers": 0,
    "use_single_loss": True,
    "behavior_weight": 0.4, #basically doubles the strict weight's job.
    "iit_weight": 1.,
    "strict_weight": 0.4,
    "clip_grad_norm": 1.0,
    "early_stop" : True,
    "lr_scheduler": torch.optim.lr_scheduler.LinearLR,
    "scheduler_kwargs": dict(start_factor=1, end_factor=0.2, total_iters=int(n_epochs)),
    "optimizer_kwargs": dict(lr=1e-3, betas=(0.9, 0.9)),
    "scheduler_val_metric": ["val/accuracy", "val/IIA"], #for ReduceLRonPlateau
    "scheduler_mode": "max", #for ReduceLRonPlateau
    "siit_sampling" : "sample_all",
    "val_IIA_sampling": "all", # random or all
    "seed" : 42
}

In [4]:
n_ctx = 15
seed = 42

# (0) DuplicateRemover + (1) LeftGreater

In [5]:
cases = [Case0, Case1]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}

blocks.0.mlp.hook_post [[prev_equal_hook], [mlp0_hook]]
blocks.0.attn.hook_z.0 [[prev_token_hook], None]
blocks.0.attn.hook_z.1 [None, [paren_counts_hook]]
blocks.0.attn.hook_z.2 [None, None]
blocks.0.attn.hook_z.3 [None, None]
blocks.1.mlp.hook_post [[output_hook], None]
blocks.1.attn.hook_z.0 [None, None]
blocks.1.attn.hook_z.1 [None, None]
blocks.1.attn.hook_z.2 [None, None]
blocks.1.attn.hook_z.3 [None, None]


In [6]:
dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

  self.inputs = t.tensor(inputs).to(t.int)
  self.targets = t.tensor(targets).to(t.float32)
  self.markers = t.tensor(markers).to(t.int)


In [7]:
ll_model = poly_hl_model.get_ll_model().to(poly_hl_model.device)
ll_model.device = poly_hl_model.device
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    epochs=n_epochs,
)

Moving model to device:  mps
input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 3, :], subspace=None)]
training_args={'batch_size': 256, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_schedu

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/54 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.1303, train/behavior_loss: 0.4246, train/strict_loss: 0.1767, val/iit_loss: 0.8554, val/IIA: 68.64%, val/accuracy: 76.78%, val/strict_accuracy: 75.68%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.6910, train/behavior_loss: 0.2209, train/strict_loss: 0.1072, val/iit_loss: 0.5357, val/IIA: 79.27%, val/accuracy: 85.36%, val/strict_accuracy: 84.45%
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.4386, train/behavior_loss: 0.1223, train/strict_loss: 0.0726, val/iit_loss: 0.3746, val/IIA: 85.13%, val/accuracy: 91.02%, val/strict_accuracy: 89.75%
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.3199, train/behavior_loss: 0.0829, train/strict_loss: 0.0583, val/iit_loss: 0.3103, val

In [8]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_0+1")

PosixPath('saved_poly_models/cases_0+1')

# (0) DuplicateRemover + (2) ParenChecker

In [9]:
cases = [Case0, Case2]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}

blocks.0.mlp.hook_post [[prev_equal_hook], [mlp0_hook]]
blocks.0.attn.hook_z.0 [[prev_token_hook], None]
blocks.0.attn.hook_z.1 [None, None]
blocks.0.attn.hook_z.2 [None, None]
blocks.0.attn.hook_z.3 [None, [paren_counts_hook]]
blocks.1.mlp.hook_post [[output_hook], [mlp1_hook]]
blocks.1.attn.hook_z.0 [None, None]
blocks.1.attn.hook_z.1 [

In [10]:
dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

  self.inputs = t.tensor(inputs).to(t.int)
  self.targets = t.tensor(targets).to(t.float32)
  self.markers = t.tensor(markers).to(t.int)


In [11]:
ll_model = poly_hl_model.get_ll_model().to(poly_hl_model.device)
ll_model.device = poly_hl_model.device

corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    epochs=n_epochs,
)

Moving model to device:  mps
input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.a

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/54 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.0867, train/behavior_loss: 0.4002, train/strict_loss: 0.1730, val/iit_loss: 0.7787, val/IIA: 76.15%, val/accuracy: 81.04%, val/strict_accuracy: 80.95%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.6278, train/behavior_loss: 0.2130, train/strict_loss: 0.1016, val/iit_loss: 0.4354, val/IIA: 86.25%, val/accuracy: 89.27%, val/strict_accuracy: 88.56%
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.3528, train/behavior_loss: 0.1151, train/strict_loss: 0.0593, val/iit_loss: 0.2729, val/IIA: 90.48%, val/accuracy: 93.35%, val/strict_accuracy: 92.29%
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.2344, train/behavior_loss: 0.0627, train/strict_loss: 0.0458, val/iit_loss: 0.1807, val

In [12]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_0+2")

PosixPath('saved_poly_models/cases_0+2')

# (0) DuplicateRemover + (3) UniqueExtractor

In [13]:
cases = [Case0, Case3]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}

blocks.0.mlp.hook_post [[prev_equal_hook], [appeared_mlp]]
blocks.0.attn.hook_z.0 [[prev_token_hook], None]
blocks.0.attn.hook_z.1 [None, None]
blocks.0.attn.hook_z.2 [None, [counter_head]]
blocks.0.attn.hook_z.3 [None, None]
blocks.1.mlp.hook_post [[output_hook], [mask_mlp]]
blocks.1.attn.hook_z.0 [None, None]
blocks.1.attn.hook_z.1 [None, None]
blocks.1.attn.hook_z.2 [None, None]
blocks.1.attn.hook_z.3 [None, None]
blocks.

In [14]:
dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

In [15]:
ll_model = poly_hl_model.get_ll_model().to(poly_hl_model.device)
ll_model.device = poly_hl_model.device
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    epochs=n_epochs,
)

Moving model to device:  mps
input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.2.attn.hook_z', ind

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.7631, train/behavior_loss: 0.2748, train/strict_loss: 0.1209, val/iit_loss: 0.4155, val/IIA: 86.74%, val/accuracy: 90.95%, val/strict_accuracy: 91.26%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.3254, train/behavior_loss: 0.0832, train/strict_loss: 0.0493, val/iit_loss: 0.2580, val/IIA: 89.51%, val/accuracy: 95.46%, val/strict_accuracy: 94.97%
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.2232, train/behavior_loss: 0.0387, train/strict_loss: 0.0358, val/iit_loss: 0.1836, val/IIA: 92.62%, val/accuracy: 98.21%, val/strict_accuracy: 97.30%
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.1380, train/behavior_loss: 0.0111, train/strict_loss: 0.0315, val/iit_loss: 0.1416, val

In [16]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_0+3")

PosixPath('saved_poly_models/cases_0+3')

# (1) LeftGreater + (2) ParenChecker

In [17]:
cases = [Case1, Case2]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}

blocks.0.mlp.hook_post [[mlp0_hook], [mlp0_hook]]
blocks.0.attn.hook_z.0 [None, None]
blocks.0.attn.hook_z.1 [[paren_counts_hook], None]
blocks.0.attn.hook_z.2 [None, None]
blocks.0.attn.hook_z.3 [None, [paren_counts_hook]]
blocks.1.mlp.hook_post [None, [mlp1_hook]]
blocks.1.attn.hook_z.0 [None, None]
blocks.1.attn.hook_z.1 [None, None]
b

In [18]:
ll_model = poly_hl_model.get_ll_model().to(poly_hl_model.device)
ll_model.device = poly_hl_model.device
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    epochs=n_epochs,
)

Moving model to device:  mps
input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.a

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/44 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.7672, train/behavior_loss: 0.2989, train/strict_loss: 0.1218, val/iit_loss: 0.4668, val/IIA: 91.08%, val/accuracy: 95.88%, val/strict_accuracy: 95.39%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.3462, train/behavior_loss: 0.1083, train/strict_loss: 0.0493, val/iit_loss: 0.2613, val/IIA: 94.42%, val/accuracy: 97.74%, val/strict_accuracy: 97.59%
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.1935, train/behavior_loss: 0.0568, train/strict_loss: 0.0245, val/iit_loss: 0.1835, val/IIA: 94.67%, val/accuracy: 97.90%, val/strict_accuracy: 97.81%
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.1577, train/behavior_loss: 0.0314, train/strict_loss: 0.0166, val/iit_loss: 0.1350, val

In [19]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_1+2")

PosixPath('saved_poly_models/cases_1+2')

# (1) LeftGreater + (3) UniqueExtractor

In [20]:
cases = [Case1, Case3]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}

blocks.0.mlp.hook_post [[mlp0_hook], [appeared_mlp]]
blocks.0.attn.hook_z.0 [None, None]
blocks.0.attn.hook_z.1 [[paren_counts_hook], None]
blocks.0.attn.hook_z.2 [None, [counter_head]]
blocks.0.attn.hook_z.3 [None, None]
blocks.1.mlp.hook_post [None, [mask_mlp]]
blocks.1.attn.hook_z.0 [None, None]
blocks.1.attn.hook_z.1 [None, None]
blocks.1.attn.hook_z.2 [None, None]
blocks.1.attn.hook_z.3 [None, None]
blocks.2.mlp.hook_po

In [21]:
ll_model = poly_hl_model.get_ll_model().to(poly_hl_model.device)
ll_model.device = poly_hl_model.device
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)

model_pair.train(
    train_set=train_set,
    test_set=test_set,
    epochs=n_epochs,
)

Moving model to device:  mps
input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.2.attn.hook_z', ind

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/54 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.7667, train/behavior_loss: 0.2909, train/strict_loss: 0.1265, val/iit_loss: 0.4739, val/IIA: 86.81%, val/accuracy: 92.44%, val/strict_accuracy: 91.41%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.3308, train/behavior_loss: 0.0943, train/strict_loss: 0.0470, val/iit_loss: 0.2141, val/IIA: 93.61%, val/accuracy: 98.58%, val/strict_accuracy: 98.21%
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.1541, train/behavior_loss: 0.0219, train/strict_loss: 0.0185, val/iit_loss: 0.1194, val/IIA: 96.25%, val/accuracy: 99.64%, val/strict_accuracy: 99.25%
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.1118, train/behavior_loss: 0.0097, train/strict_loss: 0.0167, val/iit_loss: 0.0813, val

In [22]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_1+3")

PosixPath('saved_poly_models/cases_1+3')

# (2) ParenChecker + (3) UniqueExtractor

In [23]:
cases = [Case2, Case3]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}

blocks.0.mlp.hook_post [[mlp0_hook], [appeared_mlp]]
blocks.0.attn.hook_z.0 [None, None]
blocks.0.attn.hook_z.1 [None, None]
blocks.0.attn.hook_z.2 [None, [counter_head]]
blocks.0.attn.hook_z.3 [[paren_counts_hook], None]
blocks.1.mlp.hook_post [[mlp1_hook], [mask_mlp]]
blocks.1.attn.hook_z.0 [None, None]
blocks.1.attn.hook_z.1 [None, Non

In [24]:
ll_model = poly_hl_model.get_ll_model().to(poly_hl_model.device)
ll_model.device = poly_hl_model.device
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    epochs=n_epochs,
)

Moving model to device:  mps
input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.a

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/54 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.7994, train/behavior_loss: 0.3052, train/strict_loss: 0.1345, val/iit_loss: 0.4381, val/IIA: 88.08%, val/accuracy: 90.95%, val/strict_accuracy: 91.12%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.2983, train/behavior_loss: 0.0870, train/strict_loss: 0.0364, val/iit_loss: 0.1716, val/IIA: 95.63%, val/accuracy: 97.45%, val/strict_accuracy: 97.38%
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.1292, train/behavior_loss: 0.0267, train/strict_loss: 0.0153, val/iit_loss: 0.0976, val/IIA: 97.23%, val/accuracy: 98.83%, val/strict_accuracy: 98.76%
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.0744, train/behavior_loss: 0.0099, train/strict_loss: 0.0068, val/iit_loss: 0.0720, val

In [25]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_2+3")

PosixPath('saved_poly_models/cases_2+3')

# (0) DuplicateRemover + (1) LeftGreater + (2) ParenChecker

In [26]:
cases = [Case0, Case1, Case2]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}

blocks.0.mlp.hook_post [[prev_equal_hook], [mlp0_hook], [mlp0_hook]]
blocks.0.attn.hook_z.0 [[prev_token_hook], None, None]
blocks.0.attn.hook_z.1 [None, [paren_counts_hook], None]
blocks.0.attn.hook_z.2 [None, None, None]
blocks.0.attn.hook_z.3 [None,

In [27]:
ll_model = poly_hl_model.get_ll_model().to(poly_hl_model.device)
ll_model.device = poly_hl_model.device
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    epochs=n_epochs,
)

Moving model to device:  mps
input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}
[LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.9956, train/behavior_loss: 0.3630, train/strict_loss: 0.1667, val/iit_loss: 0.6228, val/IIA: 79.09%, val/accuracy: 85.76%, val/strict_accuracy: 83.54%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.4591, train/behavior_loss: 0.1503, train/strict_loss: 0.0895, val/iit_loss: 0.3790, val/IIA: 85.87%, val/accuracy: 90.76%, val/strict_accuracy: 90.17%
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.3343, train/behavior_loss: 0.0841, train/strict_loss: 0.0491, val/iit_loss: 0.2625, val/IIA: 89.47%, val/accuracy: 93.54%, val/strict_accuracy: 92.62%
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.2175, train/behavior_loss: 0.0557, train/strict_loss: 0.0412, val/iit_loss: 0.1919, val

In [28]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_0+1+2")

PosixPath('saved_poly_models/cases_0+1+2')

# (0) DuplicateRemover + (1) LeftGreater + (3) UniqueExtractor

In [29]:
cases = [Case0, Case1, Case3]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}

blocks.0.mlp.hook_post [[prev_equal_hook], [mlp0_hook], [appeared_mlp]]
blocks.0.attn.hook_z.0 [[prev_token_hook], None, None]
blocks.0.attn.hook_z.1 [None, [paren_counts_hook], None]
blocks.0.attn.hook_z.2 [None, None, [counter_head]]
blocks.0.attn.hook_z.3 [None, None, None]
blocks.1.mlp.hook_post [[output_hook], None, [mask_mlp]]
block

In [30]:
ll_model = poly_hl_model.get_ll_model().to(poly_hl_model.device)
ll_model.device = poly_hl_model.device
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    epochs=n_epochs,
)

Moving model to device:  mps
input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
[LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.2.a

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/85 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.7619, train/behavior_loss: 0.2693, train/strict_loss: 0.1164, val/iit_loss: 0.4183, val/IIA: 85.40%, val/accuracy: 90.46%, val/strict_accuracy: 89.81%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.3279, train/behavior_loss: 0.0696, train/strict_loss: 0.0427, val/iit_loss: 0.2384, val/IIA: 90.93%, val/accuracy: 97.18%, val/strict_accuracy: 96.29%
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.1841, train/behavior_loss: 0.0224, train/strict_loss: 0.0327, val/iit_loss: 0.1624, val/IIA: 94.10%, val/accuracy: 99.76%, val/strict_accuracy: 98.80%
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.1386, train/behavior_loss: 0.0087, train/strict_loss: 0.0284, val/iit_loss: 0.1228, val

In [31]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_0+1+3")

PosixPath('saved_poly_models/cases_0+1+3')

# (0) DuplicateRemover + (2) ParenChecker + (3) UniqueExtractor 

In [32]:
cases = [Case0, Case2, Case3]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}

blocks.0.mlp.hook_post [[prev_equal_hook], [mlp0_hook], [appeared_mlp]]
blocks.0.attn.hook_z.0 [[prev_token_hook], None, None]
blocks.0.attn.hook_z.1 [None, None, None]
blocks.0.attn.hook_z.2 [None, None, [counter_head]]
blocks.0.attn.hook_z.3 [None, [

In [33]:
ll_model = poly_hl_model.get_ll_model().to(poly_hl_model.device)
ll_model.device = poly_hl_model.device
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    epochs=n_epochs,
)

Moving model to device:  mps
input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}
[LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/85 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.8056, train/behavior_loss: 0.2910, train/strict_loss: 0.1275, val/iit_loss: 0.4016, val/IIA: 85.89%, val/accuracy: 89.19%, val/strict_accuracy: 88.75%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.2906, train/behavior_loss: 0.0781, train/strict_loss: 0.0423, val/iit_loss: 0.2135, val/IIA: 92.17%, val/accuracy: 96.62%, val/strict_accuracy: 95.61%
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.1896, train/behavior_loss: 0.0274, train/strict_loss: 0.0284, val/iit_loss: 0.1267, val/IIA: 95.87%, val/accuracy: 99.26%, val/strict_accuracy: 98.33%
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.0948, train/behavior_loss: 0.0085, train/strict_loss: 0.0222, val/iit_loss: 0.1018, val

In [34]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_0+2+3")

PosixPath('saved_poly_models/cases_0+2+3')

# (1) LeftGreater + (2) ParenChecker + (3) UniqueExtractor 

In [35]:
cases = [Case1, Case2, Case3]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}

blocks.0.mlp.hook_post [[mlp0_hook], [mlp0_hook], [appeared_mlp]]
blocks.0.attn.hook_z.0 [None, None, None]
blocks.0.attn.hook_z.1 [[paren_counts_hook], None, None]
blocks.0.attn.hook_z.2 [None, None, [counter_head]]
blocks.0.attn.hook_z.3 [None, [pare

In [36]:
ll_model = poly_hl_model.get_ll_model().to(poly_hl_model.device)
ll_model.device = poly_hl_model.device
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    epochs=n_epochs,
)

Moving model to device:  mps
input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}
[LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.7057, train/behavior_loss: 0.2560, train/strict_loss: 0.1102, val/iit_loss: 0.3788, val/IIA: 89.01%, val/accuracy: 95.11%, val/strict_accuracy: 95.00%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.2848, train/behavior_loss: 0.0660, train/strict_loss: 0.0299, val/iit_loss: 0.1911, val/IIA: 94.42%, val/accuracy: 98.38%, val/strict_accuracy: 98.34%
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.1570, train/behavior_loss: 0.0237, train/strict_loss: 0.0133, val/iit_loss: 0.1371, val/IIA: 95.40%, val/accuracy: 98.64%, val/strict_accuracy: 98.49%
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.1056, train/behavior_loss: 0.0111, train/strict_loss: 0.0085, val/iit_loss: 0.0870, val

In [37]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_1+2+3")

PosixPath('saved_poly_models/cases_1+2+3')

# (0) DuplicateRemover + (1) LeftGreater + (2) ParenChecker + (3) UniqueExtractor 

In [38]:
cases = [Case0, Case1, Case2, Case3]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
task_hook {LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}

blocks.0.mlp.hook_post [[prev_equal_hook], [mlp0_hook], [mlp0_hook], [appeared_mlp]]
blocks.0.attn.hook_z.0 [[prev_token_hook], None, None, None]
blocks.0.attn.hook

In [39]:
ll_model = poly_hl_model.get_ll_model().to(poly_hl_model.device)
ll_model.device = poly_hl_model.device
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    epochs=n_epochs,
)

Moving model to device:  mps
input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
task_hook {LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}
[LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspac

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.7860, train/behavior_loss: 0.2727, train/strict_loss: 0.1158, val/iit_loss: 0.4426, val/IIA: 84.27%, val/accuracy: 92.07%, val/strict_accuracy: 90.99%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.3287, train/behavior_loss: 0.0676, train/strict_loss: 0.0345, val/iit_loss: 0.2746, val/IIA: 90.16%, val/accuracy: 97.41%, val/strict_accuracy: 96.47%
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.2217, train/behavior_loss: 0.0269, train/strict_loss: 0.0218, val/iit_loss: 0.1841, val/IIA: 93.73%, val/accuracy: 99.13%, val/strict_accuracy: 98.59%
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.1529, train/behavior_loss: 0.0120, train/strict_loss: 0.0178, val/iit_loss: 0.1357, val

In [40]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_0+1+2+3")

PosixPath('saved_poly_models/cases_0+1+2+3')

# Push to HF

In [41]:
save_to_hf(local_dir="saved_poly_models", message="pushes all polysemantic models")

Upload 11 LFS files:   0%|          | 0/11 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/162k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/68.4k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/46.5k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/68.4k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/162k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/162k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/162k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/68.3k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/162k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/162k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/162k [00:00<?, ?B/s]