In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import torch

from iit.model_pairs.strict_iit_model_pair import StrictIITModelPair

from poly_bench.cases.paren_checker import HighLevelParensBalanceChecker, BalancedParensDataset
from poly_bench.cases.left_greater import HighLevelLeftGreater, LeftGreaterDataset
from poly_bench.cases.duplicate_remover import HighLevelDuplicateRemover, DuplicateRemoverDataset
from poly_bench.cases.unique_extractor import HighLevelUniqueExtractor, UniqueExtractorDataset
from poly_bench.poly_hl_model import PolyHLModel, PolyModelDataset
from poly_bench.utils import save_poly_model_to_dir

Case0 = HighLevelDuplicateRemover
Case1 = HighLevelLeftGreater
Case2 = HighLevelParensBalanceChecker
Case3 = HighLevelUniqueExtractor

dataset_mapping = {
    HighLevelDuplicateRemover: DuplicateRemoverDataset,
    HighLevelLeftGreater: LeftGreaterDataset,
    HighLevelParensBalanceChecker: BalancedParensDataset,
    HighLevelUniqueExtractor: UniqueExtractorDataset
}

In [5]:
n_epochs = 100
n_samples = 10_000
training_args = {
    "batch_size": 256,
    "num_workers": 0,
    "use_single_loss": True,
    "behavior_weight": 0.4, #basically doubles the strict weight's job.
    "iit_weight": 1.,
    "strict_weight": 0.4,
    "clip_grad_norm": 1.0,
    "iit_weight_schedule" : lambda s, i: s,
    "strict_weight_schedule" : lambda s, i: s,
    "behavior_weight_schedule" : lambda s, i: s, #0.955*s if 0.955**i > 0.01 else s, #have behavior weight decay over time
    "early_stop" : True,
    "lr_scheduler": torch.optim.lr_scheduler.LinearLR,
    "scheduler_kwargs": dict(start_factor=1, end_factor=0.2, total_iters=int(n_epochs)),
    "optimizer_kwargs": dict(lr=1e-3, betas=(0.9, 0.9)),
    "scheduler_val_metric": ["val/accuracy", "val/IIA"], #for ReduceLRonPlateau
    "scheduler_mode": "max", #for ReduceLRonPlateau
    "siit_sampling" : "sample_all",
    "seed" : 42
}

In [6]:
n_ctx = 15
seed = 42

# (0) DuplicateRemover + (1) LeftGreater

In [7]:
cases = [Case0, Case1]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}

blocks.0.mlp.hook_post [[prev_equal_hook], [mlp0_hook]]
blocks.0.attn.hook_z.0 [[prev_token_hook], None]
blocks.0.attn.hook_z.1 [None, [paren_counts_hook]]
blocks.0.attn.hook_z.2 [None, None]
blocks.0.attn.hook_z.3 [None, None]
blocks.1.mlp.hook_post [[output_hook], None]
blocks.1.attn.hook_z.0 [None, None]
blocks.1.attn.hook_z.1 [None, None]
blocks.1.attn.hook_z.2 [None, None]
blocks.1.attn.hook_z.3 [None, None]


In [8]:
dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

  _, cache = hl_model.run_with_cache((t.tensor(sample).unsqueeze(0), None, None))
  self.inputs = t.tensor(inputs).to(int)


In [9]:
ll_model = poly_hl_model.get_ll_model()
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.Adam,
    epochs=n_epochs,
    # optimizer_kwargs=dict(weight_decay=1e-4),
)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 3, :], subspace=None)]
training_args={'batch_size': 256, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_scheduler.LinearLR'>, 'scheduler_va

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/54 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.05e+00, train/behavior_loss: 3.68e-01, train/strict_loss: 1.67e-01, val/iit_loss: 8.54e-01, val/IIA: 69.36, val/accuracy: 78.52, val/strict_accuracy: 75.47
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 6.11e-01, train/behavior_loss: 1.61e-01, train/strict_loss: 1.12e-01, val/iit_loss: 6.56e-01, val/IIA: 78.73, val/accuracy: 87.32, val/strict_accuracy: 83.38
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 4.32e-01, train/behavior_loss: 1.09e-01, train/strict_loss: 1.08e-01, val/iit_loss: 5.26e-01, val/IIA: 84.12, val/accuracy: 90.63, val/strict_accuracy: 86.45
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 3.54e-01, train/behavior_loss: 8.48e-02, train/strict_loss: 8.47e-02, val/

In [11]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_0+1")

# (0) DuplicateRemover + (2) ParenChecker

In [12]:
cases = [Case0, Case2]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}

blocks.0.mlp.hook_post [[prev_equal_hook], [mlp0_hook]]
blocks.0.attn.hook_z.0 [[prev_token_hook], None]
blocks.0.attn.hook_z.1 [None, None]
blocks.0.attn.hook_z.2 [None, None]
blocks.0.attn.hook_z.3 [None, [paren_counts_hook]]
blocks.1.mlp.hook_post [[output_hook], [mlp1_hook]]
blocks.1.attn.hook_z.0 [None, None]
blocks.1.attn.hook_z.1 [

In [13]:
dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

  _, cache = hl_model.run_with_cache((t.tensor(sample).unsqueeze(0), None, None))
  self.markers = None


In [14]:
ll_model = poly_hl_model.get_ll_model()
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.Adam,
    epochs=n_epochs,
    # optimizer_kwargs=dict(weight_decay=1e-4),
)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 3, 

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/54 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.01e+00, train/behavior_loss: 3.18e-01, train/strict_loss: 1.44e-01, val/iit_loss: 4.25e-01, val/IIA: 72.80, val/accuracy: 82.55, val/strict_accuracy: 82.65
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 6.49e-01, train/behavior_loss: 1.45e-01, train/strict_loss: 7.75e-02, val/iit_loss: 3.38e-01, val/IIA: 80.60, val/accuracy: 86.44, val/strict_accuracy: 85.75
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 4.11e-01, train/behavior_loss: 1.10e-01, train/strict_loss: 5.87e-02, val/iit_loss: 2.48e-01, val/IIA: 85.24, val/accuracy: 89.06, val/strict_accuracy: 89.09
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 3.36e-01, train/behavior_loss: 8.03e-02, train/strict_loss: 4.76e-02, val/

In [15]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_0+2")

# (0) DuplicateRemover + (3) UniqueExtractor

In [19]:
cases = [Case0, Case3]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}

blocks.0.mlp.hook_post [[prev_equal_hook], [appeared_mlp]]
blocks.0.attn.hook_z.0 [[prev_token_hook], None]
blocks.0.attn.hook_z.1 [None, None]
blocks.0.attn.hook_z.2 [None, [counter_head]]
blocks.0.attn.hook_z.3 [None, None]
blocks.1.mlp.hook_post [[output_hook], [mask_mlp]]
blocks.1.attn.hook_z.0 [None, None]
blocks.1.attn.hook_z.1 [None, None]
blocks.1.attn.hook_z.2 [None, None]
blocks.1.attn.hook_z.3 [None, None]
blocks.

In [20]:
dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

In [21]:
ll_model = poly_hl_model.get_ll_model()
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.Adam,
    epochs=n_epochs,
    # optimizer_kwargs=dict(weight_decay=1e-4),
)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.2.attn.hook_z', index=[:, :, 0, :], subspace=Non

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 7.06e-01, train/behavior_loss: 2.14e-01, train/strict_loss: 9.88e-02, val/iit_loss: 4.96e-01, val/IIA: 81.89, val/accuracy: 86.72, val/strict_accuracy: 86.84
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 3.54e-01, train/behavior_loss: 8.13e-02, train/strict_loss: 5.61e-02, val/iit_loss: 4.51e-01, val/IIA: 88.03, val/accuracy: 97.54, val/strict_accuracy: 95.17
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 2.02e-01, train/behavior_loss: 2.80e-02, train/strict_loss: 4.32e-02, val/iit_loss: 3.07e-01, val/IIA: 93.38, val/accuracy: 99.77, val/strict_accuracy: 97.99
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.05e-01, train/behavior_loss: 1.43e-02, train/strict_loss: 3.91e-02, val/

In [None]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_0+3")

# (1) LeftGreater + (2) ParenChecker

In [None]:
cases = [Case1, Case2]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

In [None]:
ll_model = poly_hl_model.get_ll_model()
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.Adam,
    epochs=n_epochs,
    # optimizer_kwargs=dict(weight_decay=1e-4),
)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.2.attn.hook_z', index=[:, :, 0, :], subspace=Non

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 7.06e-01, train/behavior_loss: 2.14e-01, train/strict_loss: 9.88e-02, val/iit_loss: 4.96e-01, val/IIA: 81.89, val/accuracy: 86.72, val/strict_accuracy: 86.84
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 3.54e-01, train/behavior_loss: 8.13e-02, train/strict_loss: 5.61e-02, val/iit_loss: 4.51e-01, val/IIA: 88.03, val/accuracy: 97.54, val/strict_accuracy: 95.17
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 2.02e-01, train/behavior_loss: 2.80e-02, train/strict_loss: 4.32e-02, val/iit_loss: 3.07e-01, val/IIA: 93.38, val/accuracy: 99.77, val/strict_accuracy: 97.99
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.05e-01, train/behavior_loss: 1.43e-02, train/strict_loss: 3.91e-02, val/

In [None]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_1+2")

# (1) LeftGreater + (3) UniqueExtractor

In [None]:
cases = [Case1, Case3]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

In [None]:
ll_model = poly_hl_model.get_ll_model()
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.Adam,
    epochs=n_epochs,
    # optimizer_kwargs=dict(weight_decay=1e-4),
)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.2.attn.hook_z', index=[:, :, 0, :], subspace=Non

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 7.06e-01, train/behavior_loss: 2.14e-01, train/strict_loss: 9.88e-02, val/iit_loss: 4.96e-01, val/IIA: 81.89, val/accuracy: 86.72, val/strict_accuracy: 86.84
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 3.54e-01, train/behavior_loss: 8.13e-02, train/strict_loss: 5.61e-02, val/iit_loss: 4.51e-01, val/IIA: 88.03, val/accuracy: 97.54, val/strict_accuracy: 95.17
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 2.02e-01, train/behavior_loss: 2.80e-02, train/strict_loss: 4.32e-02, val/iit_loss: 3.07e-01, val/IIA: 93.38, val/accuracy: 99.77, val/strict_accuracy: 97.99
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.05e-01, train/behavior_loss: 1.43e-02, train/strict_loss: 3.91e-02, val/

In [None]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_1+3")

# (2) ParenChecker + (3) UniqueExtractor

In [None]:
cases = [Case2, Case3]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

In [None]:
ll_model = poly_hl_model.get_ll_model()
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.Adam,
    epochs=n_epochs,
    # optimizer_kwargs=dict(weight_decay=1e-4),
)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.2.attn.hook_z', index=[:, :, 0, :], subspace=Non

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 7.06e-01, train/behavior_loss: 2.14e-01, train/strict_loss: 9.88e-02, val/iit_loss: 4.96e-01, val/IIA: 81.89, val/accuracy: 86.72, val/strict_accuracy: 86.84
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 3.54e-01, train/behavior_loss: 8.13e-02, train/strict_loss: 5.61e-02, val/iit_loss: 4.51e-01, val/IIA: 88.03, val/accuracy: 97.54, val/strict_accuracy: 95.17
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 2.02e-01, train/behavior_loss: 2.80e-02, train/strict_loss: 4.32e-02, val/iit_loss: 3.07e-01, val/IIA: 93.38, val/accuracy: 99.77, val/strict_accuracy: 97.99
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.05e-01, train/behavior_loss: 1.43e-02, train/strict_loss: 3.91e-02, val/

In [None]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_2+3")

# (0) DuplicateRemover + (1) LeftGreater + (2) ParenChecker

In [None]:
cases = [Case0, Case1, Case2]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

In [None]:
ll_model = poly_hl_model.get_ll_model()
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.Adam,
    epochs=n_epochs,
    # optimizer_kwargs=dict(weight_decay=1e-4),
)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.2.attn.hook_z', index=[:, :, 0, :], subspace=Non

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 7.06e-01, train/behavior_loss: 2.14e-01, train/strict_loss: 9.88e-02, val/iit_loss: 4.96e-01, val/IIA: 81.89, val/accuracy: 86.72, val/strict_accuracy: 86.84
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 3.54e-01, train/behavior_loss: 8.13e-02, train/strict_loss: 5.61e-02, val/iit_loss: 4.51e-01, val/IIA: 88.03, val/accuracy: 97.54, val/strict_accuracy: 95.17
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 2.02e-01, train/behavior_loss: 2.80e-02, train/strict_loss: 4.32e-02, val/iit_loss: 3.07e-01, val/IIA: 93.38, val/accuracy: 99.77, val/strict_accuracy: 97.99
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.05e-01, train/behavior_loss: 1.43e-02, train/strict_loss: 3.91e-02, val/

In [None]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_0+1+2")

# (0) DuplicateRemover + (1) LeftGreater + (3) UniqueExtractor

In [None]:
cases = [Case0, Case1, Case3]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

In [None]:
ll_model = poly_hl_model.get_ll_model()
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.Adam,
    epochs=n_epochs,
    # optimizer_kwargs=dict(weight_decay=1e-4),
)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.2.attn.hook_z', index=[:, :, 0, :], subspace=Non

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 7.06e-01, train/behavior_loss: 2.14e-01, train/strict_loss: 9.88e-02, val/iit_loss: 4.96e-01, val/IIA: 81.89, val/accuracy: 86.72, val/strict_accuracy: 86.84
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 3.54e-01, train/behavior_loss: 8.13e-02, train/strict_loss: 5.61e-02, val/iit_loss: 4.51e-01, val/IIA: 88.03, val/accuracy: 97.54, val/strict_accuracy: 95.17
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 2.02e-01, train/behavior_loss: 2.80e-02, train/strict_loss: 4.32e-02, val/iit_loss: 3.07e-01, val/IIA: 93.38, val/accuracy: 99.77, val/strict_accuracy: 97.99
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.05e-01, train/behavior_loss: 1.43e-02, train/strict_loss: 3.91e-02, val/

In [None]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_0+1+3")

# (0) DuplicateRemover + (2) ParenChecker + (3) UniqueExtractor 

In [None]:
cases = [Case0, Case2, Case3]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

In [None]:
ll_model = poly_hl_model.get_ll_model()
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.Adam,
    epochs=n_epochs,
    # optimizer_kwargs=dict(weight_decay=1e-4),
)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.2.attn.hook_z', index=[:, :, 0, :], subspace=Non

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 7.06e-01, train/behavior_loss: 2.14e-01, train/strict_loss: 9.88e-02, val/iit_loss: 4.96e-01, val/IIA: 81.89, val/accuracy: 86.72, val/strict_accuracy: 86.84
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 3.54e-01, train/behavior_loss: 8.13e-02, train/strict_loss: 5.61e-02, val/iit_loss: 4.51e-01, val/IIA: 88.03, val/accuracy: 97.54, val/strict_accuracy: 95.17
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 2.02e-01, train/behavior_loss: 2.80e-02, train/strict_loss: 4.32e-02, val/iit_loss: 3.07e-01, val/IIA: 93.38, val/accuracy: 99.77, val/strict_accuracy: 97.99
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.05e-01, train/behavior_loss: 1.43e-02, train/strict_loss: 3.91e-02, val/

In [None]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_0+2+3")

# (1) LeftGreater + (2) ParenChecker + (3) UniqueExtractor 

In [None]:
cases = [Case1, Case2, Case3]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

In [None]:
ll_model = poly_hl_model.get_ll_model()
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.Adam,
    epochs=n_epochs,
    # optimizer_kwargs=dict(weight_decay=1e-4),
)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.2.attn.hook_z', index=[:, :, 0, :], subspace=Non

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 7.06e-01, train/behavior_loss: 2.14e-01, train/strict_loss: 9.88e-02, val/iit_loss: 4.96e-01, val/IIA: 81.89, val/accuracy: 86.72, val/strict_accuracy: 86.84
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 3.54e-01, train/behavior_loss: 8.13e-02, train/strict_loss: 5.61e-02, val/iit_loss: 4.51e-01, val/IIA: 88.03, val/accuracy: 97.54, val/strict_accuracy: 95.17
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 2.02e-01, train/behavior_loss: 2.80e-02, train/strict_loss: 4.32e-02, val/iit_loss: 3.07e-01, val/IIA: 93.38, val/accuracy: 99.77, val/strict_accuracy: 97.99
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.05e-01, train/behavior_loss: 1.43e-02, train/strict_loss: 3.91e-02, val/

In [None]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_1+2+3")

# (0) DuplicateRemover + (1) LeftGreater + (2) ParenChecker + (3) UniqueExtractor 

In [None]:
cases = [Case0, Case1, Case2, Case3]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=1)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

dataset_cases = [dataset_mapping[case] for case in cases]
dsets = [dsetcase(N_samples=n_samples, n_ctx=n_ctx, seed=seed) for dsetcase in dataset_cases]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

In [None]:
ll_model = poly_hl_model.get_ll_model()
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.Adam,
    epochs=n_epochs,
    # optimizer_kwargs=dict(weight_decay=1e-4),
)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.2.attn.hook_z', index=[:, :, 0, :], subspace=Non

Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 7.06e-01, train/behavior_loss: 2.14e-01, train/strict_loss: 9.88e-02, val/iit_loss: 4.96e-01, val/IIA: 81.89, val/accuracy: 86.72, val/strict_accuracy: 86.84
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 3.54e-01, train/behavior_loss: 8.13e-02, train/strict_loss: 5.61e-02, val/iit_loss: 4.51e-01, val/IIA: 88.03, val/accuracy: 97.54, val/strict_accuracy: 95.17
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 2.02e-01, train/behavior_loss: 2.80e-02, train/strict_loss: 4.32e-02, val/iit_loss: 3.07e-01, val/IIA: 93.38, val/accuracy: 99.77, val/strict_accuracy: 97.99
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.05e-01, train/behavior_loss: 1.43e-02, train/strict_loss: 3.91e-02, val/

In [None]:

save_poly_model_to_dir(ll_model, poly_hl_model, f"./saved_poly_models/cases_0+1+2+3")