In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from cases.paren_checker import HighLevelParensBalanceChecker, test_HL_parens_balancer_components, BalancedParensDataset
from cases.left_greater import HighLevelLeftGreater, test_HL_left_greater_components, LeftGreaterDataset
from cases.duplicate_remover import HighLevelDuplicateRemover, test_HL_duplicate_remover_components, DuplicateRemoverDataset
from cases.unique_extractor import HighLevelUniqueExtractor, test_HL_unique_extractor_components, UniqueExtractorDataset
from poly_hl_model import PolyHLModel, PolyModelDataset

from iit.model_pairs.strict_iit_model_pair import StrictIITModelPair
from iit.utils.index import Ix
import torch

In [9]:
n_epochs = 300
n_samples = 1_000
#iit_weight = 1. / siit_weight = 0.4 / behavior_weight = 1. works!
training_args = {
    "batch_size": 256,
    "num_workers": 0,
    "use_single_loss": True,
    "behavior_weight": 1., #basically doubles the strict weight's job.
    "iit_weight": 0.,
    "strict_weight": 0.,
    "clip_grad_norm": 1.0,
    "iit_weight_schedule" : lambda s, i: s,
    "strict_weight_schedule" : lambda s, i: s,
    "behavior_weight_schedule" : lambda s, i: s, #0.955*s if 0.955**i > 0.01 else s, #have behavior weight decay over time
    "early_stop" : True,
    "lr_scheduler": torch.optim.lr_scheduler.LinearLR,
    "scheduler_kwargs": dict(start_factor=1, end_factor=0.1, total_iters=int(n_epochs)),
    "optimizer_kwargs": dict(lr=1e-3, weight_decay=1e-8, betas=(0.9, 0.9)), #1e-8 was 237 epochs
    "scheduler_val_metric": ["val/accuracy", "val/IIA"], #for ReduceLRonPlateau
    "scheduler_mode": "max", #for ReduceLRonPlateau
    "siit_sampling" : "sample_all",
    "seed" : 42
}

# Paren Checker

In [10]:
test_HL_parens_balancer_components()

All Balance tests passed!


True

In [11]:
hl_model = HighLevelParensBalanceChecker()
corr = hl_model.get_correspondence()
dataset = BalancedParensDataset(N_samples=n_samples, n_ctx=hl_model.get_ll_model_cfg().n_ctx, seed=42)
train_set, test_set = dataset.get_IIT_train_test_set()

making IIT dataset


In [12]:
print(dataset.get_dataset()[2])

{'tokens': [0, 4, 4, 4, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4], 'str_tokens': [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None], 'labels': [[0.0, 0.0, 1.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]], 'markers': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}


In [13]:
print(dataset.get_dataset().shape)
print(dataset.get_dataset()[:10]['tokens'])
print(dataset.get_dataset()[:10]['labels'])
for i in range(10):
    tokens, labels, hl_outputs = dataset.get_dataset()[i]['tokens'], dataset.get_dataset()[i]['labels'], hl_model((torch.tensor(dataset.get_dataset()[i]['tokens'])[None,:], None, None))
    nonzero = (torch.tensor(labels) - hl_outputs[0].cpu()).nonzero()
    if nonzero.numel() > 0:
        print(tokens, torch.unique(nonzero[:,0]))
        bad_indices = torch.unique(nonzero[:,0]).tolist()
        for idx in bad_indices:
            print(labels[idx], hl_outputs[0,idx])

(1000, 4)
[[0, 3, 4, 3, 4, 3, 3, 3, 4, 4, 3, 3, 3, 3, 3], [0, 3, 4, 3, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 4], [0, 4, 4, 4, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4], [0, 4, 4, 4, 3, 4, 4, 4, 4, 4, 3, 3, 4, 3, 3], [0, 3, 4, 3, 3, 4, 4, 3, 3, 3, 3, 3, 3, 3, 4], [0, 4, 3, 3, 4, 3, 4, 3, 4, 3, 3, 4, 3, 4, 3], [0, 3, 4, 3, 3, 3, 4, 3, 4, 3, 4, 4, 3, 4, 4], [0, 3, 4, 3, 4, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4], [0, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4], [0, 4, 3, 4, 3, 4, 4, 3, 4, 3, 3, 3, 3, 4, 4]]
[[[0.0, 0.0, 1.0, 0.0], [1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 1.0, 0.0], [1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [0.0, 1.0

In [14]:

ll_model = hl_model.get_ll_model(seed=42)
model_pair = StrictIITModelPair(hl_model=hl_model, ll_model=ll_model, corr=corr, training_args=training_args)

model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.AdamW,
    epochs=n_epochs,
)

training_args={'batch_size': 256, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_scheduler.LinearLR'>, 'scheduler_val_metric': ['val/accuracy', 'val/IIA'], 'scheduler_mode': 'max', 'scheduler_kwargs': {'start_factor': 1, 'end_factor': 0.1, 'total_iters': 300}, 'optimizer_kwargs': {'lr': 0.001, 'weight_decay': 1e-08, 'betas': (0.9, 0.9)}, 'clip_grad_norm': 1.0, 'seed': 42, 'detach_while_caching': True, 'lr': 0.001, 'atol': 0.05, 'use_single_loss': True, 'iit_weight': 0.0, 'behavior_weight': 1.0, 'strict_weight': 0.0, 'siit_sampling': 'sample_all', 'iit_weight_schedule': <function <lambda> at 0x1061fa8e0>, 'strict_weight_schedule': <function <lambda> at 0x35c1d60c0>, 'behavior_weight_schedule': <function <lambda> at 0x35c1d58a0>}


VBox(children=(Training Epochs:   0%|          | 0/300 [00:00<?, ?it/s],))

Epoch 1: lr: 9.97e-04, iit_weight: 0.00e+00, behavior_weight: 1.00e+00, strict_weight: 0.00e+00, train/iit_loss: 0.00e+00, train/behavior_loss: 1.15e+00, train/strict_loss: 0.00e+00, val/iit_loss: 9.81e-01, val/IIA: 55.51, val/accuracy: 60.13, val/strict_accuracy: 60.13
Epoch 2: lr: 9.94e-04, iit_weight: 0.00e+00, behavior_weight: 1.00e+00, strict_weight: 0.00e+00, train/iit_loss: 0.00e+00, train/behavior_loss: 8.75e-01, train/strict_loss: 0.00e+00, val/iit_loss: 8.23e-01, val/IIA: 95.04, val/accuracy: 95.01, val/strict_accuracy: 95.01
Epoch 3: lr: 9.91e-04, iit_weight: 0.00e+00, behavior_weight: 1.00e+00, strict_weight: 0.00e+00, train/iit_loss: 0.00e+00, train/behavior_loss: 7.41e-01, train/strict_loss: 0.00e+00, val/iit_loss: 6.89e-01, val/IIA: 95.04, val/accuracy: 95.01, val/strict_accuracy: 95.01
Epoch 4: lr: 9.88e-04, iit_weight: 0.00e+00, behavior_weight: 1.00e+00, strict_weight: 0.00e+00, train/iit_loss: 0.00e+00, train/behavior_loss: 6.34e-01, train/strict_loss: 0.00e+00, val/

KeyboardInterrupt: 

# Left > Right

In [13]:
from cases.left_greater import HighLevelLeftGreater, test_HL_left_greater_components, LeftGreaterDataset

test_HL_left_greater_components()

All left greater tests passed!


True

In [14]:
hl_model = HighLevelLeftGreater()
corr = hl_model.get_correspondence()
dataset = LeftGreaterDataset(N_samples=n_samples, n_ctx=hl_model.get_ll_model_cfg().n_ctx, seed=42)
train_set, test_set = dataset.get_IIT_train_test_set()
print(hl_model.hook_dict)
print(list(corr.keys()))

making IIT dataset
{'input_hook': HookPoint(), 'paren_counts_hook': HookPoint(), 'mlp0_hook': HookPoint()}
[input_hook, paren_counts_hook, mlp0_hook]


In [15]:
ll_model = hl_model.get_ll_model(seed=42)

model_pair = StrictIITModelPair(hl_model=hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.AdamW,
    epochs=n_epochs,
)

training_args={'batch_size': 256, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_scheduler.LinearLR'>, 'scheduler_val_metric': ['val/accuracy', 'val/IIA'], 'scheduler_mode': 'max', 'scheduler_kwargs': {'start_factor': 1, 'end_factor': 1, 'total_iters': 300}, 'optimizer_kwargs': {'lr': 0.001, 'weight_decay': 1e-08, 'betas': (0.9, 0.9)}, 'clip_grad_norm': 1.0, 'seed': 42, 'detach_while_caching': True, 'lr': 0.001, 'atol': 0.05, 'use_single_loss': True, 'iit_weight': 1.0, 'behavior_weight': 0.5, 'strict_weight': 0.5, 'siit_sampling': 'sample_all', 'iit_weight_schedule': <function <lambda> at 0x105d7e7a0>, 'strict_weight_schedule': <function <lambda> at 0x33a3f5120>, 'behavior_weight_schedule': <function <lambda> at 0x33a3f49a0>}


VBox(children=(Training Epochs:   0%|          | 0/300 [00:00<?, ?it/s],))

Epoch 1: lr: 1.00e-03, iit_weight: 1.00e+00, behavior_weight: 5.00e-01, strict_weight: 5.00e-01, train/iit_loss: 1.29e+00, train/behavior_loss: 6.68e-01, train/strict_loss: 3.33e-01, val/iit_loss: 1.21e+00, val/IIA: 46.04, val/accuracy: 37.42, val/strict_accuracy: 39.02
Epoch 2: lr: 1.00e-03, iit_weight: 1.00e+00, behavior_weight: 5.00e-01, strict_weight: 5.00e-01, train/iit_loss: 1.19e+00, train/behavior_loss: 6.08e-01, train/strict_loss: 3.03e-01, val/iit_loss: 1.14e+00, val/IIA: 54.93, val/accuracy: 54.87, val/strict_accuracy: 54.87
Epoch 3: lr: 1.00e-03, iit_weight: 1.00e+00, behavior_weight: 5.00e-01, strict_weight: 5.00e-01, train/iit_loss: 1.12e+00, train/behavior_loss: 5.64e-01, train/strict_loss: 2.82e-01, val/iit_loss: 1.09e+00, val/IIA: 54.14, val/accuracy: 54.87, val/strict_accuracy: 54.87
Epoch 4: lr: 1.00e-03, iit_weight: 1.00e+00, behavior_weight: 5.00e-01, strict_weight: 5.00e-01, train/iit_loss: 1.06e+00, train/behavior_loss: 5.37e-01, train/strict_loss: 2.69e-01, val/

# Duplicate remover
case 19 in circuits-bench

In [16]:

test_HL_duplicate_remover_components()

[[4, 0, 0, 1, 2, 0, 1, 3, 3], [4, 0, 1, 2, 2, 2, 2, 2, 2], [4, 0, 1, 2, 3, 3, 3, 3, 3]]
[[False, False, True, False, False, False, False, False, True], [False, False, False, False, True, True, True, True, True], [False, False, False, False, False, True, True, True, True]]
All DuplicateRemover tests passed!


True

In [17]:
hl_model = HighLevelDuplicateRemover()
corr = hl_model.get_correspondence()
dataset = DuplicateRemoverDataset(N_samples=n_samples, n_ctx=hl_model.get_ll_model_cfg().n_ctx, seed=42)
train_set, test_set = dataset.get_IIT_train_test_set()
print(hl_model.hook_dict)
print(list(corr.keys()))

making IIT dataset
{'input_hook': HookPoint(), 'prev_token_hook': HookPoint(), 'prev_equal_hook': HookPoint(), 'output_hook': HookPoint()}
[input_hook, prev_token_hook, prev_equal_hook, output_hook]


In [18]:
ll_model = hl_model.get_ll_model(seed=42)

model_pair = StrictIITModelPair(hl_model=hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.AdamW,
    epochs=n_epochs,
)

training_args={'batch_size': 256, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_scheduler.LinearLR'>, 'scheduler_val_metric': ['val/accuracy', 'val/IIA'], 'scheduler_mode': 'max', 'scheduler_kwargs': {'start_factor': 1, 'end_factor': 1, 'total_iters': 300}, 'optimizer_kwargs': {'lr': 0.001, 'weight_decay': 1e-08, 'betas': (0.9, 0.9)}, 'clip_grad_norm': 1.0, 'seed': 42, 'detach_while_caching': True, 'lr': 0.001, 'atol': 0.05, 'use_single_loss': True, 'iit_weight': 1.0, 'behavior_weight': 0.5, 'strict_weight': 0.5, 'siit_sampling': 'sample_all', 'iit_weight_schedule': <function <lambda> at 0x105d7e7a0>, 'strict_weight_schedule': <function <lambda> at 0x33a3f5120>, 'behavior_weight_schedule': <function <lambda> at 0x33a3f49a0>}


VBox(children=(Training Epochs:   0%|          | 0/300 [00:00<?, ?it/s],))

Epoch 1: lr: 1.00e-03, iit_weight: 1.00e+00, behavior_weight: 5.00e-01, strict_weight: 5.00e-01, train/iit_loss: 1.60e+00, train/behavior_loss: 8.04e-01, train/strict_loss: 4.02e-01, val/iit_loss: 1.45e+00, val/IIA: 46.55, val/accuracy: 50.22, val/strict_accuracy: 50.29
Epoch 2: lr: 1.00e-03, iit_weight: 1.00e+00, behavior_weight: 5.00e-01, strict_weight: 5.00e-01, train/iit_loss: 1.40e+00, train/behavior_loss: 6.92e-01, train/strict_loss: 3.46e-01, val/iit_loss: 1.37e+00, val/IIA: 47.80, val/accuracy: 53.13, val/strict_accuracy: 53.03
Epoch 3: lr: 1.00e-03, iit_weight: 1.00e+00, behavior_weight: 5.00e-01, strict_weight: 5.00e-01, train/iit_loss: 1.31e+00, train/behavior_loss: 6.33e-01, train/strict_loss: 3.17e-01, val/iit_loss: 1.32e+00, val/IIA: 57.24, val/accuracy: 65.86, val/strict_accuracy: 66.16
Epoch 4: lr: 1.00e-03, iit_weight: 1.00e+00, behavior_weight: 5.00e-01, strict_weight: 5.00e-01, train/iit_loss: 1.29e+00, train/behavior_loss: 5.93e-01, train/strict_loss: 2.97e-01, val/

# Unique Extractor

In [23]:
test_HL_unique_extractor_components()

All UniqueExtractor tests passed!


True

In [24]:
hl_model = HighLevelUniqueExtractor()
corr = hl_model.get_correspondence()
dataset = UniqueExtractorDataset(N_samples=n_samples, n_ctx=hl_model.get_ll_model_cfg().n_ctx, seed=42)
train_set, test_set = dataset.get_IIT_train_test_set()
print(hl_model.hook_dict)
print(list(corr.keys()))

making IIT dataset
{'input_hook': HookPoint(), 'counter_head': HookPoint(), 'appeared_mlp': HookPoint(), 'mask_mlp': HookPoint(), 'output_mlp': HookPoint()}
[input_hook, counter_head, appeared_mlp, mask_mlp, output_mlp]


In [25]:
ll_model = hl_model.get_ll_model(seed=42)

model_pair = StrictIITModelPair(hl_model=hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.AdamW,
    epochs=n_epochs,
)

training_args={'batch_size': 256, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_scheduler.LinearLR'>, 'scheduler_val_metric': ['val/accuracy', 'val/IIA'], 'scheduler_mode': 'max', 'scheduler_kwargs': {'start_factor': 1, 'end_factor': 0.1, 'total_iters': 300}, 'optimizer_kwargs': {'lr': 0.001, 'weight_decay': 1e-08, 'betas': (0.9, 0.9)}, 'clip_grad_norm': 1.0, 'seed': 42, 'detach_while_caching': True, 'lr': 0.001, 'atol': 0.05, 'use_single_loss': True, 'iit_weight': 1.0, 'behavior_weight': 0.5, 'strict_weight': 0.5, 'siit_sampling': 'sample_all', 'iit_weight_schedule': <function <lambda> at 0x33a3f5440>, 'strict_weight_schedule': <function <lambda> at 0x33a3f6700>, 'behavior_weight_schedule': <function <lambda> at 0x33a3f5b20>}


VBox(children=(Training Epochs:   0%|          | 0/300 [00:00<?, ?it/s],))

Epoch 1: lr: 9.97e-04, iit_weight: 1.00e+00, behavior_weight: 5.00e-01, strict_weight: 5.00e-01, train/iit_loss: 1.22e+00, train/behavior_loss: 6.10e-01, train/strict_loss: 3.05e-01, val/iit_loss: 1.01e+00, val/IIA: 77.49, val/accuracy: 80.00, val/strict_accuracy: 80.00
Epoch 2: lr: 9.94e-04, iit_weight: 1.00e+00, behavior_weight: 5.00e-01, strict_weight: 5.00e-01, train/iit_loss: 9.53e-01, train/behavior_loss: 4.73e-01, train/strict_loss: 2.37e-01, val/iit_loss: 9.04e-01, val/IIA: 82.87, val/accuracy: 86.67, val/strict_accuracy: 86.67
Epoch 3: lr: 9.91e-04, iit_weight: 1.00e+00, behavior_weight: 5.00e-01, strict_weight: 5.00e-01, train/iit_loss: 8.92e-01, train/behavior_loss: 4.17e-01, train/strict_loss: 2.08e-01, val/iit_loss: 8.23e-01, val/IIA: 83.51, val/accuracy: 86.67, val/strict_accuracy: 86.67
Epoch 4: lr: 9.88e-04, iit_weight: 1.00e+00, behavior_weight: 5.00e-01, strict_weight: 5.00e-01, train/iit_loss: 7.48e-01, train/behavior_loss: 3.67e-01, train/strict_loss: 1.83e-01, val/

# Poly Model

In [17]:
# n_epochs = 1000
# #iit_weight = 1. / siit_weight = 0.4 / behavior_weight = 1. works!
# training_args = {
#     "batch_size": 256,
#     "num_workers": 0,
#     "use_single_loss": True,
#     "behavior_weight": 0.5, #basically doubles the strict weight's job.
#     "iit_weight": 1.,
#     "strict_weight": 0.5,
#     "clip_grad_norm": 1.0,
#     "iit_weight_schedule" : lambda s, i: s,
#     "strict_weight_schedule" : lambda s, i: s,
#     "behavior_weight_schedule" : lambda s, i: s, #0.955*s if 0.955**i > 0.01 else s, #have behavior weight decay over time
#     "early_stop" : True,
#     "lr_scheduler": torch.optim.lr_scheduler.LinearLR,
#     "scheduler_kwargs": dict(start_factor=1, end_factor=0.01, total_iters=int(0.8*n_epochs)),
#     "optimizer_kwargs": dict(lr=1e-3, weight_decay=1e-4, betas=(0.9, 0.7)), #betas=(0.9, 0.7) trains better than (0.9, 0.9)
#     "scheduler_val_metric": ["val/accuracy", "val/IIA"], #for ReduceLRonPlateau
#     "scheduler_mode": "max", #for ReduceLRonPlateau
#     "siit_sampling" : "sample_all"
# }

## LeftGreater + ParensBalancer

In [5]:
cases = [HighLevelLeftGreater, HighLevelParensBalanceChecker]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=2)
corr = poly_hl_model.get_correspondence()

print(n_samples)
dataset1 = LeftGreaterDataset(N_samples=n_samples, n_ctx=15, seed=42)
dataset2 = BalancedParensDataset(N_samples=n_samples, n_ctx=15, seed=42)
dsets = [dataset1,dataset2]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)


1000


  tokens = torch.cat([torch.tensor(dataset.tokens) for dataset in datasets], dim=0)
  self.inputs = t.tensor(inputs).to(int)


In [6]:
ll_model = poly_hl_model.get_ll_model(seed=42)
corr = poly_hl_model.get_correspondence()
# for k, v in corr.items():
#     print(k, v)

# for k, v in poly_hl_model.corr_mapping.items():
#     print(k, v)

train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.AdamW,
    epochs=n_epochs,
)

[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.2.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.2.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.2.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.3.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.3.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.3.attn.hook_z', index=[:, :, 2, :], subspace=None)]
training_args={'batch_size': 256, 'num_workers': 0, 'early_stop': True

VBox(children=(Training Epochs:   0%|          | 0/300 [00:00<?, ?it/s],))

Epoch 1: lr: 9.96e-04, iit_weight: 1.00e+00, behavior_weight: 2.50e-01, strict_weight: 7.50e-01, train/iit_loss: 8.08e-01, train/behavior_loss: 1.75e-01, train/strict_loss: 4.40e-01, val/iit_loss: 3.64e-01, val/IIA: 87.25, val/accuracy: 90.25, val/strict_accuracy: 88.53
Epoch 2: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 2.50e-01, strict_weight: 7.50e-01, train/iit_loss: 3.44e-01, train/behavior_loss: 5.97e-02, train/strict_loss: 2.51e-01, val/iit_loss: 2.37e-01, val/IIA: 89.06, val/accuracy: 98.28, val/strict_accuracy: 95.12
Epoch 3: lr: 9.88e-04, iit_weight: 1.00e+00, behavior_weight: 2.50e-01, strict_weight: 7.50e-01, train/iit_loss: 2.75e-01, train/behavior_loss: 3.51e-02, train/strict_loss: 3.04e-01, val/iit_loss: 2.01e-01, val/IIA: 90.62, val/accuracy: 94.64, val/strict_accuracy: 92.45
Epoch 4: lr: 9.83e-04, iit_weight: 1.00e+00, behavior_weight: 2.50e-01, strict_weight: 7.50e-01, train/iit_loss: 2.48e-01, train/behavior_loss: 3.14e-02, train/strict_loss: 2.92e-01, val/

KeyboardInterrupt: 

### DuplicateRemover + LeftGreater

In [None]:
cases = [HighLevelDuplicateRemover, HighLevelLeftGreater]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=2)
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

print()
for k, v in poly_hl_model.corr_mapping.items():
    print(k, v)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}

blocks.0.mlp.hook_post [[prev_equal_hook], [mlp0_hook]]
blocks.0.attn.hook_z.0 [[prev_token_hook], None]
blocks.0.attn.hook_z.1 [None, [paren_counts_hook]]
blocks.0.attn.hook_z.2 [None, None]
blocks.0.attn.hook_z.3 [None, None]
blocks.1.mlp.hook_post [[output_hook], None]
blocks.1.attn.hook_z.0 [None, None]
blocks.1.attn.hook_z.1 [None, None]
blocks.1.attn.hook_z.2 [None, None]
blocks.1.attn.hook_z.3 [None, None]


In [None]:
dataset1 = DuplicateRemoverDataset(N_samples=n_samples, n_ctx=15, seed=42)
dataset2 = LeftGreaterDataset(N_samples=n_samples, n_ctx=15, seed=42)
dsets = [dataset1,dataset2]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

test_samples = 100
input = poly_dataset.get_dataset()[:test_samples][0]
output = poly_hl_model((input, None, None)).cpu()
expected = poly_dataset.get_dataset()[:test_samples][1]
print(input)
print(torch.allclose(output, expected))

  _, cache = hl_model.run_with_cache((t.tensor(sample).unsqueeze(0), None, None))


tensor([[1, 3, 0,  ..., 1, 0, 1],
        [1, 3, 0,  ..., 1, 0, 0],
        [0, 4, 0,  ..., 0, 0, 1],
        ...,
        [1, 3, 1,  ..., 0, 0, 1],
        [0, 4, 0,  ..., 0, 1, 0],
        [0, 4, 0,  ..., 2, 2, 2]])
True


  labels = torch.cat([torch.tensor(dataset.labels) for dataset in datasets], dim=0)


In [None]:
ll_model = poly_hl_model.get_ll_model()
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

# n_epochs = 100
# training_args = {
#     "batch_size": 256,
#     "lr": 3e-4,
#     "num_workers": 0,
#     "use_single_loss": False,
#     "behavior_weight": 1., #basically doubles the strict weight's job.
#     "iit_weight": 1.,
#     "strict_weight": 0.4,
#     "clip_grad_norm": 1.0,
#     "iit_weight_schedule" : lambda s, i: s,
#     "strict_weight_schedule" : lambda s, i: s,
#     "behavior_weight_schedule" : lambda s, i: s, #0.955*s if 0.955**i > 0.01 else s, #have behavior weight decay over time
#     "early_stop" : True,
#     "lr_scheduler": torch.optim.lr_scheduler.LinearLR,
#     "scheduler_kwargs": dict(start_factor=1, end_factor=0, total_iters=n_epochs),
#     "scheduler_val_metric": ["val/accuracy", "val/IIA"], #for ReduceLRonPlateau
#     "scheduler_mode": "max", #for ReduceLRonPlateau
# }

# training_args['iit_weight'] = 1.
# training_args['strict_weight'] = 1.
# training_args['optimizer_kwargs']['betas'] = (0.8, 0.9)
# training_args['optimizer_kwargs']['weight_decay'] = 1e-4
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.AdamW,
    epochs=n_epochs,
    # optimizer_kwargs=dict(weight_decay=1e-4),
)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
[LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 3, :], subspace=None)]
training_args={'batch_size': 256, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_scheduler.LinearLR'>, 'scheduler_va

VBox(children=(Training Epochs:   0%|          | 0/1000 [00:00<?, ?it/s],))

Epoch 1: lr: 9.99e-04, iit_weight: 1.00e+00, behavior_weight: 5.00e-01, strict_weight: 5.00e-01, train/iit_loss: 1.28e+00, train/behavior_loss: 5.97e-01, train/strict_loss: 3.14e-01, val/iit_loss: 1.01e+00, val/IIA: 58.38, val/accuracy: 69.71, val/strict_accuracy: 68.23
Epoch 2: lr: 9.98e-04, iit_weight: 1.00e+00, behavior_weight: 5.00e-01, strict_weight: 5.00e-01, train/iit_loss: 8.94e-01, train/behavior_loss: 3.31e-01, train/strict_loss: 1.94e-01, val/iit_loss: 8.65e-01, val/IIA: 63.64, val/accuracy: 83.48, val/strict_accuracy: 80.62
Epoch 3: lr: 9.96e-04, iit_weight: 1.00e+00, behavior_weight: 5.00e-01, strict_weight: 5.00e-01, train/iit_loss: 5.51e-01, train/behavior_loss: 2.04e-01, train/strict_loss: 1.53e-01, val/iit_loss: 7.91e-01, val/IIA: 70.08, val/accuracy: 87.35, val/strict_accuracy: 84.66
Epoch 4: lr: 9.95e-04, iit_weight: 1.00e+00, behavior_weight: 5.00e-01, strict_weight: 5.00e-01, train/iit_loss: 7.25e-01, train/behavior_loss: 1.51e-01, train/strict_loss: 1.29e-01, val/

## Duplicate Remover + LeftGreater + ParensBalancer

In [None]:
cases = [HighLevelDuplicateRemover, HighLevelLeftGreater, HighLevelParensBalanceChecker]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=2)
corr = poly_hl_model.get_correspondence()

dataset1 = DuplicateRemoverDataset(N_samples=n_samples, n_ctx=15, seed=42)
dataset2 = LeftGreaterDataset(N_samples=n_samples, n_ctx=15, seed=42)
dataset3 = BalancedParensDataset(N_samples=n_samples, n_ctx=15, seed=42)
dsets = [dataset1,dataset2, dataset3]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)


In [None]:
ll_model = poly_hl_model.get_ll_model()
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

# n_epochs = 1000
# training_args = {
#     "batch_size": 256,
#     "lr": 1e-3,
#     "num_workers": 0,
#     "use_single_loss": True,
#     "behavior_weight": 0., #basically doubles the strict weight's job.
#     "iit_weight": 1.,
#     "strict_weight": 1.,
#     "clip_grad_norm": 1.0,
#     "iit_weight_schedule" : lambda s, i: s,
#     "strict_weight_schedule" : lambda s, i: s,
#     "behavior_weight_schedule" : lambda s, i: s, #0.955*s if 0.955**i > 0.01 else s, #have behavior weight decay over time
#     "early_stop" : True,
#     "lr_scheduler": torch.optim.lr_scheduler.LinearLR,
#     "scheduler_kwargs": dict(start_factor=1, end_factor=0, total_iters=n_epochs),
#     "scheduler_val_metric": ["val/accuracy", "val/IIA"], #for ReduceLRonPlateau
#     "scheduler_mode": "max", #for ReduceLRonPlateau
# }
train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.AdamW,
    epochs=n_epochs,
    # optimizer_kwargs=dict(weight_decay=1e-4),
)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
task_hook {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}
[LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', 

VBox(children=(Training Epochs:   0%|          | 0/1000 [00:00<?, ?it/s],))

Epoch 1: lr: 9.99e-04, iit_weight: 1.00e+00, behavior_weight: 5.00e-01, strict_weight: 5.00e-01, train/iit_loss: 1.30e+00, train/behavior_loss: 6.19e-01, train/strict_loss: 3.14e-01, val/iit_loss: 9.41e-01, val/IIA: 59.83, val/accuracy: 63.96, val/strict_accuracy: 63.69
Epoch 2: lr: 9.98e-04, iit_weight: 1.00e+00, behavior_weight: 5.00e-01, strict_weight: 5.00e-01, train/iit_loss: 8.16e-01, train/behavior_loss: 3.61e-01, train/strict_loss: 1.93e-01, val/iit_loss: 5.76e-01, val/IIA: 78.92, val/accuracy: 80.19, val/strict_accuracy: 79.14
Epoch 3: lr: 9.96e-04, iit_weight: 1.00e+00, behavior_weight: 5.00e-01, strict_weight: 5.00e-01, train/iit_loss: 7.98e-01, train/behavior_loss: 2.12e-01, train/strict_loss: 1.41e-01, val/iit_loss: 7.64e-01, val/IIA: 70.80, val/accuracy: 89.67, val/strict_accuracy: 89.09
Epoch 4: lr: 9.95e-04, iit_weight: 1.00e+00, behavior_weight: 5.00e-01, strict_weight: 5.00e-01, train/iit_loss: 6.30e-01, train/behavior_loss: 1.49e-01, train/strict_loss: 1.04e-01, val/

## Duplicate Remover + LeftGreater + ParensBalancer + Unique Extractor

In [None]:
cases = [HighLevelDuplicateRemover, HighLevelLeftGreater, HighLevelParensBalanceChecker, HighLevelUniqueExtractor]
poly_hl_model = PolyHLModel(hl_classes=cases, size_expansion=2)
corr = poly_hl_model.get_correspondence()

dataset1 = DuplicateRemoverDataset(N_samples=n_samples, n_ctx=15, seed=42)
dataset2 = LeftGreaterDataset(N_samples=n_samples, n_ctx=15, seed=42)
dataset3 = BalancedParensDataset(N_samples=n_samples, n_ctx=15, seed=42)
dataset4 = UniqueExtractorDataset(N_samples=n_samples, n_ctx=15, seed=42)
dsets = [dataset1,dataset2, dataset3, dataset4]
poly_dataset = PolyModelDataset(dsets, n_ctx=poly_hl_model.cfg.n_ctx)

  _, cache = hl_model.run_with_cache((t.tensor(sample).unsqueeze(0), None, None))


In [None]:
ll_model = poly_hl_model.get_ll_model()
corr = poly_hl_model.get_correspondence()
for k, v in corr.items():
    print(k, v)

train_set, test_set = poly_dataset.get_IIT_train_test_set()
model_pair = StrictIITModelPair(hl_model=poly_hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
print(model_pair.nodes_not_in_circuit)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    optimizer_cls=torch.optim.AdamW,
    epochs=n_epochs,
)

input_hook {LLNode(name='blocks.0.hook_resid_pre', index=[:], subspace=None)}
mlp_hooks.0 {LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.0.0 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 0, :], subspace=None)}
attn_hooks.0.1 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 1, :], subspace=None)}
attn_hooks.0.2 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 2, :], subspace=None)}
attn_hooks.0.3 {LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)}
mlp_hooks.1 {LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)}
task_hook {LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None)}
mlp_hooks.2 {LLNode(name='blocks.2.mlp.hook_post', index=[:], subspace=None)}
attn_hooks.2.3 {LLNode(name='blocks.2.attn.hook_z', index=[:, :, 3, :], subspace=None)}
[LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.

VBox(children=(Training Epochs:   0%|          | 0/1000 [00:00<?, ?it/s],))

Epoch 1: lr: 9.99e-04, iit_weight: 1.00e+00, behavior_weight: 5.00e-01, strict_weight: 5.00e-01, train/iit_loss: 1.07e+00, train/behavior_loss: 4.72e-01, train/strict_loss: 2.49e-01, val/iit_loss: 9.37e-01, val/IIA: 64.07, val/accuracy: 77.50, val/strict_accuracy: 76.54
Epoch 2: lr: 9.98e-04, iit_weight: 1.00e+00, behavior_weight: 5.00e-01, strict_weight: 5.00e-01, train/iit_loss: 7.76e-01, train/behavior_loss: 2.10e-01, train/strict_loss: 1.25e-01, val/iit_loss: 7.57e-01, val/IIA: 71.59, val/accuracy: 87.03, val/strict_accuracy: 86.17
Epoch 3: lr: 9.96e-04, iit_weight: 1.00e+00, behavior_weight: 5.00e-01, strict_weight: 5.00e-01, train/iit_loss: 5.58e-01, train/behavior_loss: 1.38e-01, train/strict_loss: 9.69e-02, val/iit_loss: 8.75e-01, val/IIA: 70.24, val/accuracy: 89.79, val/strict_accuracy: 89.78
Epoch 4: lr: 9.95e-04, iit_weight: 1.00e+00, behavior_weight: 5.00e-01, strict_weight: 5.00e-01, train/iit_loss: 5.27e-01, train/behavior_loss: 1.04e-01, train/strict_loss: 7.06e-02, val/