In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from cases.paren_checker import HighLevelParensBalanceChecker, test_HL_parens_balancer_components, BalancedParensDataset
from cases.left_greater import HighLevelLeftGreater, test_HL_left_greater_components, LeftGreaterDataset
from cases.duplicate_remover import HighLevelDuplicateRemover, test_HL_duplicate_remover_components, DuplicateRemoverDataset
from cases.unique_extractor import HighLevelUniqueExtractor, test_HL_unique_extractor_components, UniqueExtractorDataset
from poly_hl_model import PolyHLModel, PolyModelDataset

from iit.model_pairs.strict_iit_model_pair import StrictIITModelPair
from iit.utils.index import Ix
import torch

In [3]:
n_epochs = 300
#iit_weight = 1. / siit_weight = 0.4 / behavior_weight = 1. works!
training_args = {
    "batch_size": 256,
    "num_workers": 0,
    "use_single_loss": True,
    "behavior_weight": 1., #basically doubles the strict weight's job.
    "iit_weight": 1.,
    "strict_weight": 0.4,
    "clip_grad_norm": 1.0,
    "iit_weight_schedule" : lambda s, i: s,
    "strict_weight_schedule" : lambda s, i: s,
    "behavior_weight_schedule" : lambda s, i: s, #0.955*s if 0.955**i > 0.01 else s, #have behavior weight decay over time
    "early_stop" : True,
    "lr_scheduler": torch.optim.lr_scheduler.LinearLR,
    "scheduler_kwargs": dict(start_factor=1, end_factor=0.01, total_iters=int(0.8*n_epochs)),
    "optimizer_kwargs": dict(lr=0.001, weight_decay=1e-4, betas=(0.9, 0.9)),
    "scheduler_val_metric": ["val/accuracy", "val/IIA"], #for ReduceLRonPlateau
    "scheduler_mode": "max", #for ReduceLRonPlateau
    "siit_sampling" : "sample_all"
}

# Paren Checker

In [4]:
test_HL_parens_balancer_components()

All Balance tests passed!


True

In [5]:
hl_model = HighLevelParensBalanceChecker()
corr = hl_model.get_correspondence()
dataset = BalancedParensDataset(N_samples=5_000, n_ctx=hl_model.get_ll_model_cfg().n_ctx, seed=42)
train_set, test_set = dataset.get_IIT_train_test_set()

making IIT dataset


In [6]:
print(dataset.get_dataset().shape)
print(dataset.get_dataset()[:10]['tokens'])
print(dataset.get_dataset()[:10]['labels'])
for i in range(10):
    tokens, labels, hl_outputs = dataset.get_dataset()[i]['tokens'], dataset.get_dataset()[i]['labels'], hl_model((torch.tensor(dataset.get_dataset()[i]['tokens'])[None,:], None, None))
    nonzero = (torch.tensor(labels) - hl_outputs[0].cpu()).nonzero()
    if nonzero.numel() > 0:
        print(tokens, torch.unique(nonzero[:,0]))
        bad_indices = torch.unique(nonzero[:,0]).tolist()
        for idx in bad_indices:
            print(labels[idx], hl_outputs[0,idx])

(3992, 4)
[[3, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0], [3, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1], [3, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1], [3, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0], [3, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [3, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1], [3, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0], [3, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0], [3, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1], [3, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1]]
[[[0.0, 0.0, 1.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 1.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0

In [7]:


# n_epochs = 100
# training_args = {
#     "batch_size": 256,
#     "lr": 0.001,
#     "num_workers": 0,
#     "use_single_loss": False,
#     "behavior_weight": 0., #basically doubles the strict weight's job.
#     "iit_weight": 1.,
#     "strict_weight": 1.0,
#     "clip_grad_norm": 1.0,
#     "iit_weight_schedule" : lambda s, i: s,
#     "strict_weight_schedule" : lambda s, i: s,
#     "behavior_weight_schedule" : lambda s, i: s, #0.955*s if 0.955**i > 0.01 else s, #have behavior weight decay over time
#     "early_stop" : True,
#     "lr_scheduler": torch.optim.lr_scheduler.LinearLR,
#     "scheduler_kwargs": dict(start_factor=1, end_factor=0, total_iters=n_epochs),
#     "scheduler_val_metric": ["val/accuracy", "val/IIA"], #for ReduceLRonPlateau
#     "scheduler_mode": "max", #for ReduceLRonPlateau
#     "siit_sampling" : "sample_all"
# }
ll_model = hl_model.get_ll_model()
model_pair = StrictIITModelPair(hl_model=hl_model, ll_model=ll_model, corr=corr, training_args=training_args)

In [8]:
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.AdamW,
    epochs=n_epochs,
    # optimizer_kwargs=dict(weight_decay=1e-4),
)

training_args={'batch_size': 256, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_scheduler.LinearLR'>, 'scheduler_val_metric': ['val/accuracy', 'val/IIA'], 'scheduler_mode': 'max', 'scheduler_kwargs': {'start_factor': 1, 'end_factor': 0.01, 'total_iters': 240}, 'optimizer_kwargs': {'lr': 0.001, 'weight_decay': 0.0001, 'betas': (0.9, 0.9)}, 'clip_grad_norm': 1.0, 'seed': 0, 'detach_while_caching': True, 'lr': 0.001, 'atol': 0.05, 'use_single_loss': True, 'iit_weight': 1.0, 'behavior_weight': 1.0, 'strict_weight': 0.4, 'siit_sampling': 'sample_all', 'iit_weight_schedule': <function <lambda> at 0x1060067a0>, 'strict_weight_schedule': <function <lambda> at 0x30b50c0e0>, 'behavior_weight_schedule': <function <lambda> at 0x30b50c180>}


VBox(children=(Training Epochs:   0%|          | 0/300 [00:00<?, ?it/s],))

Epoch 1: lr: 9.96e-04, iit_weight: 1.00e+00, behavior_weight: 1.00e+00, strict_weight: 4.00e-01, train/iit_loss: 9.35e-01, train/behavior_loss: 9.12e-01, train/strict_loss: 1.48e-01, val/iit_loss: 3.06e-01, val/IIA: 96.27, val/accuracy: 97.07, val/strict_accuracy: 97.07
Epoch 2: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 1.00e+00, strict_weight: 4.00e-01, train/iit_loss: 2.59e-01, train/behavior_loss: 2.29e-01, train/strict_loss: 3.74e-02, val/iit_loss: 1.72e-01, val/IIA: 96.99, val/accuracy: 96.88, val/strict_accuracy: 96.88
Epoch 3: lr: 9.88e-04, iit_weight: 1.00e+00, behavior_weight: 1.00e+00, strict_weight: 4.00e-01, train/iit_loss: 1.78e-01, train/behavior_loss: 1.64e-01, train/strict_loss: 2.75e-02, val/iit_loss: 1.87e-01, val/IIA: 95.68, val/accuracy: 96.88, val/strict_accuracy: 96.88
Epoch 4: lr: 9.83e-04, iit_weight: 1.00e+00, behavior_weight: 1.00e+00, strict_weight: 4.00e-01, train/iit_loss: 1.75e-01, train/behavior_loss: 1.15e-01, train/strict_loss: 2.22e-02, val/

# Left > Right

In [None]:
from cases.left_greater import HighLevelLeftGreater, test_HL_left_greater_components, LeftGreaterDataset

test_HL_left_greater_components()

All left greater tests passed!


True

In [None]:
hl_model = HighLevelLeftGreater()
corr = hl_model.get_correspondence()
dataset = LeftGreaterDataset(N_samples=1_000, n_ctx=hl_model.get_ll_model_cfg().n_ctx, seed=42)
train_set, test_set = dataset.get_IIT_train_test_set()
print(hl_model.hook_dict)
print(list(corr.keys()))

making IIT dataset
{'input_hook': HookPoint(), 'paren_counts_hook': HookPoint(), 'mlp0_hook': HookPoint()}
[input_hook, paren_counts_hook, mlp0_hook]


In [None]:
ll_model = hl_model.get_ll_model()


# n_epochs = 100
# training_args = {
#     "batch_size": 256,
#     "lr": 0.001,
#     "num_workers": 0,
#     "use_single_loss": False,
#     "behavior_weight": 0., #basically doubles the strict weight's job.
#     "iit_weight": 1.,
#     "strict_weight": 1.0,
#     "clip_grad_norm": 1.0,
#     "iit_weight_schedule" : lambda s, i: s,
#     "strict_weight_schedule" : lambda s, i: s,
#     "behavior_weight_schedule" : lambda s, i: s, #0.955*s if 0.955**i > 0.01 else s, #have behavior weight decay over time
#     "early_stop" : True,
#     "lr_scheduler": torch.optim.lr_scheduler.LinearLR,
#     "scheduler_kwargs": dict(start_factor=1, end_factor=0, total_iters=n_epochs),
#     "scheduler_val_metric": ["val/accuracy", "val/IIA"], #for ReduceLRonPlateau
#     "scheduler_mode": "max", #for ReduceLRonPlateau
#     "siit_sampling" : "sample_all"
# }


model_pair = StrictIITModelPair(hl_model=hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.AdamW,
    epochs=n_epochs,
    # optimizer_kwargs=dict(weight_decay=1e-4),
)

training_args={'batch_size': 256, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_scheduler.LinearLR'>, 'scheduler_val_metric': ['val/accuracy', 'val/IIA'], 'scheduler_mode': 'max', 'scheduler_kwargs': {'start_factor': 1, 'end_factor': 0.03, 'total_iters': 240}, 'optimizer_kwargs': {'lr': 0.001, 'weight_decay': 0.0001, 'betas': (0.9, 0.9)}, 'clip_grad_norm': 1.0, 'seed': 0, 'detach_while_caching': True, 'lr': 0.001, 'atol': 0.05, 'use_single_loss': True, 'iit_weight': 1.0, 'behavior_weight': 0.0, 'strict_weight': 1.0, 'siit_sampling': 'sample_all', 'iit_weight_schedule': <function <lambda> at 0x32dde89a0>, 'strict_weight_schedule': <function <lambda> at 0x32ddeb740>, 'behavior_weight_schedule': <function <lambda> at 0x372d92e80>}


VBox(children=(Training Epochs:   0%|          | 0/300 [00:00<?, ?it/s],))

Epoch 1: lr: 9.96e-04, iit_weight: 1.00e+00, behavior_weight: 0.00e+00, strict_weight: 1.00e+00, train/iit_loss: 1.58e+00, train/behavior_loss: 0.00e+00, train/strict_loss: 1.55e+00, val/iit_loss: 1.32e+00, val/IIA: 51.46, val/accuracy: 52.86, val/strict_accuracy: 50.42
Epoch 2: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 0.00e+00, strict_weight: 1.00e+00, train/iit_loss: 1.24e+00, train/behavior_loss: 0.00e+00, train/strict_loss: 1.26e+00, val/iit_loss: 1.17e+00, val/IIA: 56.88, val/accuracy: 70.49, val/strict_accuracy: 67.54
Epoch 3: lr: 9.88e-04, iit_weight: 1.00e+00, behavior_weight: 0.00e+00, strict_weight: 1.00e+00, train/iit_loss: 1.09e+00, train/behavior_loss: 0.00e+00, train/strict_loss: 1.02e+00, val/iit_loss: 9.93e-01, val/IIA: 65.76, val/accuracy: 76.58, val/strict_accuracy: 73.49
Epoch 4: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 0.00e+00, strict_weight: 1.00e+00, train/iit_loss: 9.34e-01, train/behavior_loss: 0.00e+00, train/strict_loss: 8.71e-01, val/

# Duplicate remover
case 19 in circuits-bench

In [None]:

test_HL_duplicate_remover_components()

[[4, 0, 0, 1, 2, 0, 1, 3, 3], [4, 0, 1, 2, 2, 2, 2, 2, 2], [4, 0, 1, 2, 3, 3, 3, 3, 3]]
[[False, False, True, False, False, False, False, False, True], [False, False, False, False, True, True, True, True, True], [False, False, False, False, False, True, True, True, True]]
All DuplicateRemover tests passed!


True

In [None]:
hl_model = HighLevelDuplicateRemover()
corr = hl_model.get_correspondence()
dataset = DuplicateRemoverDataset(N_samples=1_000, n_ctx=hl_model.get_ll_model_cfg().n_ctx, seed=42)
train_set, test_set = dataset.get_IIT_train_test_set()
print(hl_model.hook_dict)
print(list(corr.keys()))

(1000, 15, 5)
making IIT dataset
{'input_hook': HookPoint(), 'prev_token_hook': HookPoint(), 'prev_equal_hook': HookPoint(), 'output_hook': HookPoint()}
[input_hook, prev_token_hook, prev_equal_hook, output_hook]


In [None]:
ll_model = hl_model.get_ll_model()


# n_epochs = 300
# training_args = {
#     "batch_size": 256,
#     "lr": 0.001,
#     "num_workers": 0,
#     "use_single_loss": True,
#     "behavior_weight": 1.0, #basically doubles the strict weight's job.
#     "iit_weight": 1.0,
#     "strict_weight": 0.4,
#     "clip_grad_norm": 1.0,
#     "iit_weight_schedule" : lambda s, i: s,
#     "strict_weight_schedule" : lambda s, i: s,
#     "behavior_weight_schedule" : lambda s, i: s, #0.955*s if 0.955**i > 0.01 else s, #have behavior weight decay over time
#     "early_stop" : True,
#     "lr_scheduler": torch.optim.lr_scheduler.LinearLR,
#     "scheduler_kwargs": dict(start_factor=1, end_factor=0.03, total_iters=int(0.8*n_epochs)),
#     "scheduler_val_metric": ["val/accuracy", "val/IIA"], #for ReduceLRonPlateau
#     "scheduler_mode": "max", #for ReduceLRonPlateau
#     "siit_sampling" : "sample_all"
# }


model_pair = StrictIITModelPair(hl_model=hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.AdamW,
    epochs=n_epochs,
    # optimizer_kwargs=dict(weight_decay=1e-4, betas=(0.9, 0.9)),
)

training_args={'batch_size': 256, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_scheduler.LinearLR'>, 'scheduler_val_metric': ['val/accuracy', 'val/IIA'], 'scheduler_mode': 'max', 'scheduler_kwargs': {'start_factor': 1, 'end_factor': 0.03, 'total_iters': 240}, 'optimizer_kwargs': {'lr': 0.001, 'weight_decay': 0.0001, 'betas': (0.9, 0.9)}, 'clip_grad_norm': 1.0, 'seed': 0, 'detach_while_caching': True, 'lr': 0.001, 'atol': 0.05, 'use_single_loss': True, 'iit_weight': 1.0, 'behavior_weight': 0.0, 'strict_weight': 1.0, 'siit_sampling': 'sample_all', 'iit_weight_schedule': <function <lambda> at 0x32dde89a0>, 'strict_weight_schedule': <function <lambda> at 0x32ddeb740>, 'behavior_weight_schedule': <function <lambda> at 0x372d92e80>}


VBox(children=(Training Epochs:   0%|          | 0/300 [00:00<?, ?it/s],))

Epoch 1: lr: 9.96e-04, iit_weight: 1.00e+00, behavior_weight: 0.00e+00, strict_weight: 1.00e+00, train/iit_loss: 1.61e+00, train/behavior_loss: 0.00e+00, train/strict_loss: 1.59e+00, val/iit_loss: 1.46e+00, val/IIA: 23.42, val/accuracy: 25.09, val/strict_accuracy: 25.40
Epoch 2: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 0.00e+00, strict_weight: 1.00e+00, train/iit_loss: 1.40e+00, train/behavior_loss: 0.00e+00, train/strict_loss: 1.42e+00, val/iit_loss: 1.34e+00, val/IIA: 34.57, val/accuracy: 32.60, val/strict_accuracy: 33.15
Epoch 3: lr: 9.88e-04, iit_weight: 1.00e+00, behavior_weight: 0.00e+00, strict_weight: 1.00e+00, train/iit_loss: 1.43e+00, train/behavior_loss: 0.00e+00, train/strict_loss: 1.29e+00, val/iit_loss: 1.23e+00, val/IIA: 48.34, val/accuracy: 49.05, val/strict_accuracy: 48.95
Epoch 4: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 0.00e+00, strict_weight: 1.00e+00, train/iit_loss: 1.44e+00, train/behavior_loss: 0.00e+00, train/strict_loss: 1.16e+00, val/

In [None]:
test_HL_unique_extractor_components()

All UniqueExtractor tests passed!


True

In [None]:
hl_model = HighLevelUniqueExtractor()
corr = hl_model.get_correspondence()
dataset = UniqueExtractorDataset(N_samples=1_000, n_ctx=hl_model.get_ll_model_cfg().n_ctx, seed=42)
train_set, test_set = dataset.get_IIT_train_test_set()
print(hl_model.hook_dict)
print(list(corr.keys()))

making IIT dataset
{'input_hook': HookPoint(), 'counter_head': HookPoint(), 'appeared_mlp': HookPoint(), 'mask_mlp': HookPoint(), 'output_mlp': HookPoint()}
[input_hook, counter_head, appeared_mlp, mask_mlp, output_mlp]


In [None]:
ll_model = hl_model.get_ll_model()

# n_epochs = 300
# training_args = {
#     "batch_size": 256,
#     "lr": 0.001,
#     "num_workers": 0,
#     "use_single_loss": True,
#     "behavior_weight": 1., #basically doubles the strict weight's job.
#     "iit_weight": 1.,
#     "strict_weight": 0.4,
#     "clip_grad_norm": 1.0,
#     "iit_weight_schedule" : lambda s, i: s,
#     "strict_weight_schedule" : lambda s, i: s,
#     "behavior_weight_schedule" : lambda s, i: s, #0.955*s if 0.955**i > 0.01 else s, #have behavior weight decay over time
#     "early_stop" : True,
#     "lr_scheduler": torch.optim.lr_scheduler.LinearLR,
#     "scheduler_kwargs": dict(start_factor=1, end_factor=0.03, total_iters=int(0.8*n_epochs)),
#     "optimizer_kwargs": dict(weight_decay=1e-4, betas=(0.9, 0.9)
#     "scheduler_val_metric": ["val/accuracy", "val/IIA"], #for ReduceLRonPlateau
#     "scheduler_mode": "max", #for ReduceLRonPlateau
#     "siit_sampling" : "sample_all"
# }
model_pair = StrictIITModelPair(hl_model=hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.AdamW,
    epochs=n_epochs,
    # optimizer_kwargs=dict(weight_decay=1e-4, betas=(0.9, 0.9)),
)

training_args={'batch_size': 256, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_scheduler.LinearLR'>, 'scheduler_val_metric': ['val/accuracy', 'val/IIA'], 'scheduler_mode': 'max', 'scheduler_kwargs': {'start_factor': 1, 'end_factor': 0.03, 'total_iters': 240}, 'optimizer_kwargs': {'lr': 0.001, 'weight_decay': 0.0001, 'betas': (0.9, 0.9)}, 'clip_grad_norm': 1.0, 'seed': 0, 'detach_while_caching': True, 'lr': 0.001, 'atol': 0.05, 'use_single_loss': True, 'iit_weight': 1.0, 'behavior_weight': 0.0, 'strict_weight': 1.0, 'siit_sampling': 'sample_all', 'iit_weight_schedule': <function <lambda> at 0x32dde89a0>, 'strict_weight_schedule': <function <lambda> at 0x32ddeb740>, 'behavior_weight_schedule': <function <lambda> at 0x372d92e80>}


VBox(children=(Training Epochs:   0%|          | 0/300 [00:00<?, ?it/s],))

Epoch 1: lr: 9.96e-04, iit_weight: 1.00e+00, behavior_weight: 0.00e+00, strict_weight: 1.00e+00, train/iit_loss: 1.48e+00, train/behavior_loss: 0.00e+00, train/strict_loss: 1.42e+00, val/iit_loss: 1.08e+00, val/IIA: 62.55, val/accuracy: 73.40, val/strict_accuracy: 73.54
Epoch 2: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 0.00e+00, strict_weight: 1.00e+00, train/iit_loss: 7.21e-01, train/behavior_loss: 0.00e+00, train/strict_loss: 7.42e-01, val/iit_loss: 6.58e-01, val/IIA: 82.98, val/accuracy: 84.56, val/strict_accuracy: 84.42
Epoch 3: lr: 9.88e-04, iit_weight: 1.00e+00, behavior_weight: 0.00e+00, strict_weight: 1.00e+00, train/iit_loss: 5.89e-01, train/behavior_loss: 0.00e+00, train/strict_loss: 5.80e-01, val/iit_loss: 5.88e-01, val/IIA: 83.02, val/accuracy: 85.70, val/strict_accuracy: 85.52
Epoch 4: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 0.00e+00, strict_weight: 1.00e+00, train/iit_loss: 5.10e-01, train/behavior_loss: 0.00e+00, train/strict_loss: 4.98e-01, val/

# Poly Model

### DuplicateRemover + LeftGreater

In [None]:
cases = [HighLevelDuplicateRemover, HighLevelLeftGreater]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=2)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}

blocks.0.mlp.hook_post [[prev_equal_hook], [mlp0_hook]]
blocks.0.attn.hook_z.0 [[prev_token_hook], None]
blocks.0.attn.hook_z.1 [None, [paren_counts_hook]]
blocks.0.attn.hook_z.2 [None, None]
blocks.0.attn.hook_z.3 [None, None]
blocks.1.mlp.hook_post [[output_hook], None]
blocks.1.attn.hook_z.0 [None, None]
blocks.1.attn.hook_z.1 [None, None]
blocks.1.attn.hook_z.2 [None, None]
blocks.1.attn.hook_z.3 [None, None]


In [None]:
dataset1 = DuplicateRemoverDataset(N_samples=1_000, n_ctx=15, seed=42)
dataset2 = LeftGreaterDataset(N_samples=1_000, n_ctx=15, seed=42)
dsets = [dataset1,dataset2]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

n_samples = 100
input = poly_dataset.get_dataset()[:n_samples][0]
output = poly_hl_model((input, None, None)).cpu()
expected = poly_dataset.get_dataset()[:n_samples][1]
print(input)
print(torch.allclose(output, expected))

(1000, 15, 5)
(1000, 16, 5)
tensor([[1, 3, 1,  ..., 1, 1, 1],
        [1, 3, 0,  ..., 0, 1, 0],
        [1, 3, 1,  ..., 0, 0, 0],
        ...,
        [1, 3, 1,  ..., 0, 1, 1],
        [1, 3, 0,  ..., 1, 1, 0],
        [1, 3, 1,  ..., 1, 1, 1]])
True


  _, cache = hl_model.run_with_cache((t.tensor(sample).unsqueeze(0), None, None))
  tokens = torch.cat([torch.tensor(dataset.tokens) for dataset in datasets], dim=0)
  labels = torch.cat([torch.tensor(dataset.labels) for dataset in datasets], dim=0)
  self.inputs = t.tensor(inputs).to(int)


In [None]:
ll_model = poly_hl_model.get_ll_model()
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

# n_epochs = 100
# training_args = {
#     "batch_size": 256,
#     "lr": 3e-4,
#     "num_workers": 0,
#     "use_single_loss": False,
#     "behavior_weight": 1., #basically doubles the strict weight's job.
#     "iit_weight": 1.,
#     "strict_weight": 0.4,
#     "clip_grad_norm": 1.0,
#     "iit_weight_schedule" : lambda s, i: s,
#     "strict_weight_schedule" : lambda s, i: s,
#     "behavior_weight_schedule" : lambda s, i: s, #0.955*s if 0.955**i > 0.01 else s, #have behavior weight decay over time
#     "early_stop" : True,
#     "lr_scheduler": torch.optim.lr_scheduler.LinearLR,
#     "scheduler_kwargs": dict(start_factor=1, end_factor=0, total_iters=n_epochs),
#     "scheduler_val_metric": ["val/accuracy", "val/IIA"], #for ReduceLRonPlateau
#     "scheduler_mode": "max", #for ReduceLRonPlateau
# }

# training_args['iit_weight'] = 1.
# training_args['strict_weight'] = 1.
# training_args['optimizer_kwargs']['betas'] = (0.8, 0.9)
# training_args['optimizer_kwargs']['weight_decay'] = 1e-4
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.AdamW,
    epochs=n_epochs,
    # optimizer_kwargs=dict(weight_decay=1e-4),
)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 3, :], subspace=None)]
training_args={'batch_size': 256, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_scheduler.LinearLR'>, 'scheduler_va

VBox(children=(Training Epochs:   0%|          | 0/300 [00:00<?, ?it/s],))

Epoch 1: lr: 9.96e-04, iit_weight: 1.00e+00, behavior_weight: 0.00e+00, strict_weight: 1.00e+00, train/iit_loss: 1.50e+00, train/behavior_loss: 0.00e+00, train/strict_loss: 1.46e+00, val/iit_loss: 1.23e+00, val/IIA: 46.04, val/accuracy: 52.39, val/strict_accuracy: 49.62
Epoch 2: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 0.00e+00, strict_weight: 1.00e+00, train/iit_loss: 1.14e+00, train/behavior_loss: 0.00e+00, train/strict_loss: 1.06e+00, val/iit_loss: 9.10e-01, val/IIA: 58.61, val/accuracy: 60.63, val/strict_accuracy: 60.76
Epoch 3: lr: 9.88e-04, iit_weight: 1.00e+00, behavior_weight: 0.00e+00, strict_weight: 1.00e+00, train/iit_loss: 1.03e+00, train/behavior_loss: 0.00e+00, train/strict_loss: 8.18e-01, val/iit_loss: 8.53e-01, val/IIA: 63.54, val/accuracy: 70.67, val/strict_accuracy: 70.46
Epoch 4: lr: 9.83e-04, iit_weight: 1.00e+00, behavior_weight: 0.00e+00, strict_weight: 1.00e+00, train/iit_loss: 9.45e-01, train/behavior_loss: 0.00e+00, train/strict_loss: 6.66e-01, val/

## Duplicate Remover + LeftGreater + ParensBalancer

In [None]:
cases = [HighLevelDuplicateRemover, HighLevelLeftGreater, HighLevelParensBalanceChecker]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=2)
corr = poly_hl_model.get_correspondence()

dataset1 = DuplicateRemoverDataset(N_samples=1_000, n_ctx=15, seed=42)
dataset2 = LeftGreaterDataset(N_samples=1_000, n_ctx=15, seed=42)
dataset3 = BalancedParensDataset(N_samples=1_000, n_ctx=15, seed=42)
dsets = [dataset1,dataset2, dataset3]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)


(1000, 15, 5)
(1000, 16, 5)


In [None]:
ll_model = poly_hl_model.get_ll_model()
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

# n_epochs = 1000
# training_args = {
#     "batch_size": 256,
#     "lr": 1e-3,
#     "num_workers": 0,
#     "use_single_loss": True,
#     "behavior_weight": 0., #basically doubles the strict weight's job.
#     "iit_weight": 1.,
#     "strict_weight": 1.,
#     "clip_grad_norm": 1.0,
#     "iit_weight_schedule" : lambda s, i: s,
#     "strict_weight_schedule" : lambda s, i: s,
#     "behavior_weight_schedule" : lambda s, i: s, #0.955*s if 0.955**i > 0.01 else s, #have behavior weight decay over time
#     "early_stop" : True,
#     "lr_scheduler": torch.optim.lr_scheduler.LinearLR,
#     "scheduler_kwargs": dict(start_factor=1, end_factor=0, total_iters=n_epochs),
#     "scheduler_val_metric": ["val/accuracy", "val/IIA"], #for ReduceLRonPlateau
#     "scheduler_mode": "max", #for ReduceLRonPlateau
# }
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.AdamW,
    epochs=n_epochs,
    # optimizer_kwargs=dict(weight_decay=1e-4),
)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}
[LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', 

VBox(children=(Training Epochs:   0%|          | 0/300 [00:00<?, ?it/s],))

Epoch 1: lr: 9.96e-04, iit_weight: 1.00e+00, behavior_weight: 0.00e+00, strict_weight: 1.00e+00, train/iit_loss: 1.18e+00, train/behavior_loss: 0.00e+00, train/strict_loss: 1.14e+00, val/iit_loss: 8.69e-01, val/IIA: 62.22, val/accuracy: 67.58, val/strict_accuracy: 66.98
Epoch 2: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 0.00e+00, strict_weight: 1.00e+00, train/iit_loss: 7.42e-01, train/behavior_loss: 0.00e+00, train/strict_loss: 6.38e-01, val/iit_loss: 6.33e-01, val/IIA: 72.23, val/accuracy: 77.25, val/strict_accuracy: 77.04
Epoch 3: lr: 9.88e-04, iit_weight: 1.00e+00, behavior_weight: 0.00e+00, strict_weight: 1.00e+00, train/iit_loss: 7.96e-01, train/behavior_loss: 0.00e+00, train/strict_loss: 4.68e-01, val/iit_loss: 6.95e-01, val/IIA: 72.98, val/accuracy: 87.82, val/strict_accuracy: 86.46
Epoch 4: lr: 9.83e-04, iit_weight: 1.00e+00, behavior_weight: 0.00e+00, strict_weight: 1.00e+00, train/iit_loss: 6.37e-01, train/behavior_loss: 0.00e+00, train/strict_loss: 3.93e-01, val/

## Duplicate Remover + LeftGreater + ParensBalancer + Unique Extractor

In [None]:
cases = [HighLevelDuplicateRemover, HighLevelLeftGreater, HighLevelParensBalanceChecker, HighLevelUniqueExtractor]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=2)
corr = poly_hl_model.get_correspondence()

dataset1 = DuplicateRemoverDataset(N_samples=1_000, n_ctx=15, seed=42)
dataset2 = LeftGreaterDataset(N_samples=1_000, n_ctx=15, seed=42)
dataset3 = BalancedParensDataset(N_samples=1_000, n_ctx=15, seed=42)
dataset4 = UniqueExtractorDataset(N_samples=1_000, n_ctx=15, seed=42)
dsets = [dataset1,dataset2, dataset3, dataset4]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

(1000, 15, 5)
(1000, 16, 5)


  _, cache = hl_model.run_with_cache((t.tensor(sample).unsqueeze(0), None, None))


In [None]:
ll_model = poly_hl_model.get_ll_model()
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.AdamW,
    epochs=n_epochs,
)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
task_hook {LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}
[LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.

VBox(children=(Training Epochs:   0%|          | 0/300 [00:00<?, ?it/s],))

Epoch 1: lr: 9.96e-04, iit_weight: 1.00e+00, behavior_weight: 0.00e+00, strict_weight: 1.00e+00, train/iit_loss: 1.13e+00, train/behavior_loss: 0.00e+00, train/strict_loss: 1.05e+00, val/iit_loss: 9.06e-01, val/IIA: 64.20, val/accuracy: 78.51, val/strict_accuracy: 77.54
Epoch 2: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 0.00e+00, strict_weight: 1.00e+00, train/iit_loss: 7.17e-01, train/behavior_loss: 0.00e+00, train/strict_loss: 4.01e-01, val/iit_loss: 8.33e-01, val/IIA: 72.94, val/accuracy: 88.31, val/strict_accuracy: 88.43
Epoch 3: lr: 9.88e-04, iit_weight: 1.00e+00, behavior_weight: 0.00e+00, strict_weight: 1.00e+00, train/iit_loss: 5.29e-01, train/behavior_loss: 0.00e+00, train/strict_loss: 2.28e-01, val/iit_loss: 6.98e-01, val/IIA: 75.25, val/accuracy: 92.24, val/strict_accuracy: 92.19
Epoch 4: lr: 9.83e-04, iit_weight: 1.00e+00, behavior_weight: 0.00e+00, strict_weight: 1.00e+00, train/iit_loss: 4.22e-01, train/behavior_loss: 0.00e+00, train/strict_loss: 1.73e-01, val/

KeyboardInterrupt: 