In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch

from iit.model_pairs.strict_iit_model_pair import StrictIITModelPair

from poly_bench.cases.paren_checker import HighLevelParensBalanceChecker, BalancedParensDataset
from poly_bench.cases.left_greater import HighLevelLeftGreater, LeftGreaterDataset
from poly_bench.cases.duplicate_remover import HighLevelDuplicateRemover, DuplicateRemoverDataset
from poly_bench.cases.unique_extractor import HighLevelUniqueExtractor, UniqueExtractorDataset
from poly_bench.poly_hl_model import PolyHLModel, PolyModelDataset
from poly_bench.io import save_poly_model_to_dir, save_to_hf

Case0 = HighLevelDuplicateRemover
Case1 = HighLevelLeftGreater
Case2 = HighLevelParensBalanceChecker
Case3 = HighLevelUniqueExtractor

dataset_mapping = {
    HighLevelDuplicateRemover: DuplicateRemoverDataset,
    HighLevelLeftGreater: LeftGreaterDataset,
    HighLevelParensBalanceChecker: BalancedParensDataset,
    HighLevelUniqueExtractor: UniqueExtractorDataset
}

In [4]:
n_epochs = 100
n_samples = 10_000
training_args = {
    "batch_size": 256,
    "num_workers": 0,
    "use_single_loss": True,
    "behavior_weight": 0.4, #basically doubles the strict weight's job.
    "iit_weight": 1.,
    "strict_weight": 0.4,
    "clip_grad_norm": 1.0,
    "early_stop" : True,
    "lr_scheduler": torch.optim.lr_scheduler.LinearLR,
    "scheduler_kwargs": dict(start_factor=1, end_factor=0.2, total_iters=int(n_epochs)),
    "optimizer_kwargs": dict(lr=1e-3, betas=(0.9, 0.9)),
    "scheduler_val_metric": ["val/accuracy", "val/IIA"], #for ReduceLRonPlateau
    "scheduler_mode": "max", #for ReduceLRonPlateau
    "siit_sampling" : "sample_all",
    "val_IIA_sampling": "all", # random or all
    "seed" : 42
}

In [5]:
n_ctx = 15
seed = 42

# (0) DuplicateRemover + (1) LeftGreater

In [10]:
cases = [Case0, Case1]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}

blocks.0.mlp.hook_post [[prev_equal_hook], [mlp0_hook]]
blocks.0.attn.hook_z.0 [[prev_token_hook], None]
blocks.0.attn.hook_z.1 [None, [paren_counts_hook]]
blocks.0.attn.hook_z.2 [None, None]
blocks.0.attn.hook_z.3 [None, None]
blocks.1.mlp.hook_post [[output_hook], None]
blocks.1.attn.hook_z.0 [None, None]
blocks.1.attn.hook_z.1 [None, None]
blocks.1.attn.hook_z.2 [None, None]
blocks.1.attn.hook_z.3 [None, None]


In [11]:
dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

In [13]:
ll_model = poly_hl_model.get_ll_model().to(poly_hl_model.device)
ll_model.device = poly_hl_model.device
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    epochs=n_epochs,
)

Moving model to device:  mps
input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 3, :], subspace=None)]
training_args={'batch_size': 256, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_schedu

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/54 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.1320, train/behavior_loss: 0.4307, train/strict_loss: 0.1762, val/iit_loss: 0.8251, val/IIA: 74.23%, val/accuracy: 80.66%, val/strict_accuracy: 79.61%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.6478, train/behavior_loss: 0.2137, train/strict_loss: 0.0958, val/iit_loss: 0.5150, val/IIA: 80.31%, val/accuracy: 85.56%, val/strict_accuracy: 84.76%
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.4251, train/behavior_loss: 0.1182, train/strict_loss: 0.0605, val/iit_loss: 0.3778, val/IIA: 85.40%, val/accuracy: 90.14%, val/strict_accuracy: 88.80%
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.3296, train/behavior_loss: 0.0812, train/strict_loss: 0.0545, val/iit_loss: 0.3014, val

KeyboardInterrupt: 

In [8]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_0+1")

# (0) DuplicateRemover + (2) ParenChecker

In [14]:
cases = [Case0, Case2]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}

blocks.0.mlp.hook_post [[prev_equal_hook], [mlp0_hook]]
blocks.0.attn.hook_z.0 [[prev_token_hook], None]
blocks.0.attn.hook_z.1 [None, None]
blocks.0.attn.hook_z.2 [None, None]
blocks.0.attn.hook_z.3 [None, [paren_counts_hook]]
blocks.1.mlp.hook_post [[output_hook], [mlp1_hook]]
blocks.1.attn.hook_z.0 [None, None]
blocks.1.attn.hook_z.1 [

In [15]:
dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

  _, cache = hl_model.run_with_cache((t.tensor(sample).unsqueeze(0), None, None))
  self.inputs = t.tensor(inputs).to(t.int)
  self.targets = t.tensor(targets).to(t.float32)
  self.markers = t.tensor(markers).to(t.int)


In [16]:
ll_model = poly_hl_model.get_ll_model().to(poly_hl_model.device)
ll_model.device = poly_hl_model.device

corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    epochs=n_epochs,
)

Moving model to device:  mps
input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.a

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/54 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.0499, train/behavior_loss: 0.4056, train/strict_loss: 0.1749, val/iit_loss: 0.7415, val/IIA: 77.54%, val/accuracy: 77.78%, val/strict_accuracy: 77.74%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.6019, train/behavior_loss: 0.2266, train/strict_loss: 0.0986, val/iit_loss: 0.4637, val/IIA: 84.28%, val/accuracy: 87.01%, val/strict_accuracy: 86.89%
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.3655, train/behavior_loss: 0.1171, train/strict_loss: 0.0596, val/iit_loss: 0.3122, val/IIA: 89.59%, val/accuracy: 90.68%, val/strict_accuracy: 89.52%


KeyboardInterrupt: 

In [12]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_0+2")

# (0) DuplicateRemover + (3) UniqueExtractor

In [17]:
cases = [Case0, Case3]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}

blocks.0.mlp.hook_post [[prev_equal_hook], [appeared_mlp]]
blocks.0.attn.hook_z.0 [[prev_token_hook], None]
blocks.0.attn.hook_z.1 [None, None]
blocks.0.attn.hook_z.2 [None, [counter_head]]
blocks.0.attn.hook_z.3 [None, None]
blocks.1.mlp.hook_post [[output_hook], [mask_mlp]]
blocks.1.attn.hook_z.0 [None, None]
blocks.1.attn.hook_z.1 [None, None]
blocks.1.attn.hook_z.2 [None, None]
blocks.1.attn.hook_z.3 [None, None]
blocks.

In [18]:
dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

  _, cache = hl_model.run_with_cache((t.tensor(sample).unsqueeze(0), None, None))


In [19]:
ll_model = poly_hl_model.get_ll_model().to(poly_hl_model.device)
ll_model.device = poly_hl_model.device
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    epochs=n_epochs,
)

Moving model to device:  mps
input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.2.attn.hook_z', ind

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.8313, train/behavior_loss: 0.3056, train/strict_loss: 0.1350, val/iit_loss: 0.4445, val/IIA: 86.38%, val/accuracy: 89.22%, val/strict_accuracy: 88.26%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.3394, train/behavior_loss: 0.0955, train/strict_loss: 0.0590, val/iit_loss: 0.2728, val/IIA: 89.50%, val/accuracy: 92.86%, val/strict_accuracy: 91.79%
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.2184, train/behavior_loss: 0.0447, train/strict_loss: 0.0405, val/iit_loss: 0.1827, val/IIA: 93.22%, val/accuracy: 95.85%, val/strict_accuracy: 95.51%
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.1523, train/behavior_loss: 0.0205, train/strict_loss: 0.0359, val/iit_loss: 0.1395, val

KeyboardInterrupt: 

In [16]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_0+3")

# (1) LeftGreater + (2) ParenChecker

In [20]:
cases = [Case1, Case2]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}

blocks.0.mlp.hook_post [[mlp0_hook], [mlp0_hook]]
blocks.0.attn.hook_z.0 [None, None]
blocks.0.attn.hook_z.1 [[paren_counts_hook], None]
blocks.0.attn.hook_z.2 [None, None]
blocks.0.attn.hook_z.3 [None, [paren_counts_hook]]
blocks.1.mlp.hook_post [None, [mlp1_hook]]
blocks.1.attn.hook_z.0 [None, None]
blocks.1.attn.hook_z.1 [None, None]
b

In [21]:
ll_model = poly_hl_model.get_ll_model().to(poly_hl_model.device)
ll_model.device = poly_hl_model.device
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    epochs=n_epochs,
)

Moving model to device:  mps
input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.a

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.7506, train/behavior_loss: 0.3005, train/strict_loss: 0.1257, val/iit_loss: 0.4640, val/IIA: 88.62%, val/accuracy: 90.57%, val/strict_accuracy: 88.67%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.3442, train/behavior_loss: 0.1330, train/strict_loss: 0.0730, val/iit_loss: 0.2874, val/IIA: 93.25%, val/accuracy: 94.91%, val/strict_accuracy: 94.67%


KeyboardInterrupt: 

In [None]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_1+2")

# (1) LeftGreater + (3) UniqueExtractor

In [22]:
cases = [Case1, Case3]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}

blocks.0.mlp.hook_post [[mlp0_hook], [appeared_mlp]]
blocks.0.attn.hook_z.0 [None, None]
blocks.0.attn.hook_z.1 [[paren_counts_hook], None]
blocks.0.attn.hook_z.2 [None, [counter_head]]
blocks.0.attn.hook_z.3 [None, None]
blocks.1.mlp.hook_post [None, [mask_mlp]]
blocks.1.attn.hook_z.0 [None, None]
blocks.1.attn.hook_z.1 [None, None]
blocks.1.attn.hook_z.2 [None, None]
blocks.1.attn.hook_z.3 [None, None]
blocks.2.mlp.hook_po

In [23]:
ll_model = poly_hl_model.get_ll_model().to(poly_hl_model.device)
ll_model.device = poly_hl_model.device
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)

model_pair.train(
    train_set=train_set,
    test_set=test_set,
    epochs=n_epochs,
)

Moving model to device:  mps
input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.2.attn.hook_z', ind

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/54 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.7742, train/behavior_loss: 0.2961, train/strict_loss: 0.1280, val/iit_loss: 0.4539, val/IIA: 86.35%, val/accuracy: 93.13%, val/strict_accuracy: 92.86%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.3026, train/behavior_loss: 0.0835, train/strict_loss: 0.0447, val/iit_loss: 0.1840, val/IIA: 95.54%, val/accuracy: 95.60%, val/strict_accuracy: 95.18%
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.1620, train/behavior_loss: 0.0288, train/strict_loss: 0.0192, val/iit_loss: 0.1470, val/IIA: 95.88%, val/accuracy: 95.98%, val/strict_accuracy: 95.65%
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.1460, train/behavior_loss: 0.0208, train/strict_loss: 0.0209, val/iit_loss: 0.1264, val

KeyboardInterrupt: 

In [None]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_1+3")

# (2) ParenChecker + (3) UniqueExtractor

In [24]:
cases = [Case2, Case3]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}

blocks.0.mlp.hook_post [[mlp0_hook], [appeared_mlp]]
blocks.0.attn.hook_z.0 [None, None]
blocks.0.attn.hook_z.1 [None, None]
blocks.0.attn.hook_z.2 [None, [counter_head]]
blocks.0.attn.hook_z.3 [[paren_counts_hook], None]
blocks.1.mlp.hook_post [[mlp1_hook], [mask_mlp]]
blocks.1.attn.hook_z.0 [None, None]
blocks.1.attn.hook_z.1 [None, Non

In [25]:
ll_model = poly_hl_model.get_ll_model().to(poly_hl_model.device)
ll_model.device = poly_hl_model.device
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    epochs=n_epochs,
)

Moving model to device:  mps
input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.a

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/53 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.8036, train/behavior_loss: 0.3006, train/strict_loss: 0.1285, val/iit_loss: 0.4144, val/IIA: 89.61%, val/accuracy: 86.95%, val/strict_accuracy: 86.95%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.2988, train/behavior_loss: 0.0937, train/strict_loss: 0.0420, val/iit_loss: 0.1993, val/IIA: 94.29%, val/accuracy: 95.22%, val/strict_accuracy: 95.61%
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.1571, train/behavior_loss: 0.0424, train/strict_loss: 0.0212, val/iit_loss: 0.1156, val/IIA: 97.55%, val/accuracy: 93.32%, val/strict_accuracy: 94.67%
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.1005, train/behavior_loss: 0.0324, train/strict_loss: 0.0161, val/iit_loss: 0.1152, val

KeyboardInterrupt: 

In [None]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_2+3")

# (0) DuplicateRemover + (1) LeftGreater + (2) ParenChecker

In [26]:
cases = [Case0, Case1, Case2]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}

blocks.0.mlp.hook_post [[prev_equal_hook], [mlp0_hook], [mlp0_hook]]
blocks.0.attn.hook_z.0 [[prev_token_hook], None, None]
blocks.0.attn.hook_z.1 [None, [paren_counts_hook], None]
blocks.0.attn.hook_z.2 [None, None, None]
blocks.0.attn.hook_z.3 [None,

In [27]:
ll_model = poly_hl_model.get_ll_model().to(poly_hl_model.device)
ll_model.device = poly_hl_model.device
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    epochs=n_epochs,
)

Moving model to device:  mps
input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}
[LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.9253, train/behavior_loss: 0.3471, train/strict_loss: 0.1651, val/iit_loss: 0.5707, val/IIA: 81.28%, val/accuracy: 85.31%, val/strict_accuracy: 84.14%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.4242, train/behavior_loss: 0.1499, train/strict_loss: 0.0831, val/iit_loss: 0.3493, val/IIA: 86.77%, val/accuracy: 90.85%, val/strict_accuracy: 90.38%
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.3061, train/behavior_loss: 0.0817, train/strict_loss: 0.0472, val/iit_loss: 0.2568, val/IIA: 90.01%, val/accuracy: 91.84%, val/strict_accuracy: 91.32%
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.2245, train/behavior_loss: 0.0610, train/strict_loss: 0.0387, val/iit_loss: 0.2061, val

KeyboardInterrupt: 

In [28]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_0+1+2")

# (0) DuplicateRemover + (1) LeftGreater + (3) UniqueExtractor

In [28]:
cases = [Case0, Case1, Case3]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}

blocks.0.mlp.hook_post [[prev_equal_hook], [mlp0_hook], [appeared_mlp]]
blocks.0.attn.hook_z.0 [[prev_token_hook], None, None]
blocks.0.attn.hook_z.1 [None, [paren_counts_hook], None]
blocks.0.attn.hook_z.2 [None, None, [counter_head]]
blocks.0.attn.hook_z.3 [None, None, None]
blocks.1.mlp.hook_post [[output_hook], None, [mask_mlp]]
block

In [29]:
ll_model = poly_hl_model.get_ll_model().to(poly_hl_model.device)
ll_model.device = poly_hl_model.device
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    epochs=n_epochs,
)

Moving model to device:  mps
input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
[LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.2.a

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/85 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.7985, train/behavior_loss: 0.2824, train/strict_loss: 0.1248, val/iit_loss: 0.4628, val/IIA: 82.20%, val/accuracy: 88.30%, val/strict_accuracy: 88.28%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.3609, train/behavior_loss: 0.0864, train/strict_loss: 0.0529, val/iit_loss: 0.2789, val/IIA: 88.48%, val/accuracy: 95.40%, val/strict_accuracy: 94.75%
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.2055, train/behavior_loss: 0.0401, train/strict_loss: 0.0347, val/iit_loss: 0.1913, val/IIA: 92.46%, val/accuracy: 99.04%, val/strict_accuracy: 97.70%
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.1676, train/behavior_loss: 0.0213, train/strict_loss: 0.0325, val/iit_loss: 0.1490, val

KeyboardInterrupt: 

In [31]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_0+1+3")

# (0) DuplicateRemover + (2) ParenChecker + (3) UniqueExtractor 

In [30]:
cases = [Case0, Case2, Case3]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}

blocks.0.mlp.hook_post [[prev_equal_hook], [mlp0_hook], [appeared_mlp]]
blocks.0.attn.hook_z.0 [[prev_token_hook], None, None]
blocks.0.attn.hook_z.1 [None, None, None]
blocks.0.attn.hook_z.2 [None, None, [counter_head]]
blocks.0.attn.hook_z.3 [None, [

In [31]:
ll_model = poly_hl_model.get_ll_model().to(poly_hl_model.device)
ll_model.device = poly_hl_model.device
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    epochs=n_epochs,
)

Moving model to device:  mps
input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}
[LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/85 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.8023, train/behavior_loss: 0.2993, train/strict_loss: 0.1361, val/iit_loss: 0.4318, val/IIA: 84.82%, val/accuracy: 87.76%, val/strict_accuracy: 87.14%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.3200, train/behavior_loss: 0.0853, train/strict_loss: 0.0488, val/iit_loss: 0.2822, val/IIA: 90.63%, val/accuracy: 92.03%, val/strict_accuracy: 91.50%
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.2377, train/behavior_loss: 0.0429, train/strict_loss: 0.0329, val/iit_loss: 0.1742, val/IIA: 94.00%, val/accuracy: 96.90%, val/strict_accuracy: 96.59%
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.1367, train/behavior_loss: 0.0227, train/strict_loss: 0.0240, val/iit_loss: 0.1413, val

KeyboardInterrupt: 

In [None]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_0+2+3")

# (1) LeftGreater + (2) ParenChecker + (3) UniqueExtractor 

In [32]:
cases = [Case1, Case2, Case3]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}

blocks.0.mlp.hook_post [[mlp0_hook], [mlp0_hook], [appeared_mlp]]
blocks.0.attn.hook_z.0 [None, None, None]
blocks.0.attn.hook_z.1 [[paren_counts_hook], None, None]
blocks.0.attn.hook_z.2 [None, None, [counter_head]]
blocks.0.attn.hook_z.3 [None, [pare

In [33]:
ll_model = poly_hl_model.get_ll_model().to(poly_hl_model.device)
ll_model.device = poly_hl_model.device
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    epochs=n_epochs,
)

Moving model to device:  mps
input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}
[LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/76 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.7353, train/behavior_loss: 0.2803, train/strict_loss: 0.1233, val/iit_loss: 0.3883, val/IIA: 88.49%, val/accuracy: 93.60%, val/strict_accuracy: 92.23%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.2830, train/behavior_loss: 0.0842, train/strict_loss: 0.0461, val/iit_loss: 0.1897, val/IIA: 94.31%, val/accuracy: 93.04%, val/strict_accuracy: 92.96%
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.1657, train/behavior_loss: 0.0381, train/strict_loss: 0.0253, val/iit_loss: 0.1424, val/IIA: 94.76%, val/accuracy: 98.61%, val/strict_accuracy: 98.65%
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.1145, train/behavior_loss: 0.0253, train/strict_loss: 0.0170, val/iit_loss: 0.1125, val

KeyboardInterrupt: 

In [None]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_1+2+3")

# (0) DuplicateRemover + (1) LeftGreater + (2) ParenChecker + (3) UniqueExtractor 

In [34]:
cases = [Case0, Case1, Case2, Case3]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
task_hook {LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}

blocks.0.mlp.hook_post [[prev_equal_hook], [mlp0_hook], [mlp0_hook], [appeared_mlp]]
blocks.0.attn.hook_z.0 [[prev_token_hook], None, None, None]
blocks.0.attn.hook

In [35]:
ll_model = poly_hl_model.get_ll_model().to(poly_hl_model.device)
ll_model.device = poly_hl_model.device
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    epochs=n_epochs,
)

Moving model to device:  mps
input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
task_hook {LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}
[LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspac

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.7291, train/behavior_loss: 0.2524, train/strict_loss: 0.1103, val/iit_loss: 0.3870, val/IIA: 86.22%, val/accuracy: 91.63%, val/strict_accuracy: 91.31%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.3005, train/behavior_loss: 0.0633, train/strict_loss: 0.0347, val/iit_loss: 0.2351, val/IIA: 91.75%, val/accuracy: 94.55%, val/strict_accuracy: 94.21%
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.2001, train/behavior_loss: 0.0286, train/strict_loss: 0.0277, val/iit_loss: 0.1544, val/IIA: 94.96%, val/accuracy: 96.65%, val/strict_accuracy: 96.43%


KeyboardInterrupt: 

In [None]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_0+1+2+3")

# Push to HF

In [None]:
save_to_hf(local_dir="saved_poly_models", message="pushes all polysemantic models")

model.safetensors:   0%|          | 0.00/162k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/162k [00:00<?, ?B/s]

Upload 11 LFS files:   0%|          | 0/11 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/68.4k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/46.5k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/68.4k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/162k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/162k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/68.3k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/162k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/162k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/162k [00:00<?, ?B/s]