In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

from iit.model_pairs.strict_iit_model_pair import StrictIITModelPair

from poly_bench.cases.paren_checker import HighLevelParensBalanceChecker, BalancedParensDataset
from poly_bench.cases.left_greater import HighLevelLeftGreater, LeftGreaterDataset
from poly_bench.cases.duplicate_remover import HighLevelDuplicateRemover, DuplicateRemoverDataset
from poly_bench.cases.unique_extractor import HighLevelUniqueExtractor, UniqueExtractorDataset
from poly_bench.io import save_model_to_dir, save_to_hf


In [3]:
n_epochs = 100
n_samples = 10_000
training_args = {
    "batch_size": 256,
    "num_workers": 0,
    "use_single_loss": True,
    "behavior_weight": 0.4, #basically doubles the strict weight's job.
    "iit_weight": 1.,
    "strict_weight": 0.4,
    "clip_grad_norm": 1.0,
    "early_stop" : True,
    "lr_scheduler": torch.optim.lr_scheduler.LinearLR,
    "scheduler_kwargs": dict(start_factor=1, end_factor=0.2, total_iters=int(n_epochs)),
    "optimizer_kwargs": dict(lr=1e-3, betas=(0.9, 0.9)),
    "scheduler_val_metric": ["val/accuracy", "val/IIA"], #for ReduceLRonPlateau
    "scheduler_mode": "max", #for ReduceLRonPlateau
    "siit_sampling" : "sample_all",
    "val_IIA_sampling": "all", # random or all
    "seed" : 42
}

# Paren Checker

In [14]:
hl_model = HighLevelParensBalanceChecker()
corr = hl_model.get_correspondence()
dataset = BalancedParensDataset(N_samples=n_samples, n_ctx=hl_model.get_ll_model_cfg().n_ctx, seed=42)
train_set, test_set = dataset.get_IIT_train_test_set()

making IIT dataset


In [15]:
i = 6
print(dataset.get_dataset()[i]['tokens'], dataset.get_dataset()[i]['labels'])

[0, 2, 2, 3, 2, 2, 3, 2, 3, 2, 3, 2, 2, 2, 2] [[0.0, 0.0, 1.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]]


In [16]:
print(dataset.get_dataset().shape)
print(dataset.get_dataset()[:10]['tokens'])
print(dataset.get_dataset()[:10]['labels'])
for i in range(10):
    tokens, labels, hl_outputs = dataset.get_dataset()[i]['tokens'], dataset.get_dataset()[i]['labels'], hl_model((torch.tensor(dataset.get_dataset()[i]['tokens'])[None,:], None, None))
    nonzero = (torch.tensor(labels) - hl_outputs[0].cpu()).nonzero()
    if nonzero.numel() > 0:
        print(tokens, torch.unique(nonzero[:,0]))
        bad_indices = torch.unique(nonzero[:,0]).tolist()
        for idx in bad_indices:
            print(labels[idx], hl_outputs[0,idx])

(7007, 4)
[[0, 2, 2, 3, 3, 3, 2, 2, 3, 3, 2, 3, 2, 2, 3], [0, 3, 2, 2, 2, 2, 2, 2, 3, 2, 2, 3, 2, 2, 3], [0, 3, 2, 2, 3, 3, 2, 2, 3, 2, 3, 2, 2, 2, 3], [0, 3, 2, 2, 3, 3, 3, 2, 3, 3, 3, 3, 2, 2, 2], [0, 2, 3, 2, 2, 2, 3, 3, 3, 3, 2, 3, 3, 2, 2], [0, 2, 2, 3, 2, 2, 3, 2, 3, 2, 2, 3, 2, 2, 2], [0, 2, 2, 3, 2, 2, 3, 2, 3, 2, 3, 2, 2, 2, 2], [0, 2, 3, 3, 3, 2, 3, 3, 3, 3, 2, 2, 2, 2, 2], [0, 2, 2, 2, 2, 3, 3, 2, 3, 2, 2, 3, 2, 2, 2], [0, 3, 2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 2, 3]]
[[[0.0, 0.0, 1.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 1.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0

In [17]:
l = [torch.tensor(0.4473), torch.tensor(0.4485)]
torch.stack(l)

tensor([0.4473, 0.4485])

In [19]:
ll_model = hl_model.get_ll_model(seed=42).to(hl_model.device)
ll_model.device = hl_model.device
model_pair = StrictIITModelPair(hl_model=hl_model, ll_model=ll_model, corr=corr, training_args=training_args)

model_pair.train(
    train_set=train_set,
    test_set=test_set,
    epochs=n_epochs,
)

Moving model to device:  mps
training_args={'batch_size': 256, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_scheduler.LinearLR'>, 'scheduler_val_metric': ['val/accuracy', 'val/IIA'], 'scheduler_mode': 'max', 'scheduler_kwargs': {'start_factor': 1, 'end_factor': 0.2, 'total_iters': 100}, 'clip_grad_norm': 1.0, 'seed': 42, 'detach_while_caching': True, 'optimizer_cls': <class 'torch.optim.adam.Adam'>, 'optimizer_kwargs': {'lr': 0.001, 'betas': (0.9, 0.9)}, 'atol': 0.05, 'use_single_loss': True, 'iit_weight': 1.0, 'behavior_weight': 0.4, 'val_IIA_sampling': 'all', 'strict_weight': 0.4, 'siit_sampling': 'sample_all'}


Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.6369, train/behavior_loss: 0.2491, train/strict_loss: 0.0997, val/iit_loss: 0.4413, val/IIA: 95.80%, val/accuracy: 95.82%, val/strict_accuracy: 95.82%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.3699, train/behavior_loss: 0.1422, train/strict_loss: 0.0569, val/iit_loss: 0.2894, val/IIA: 95.76%, val/accuracy: 95.76%, val/strict_accuracy: 95.76%
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.2525, train/behavior_loss: 0.0887, train/strict_loss: 0.0359, val/iit_loss: 0.2065, val/IIA: 96.08%, val/accuracy: 97.42%, val/strict_accuracy: 97.05%
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.1947, train/behavior_loss: 0.0579, train/strict_loss: 0.0250, val/iit_loss: 0.1687, val

In [8]:
save_model_to_dir(ll_model, f"./saved_models/{str(hl_model)}")

PosixPath('saved_models/parens_checker_model')

In [9]:
i = 6
input = torch.tensor(dataset.get_dataset()[i]['tokens'])
print(input)
print(torch.round(torch.nn.functional.softmax(model_pair.ll_model.forward(input), dim=-1)))
print(hl_model((torch.tensor(input)[None,:], None, None))) 
print(dataset.get_dataset()[i]['labels'])

tensor([0, 3, 2, 2, 3, 3, 3, 3, 2, 2, 3, 3, 2, 2, 2])
tensor([[[0., 0., 1., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.]]], device='mps:0', grad_fn=<RoundBackward0>)
tensor([[[0., 0., 1., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.]]], device='mps:0')
[[0.0, 0.0, 1.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0

  print(hl_model((torch.tensor(input)[None,:], None, None)))


# Left > Right

In [20]:
hl_model = HighLevelLeftGreater()
corr = hl_model.get_correspondence()
dataset = LeftGreaterDataset(N_samples=n_samples, n_ctx=hl_model.get_ll_model_cfg().n_ctx, seed=42)
train_set, test_set = dataset.get_IIT_train_test_set()
print(hl_model.hook_dict)
print(list(corr.keys()))

making IIT dataset
{'input_hook': HookPoint(), 'paren_counts_hook': HookPoint(), 'mlp0_hook': HookPoint()}
[input_hook, paren_counts_hook, mlp0_hook]


In [21]:
ll_model = hl_model.get_ll_model(seed=42).to(hl_model.device)
ll_model.device = hl_model.device

model_pair = StrictIITModelPair(hl_model=hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    # optimizer_cls=torch.optim.AdamW,
    epochs=n_epochs,
)

Moving model to device:  mps
training_args={'batch_size': 256, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_scheduler.LinearLR'>, 'scheduler_val_metric': ['val/accuracy', 'val/IIA'], 'scheduler_mode': 'max', 'scheduler_kwargs': {'start_factor': 1, 'end_factor': 0.2, 'total_iters': 100}, 'clip_grad_norm': 1.0, 'seed': 42, 'detach_while_caching': True, 'optimizer_cls': <class 'torch.optim.adam.Adam'>, 'optimizer_kwargs': {'lr': 0.001, 'betas': (0.9, 0.9)}, 'atol': 0.05, 'use_single_loss': True, 'iit_weight': 1.0, 'behavior_weight': 0.4, 'val_IIA_sampling': 'all', 'strict_weight': 0.4, 'siit_sampling': 'sample_all'}


Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/22 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.1604, train/behavior_loss: 0.4518, train/strict_loss: 0.1816, val/iit_loss: 1.0236, val/IIA: 56.51%, val/accuracy: 66.84%, val/strict_accuracy: 64.26%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.9398, train/behavior_loss: 0.3593, train/strict_loss: 0.1456, val/iit_loss: 0.8540, val/IIA: 67.08%, val/accuracy: 80.21%, val/strict_accuracy: 76.87%


KeyboardInterrupt: 

In [12]:
save_model_to_dir(ll_model, f"./saved_models/{str(hl_model)}")

PosixPath('saved_models/left_greater_model')

In [13]:
# from poly_bench.utils import load_from_hf
# loaded_model = load_from_hf(model_name="left_greater_model")

# Duplicate remover
case 19 in circuits-bench

In [24]:
hl_model = HighLevelDuplicateRemover()
corr = hl_model.get_correspondence()
dataset = DuplicateRemoverDataset(N_samples=n_samples, n_ctx=hl_model.get_ll_model_cfg().n_ctx, seed=42)
train_set, test_set = dataset.get_IIT_train_test_set()
print(hl_model.hook_dict)
print(list(corr.keys()))

making IIT dataset
{'input_hook': HookPoint(), 'prev_token_hook': HookPoint(), 'prev_equal_hook': HookPoint(), 'output_hook': HookPoint()}
[input_hook, prev_token_hook, prev_equal_hook, output_hook]


In [29]:
ll_model = hl_model.get_ll_model(seed=42).to(hl_model.device)
ll_model.device = hl_model.device

model_pair = StrictIITModelPair(hl_model=hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    epochs=n_epochs,
)

Moving model to device:  mps
training_args={'batch_size': 256, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_scheduler.LinearLR'>, 'scheduler_val_metric': ['val/accuracy', 'val/IIA'], 'scheduler_mode': 'max', 'scheduler_kwargs': {'start_factor': 1, 'end_factor': 0.2, 'total_iters': 100}, 'clip_grad_norm': 1.0, 'seed': 42, 'detach_while_caching': True, 'optimizer_cls': <class 'torch.optim.adam.Adam'>, 'optimizer_kwargs': {'lr': 0.001, 'betas': (0.9, 0.9)}, 'atol': 0.05, 'use_single_loss': True, 'iit_weight': 1.0, 'behavior_weight': 0.4, 'val_IIA_sampling': 'all', 'strict_weight': 0.4, 'siit_sampling': 'sample_all'}


Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.2145, train/behavior_loss: 0.4685, train/strict_loss: 0.1875, val/iit_loss: 0.9727, val/IIA: 71.06%, val/accuracy: 71.06%, val/strict_accuracy: 71.06%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.8526, train/behavior_loss: 0.3247, train/strict_loss: 0.1334, val/iit_loss: 0.7254, val/IIA: 77.13%, val/accuracy: 83.15%, val/strict_accuracy: 81.46%


KeyboardInterrupt: 

In [16]:
save_model_to_dir(ll_model, f"./saved_models/{str(hl_model)}")

PosixPath('saved_models/duplicate_remover_model')

# Unique Extractor

In [30]:
hl_model = HighLevelUniqueExtractor()
corr = hl_model.get_correspondence()
dataset = UniqueExtractorDataset(N_samples=n_samples, n_ctx=hl_model.get_ll_model_cfg().n_ctx, seed=42)
train_set, test_set = dataset.get_IIT_train_test_set()
print(hl_model.hook_dict)
print(list(corr.keys()))

making IIT dataset
{'input_hook': HookPoint(), 'counter_head': HookPoint(), 'appeared_mlp': HookPoint(), 'mask_mlp': HookPoint(), 'output_mlp': HookPoint()}
[input_hook, counter_head, appeared_mlp, mask_mlp, output_mlp]


In [32]:
ll_model = hl_model.get_ll_model(seed=42).to(hl_model.device)
ll_model.device = hl_model.device

model_pair = StrictIITModelPair(hl_model=hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    # optimizer_cls=torch.optim.AdamW,
    epochs=n_epochs,
)

Moving model to device:  mps
training_args={'batch_size': 256, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_scheduler.LinearLR'>, 'scheduler_val_metric': ['val/accuracy', 'val/IIA'], 'scheduler_mode': 'max', 'scheduler_kwargs': {'start_factor': 1, 'end_factor': 0.2, 'total_iters': 100}, 'clip_grad_norm': 1.0, 'seed': 42, 'detach_while_caching': True, 'optimizer_cls': <class 'torch.optim.adam.Adam'>, 'optimizer_kwargs': {'lr': 0.001, 'betas': (0.9, 0.9)}, 'atol': 0.05, 'use_single_loss': True, 'iit_weight': 1.0, 'behavior_weight': 0.4, 'val_IIA_sampling': 'all', 'strict_weight': 0.4, 'siit_sampling': 'sample_all'}


Training Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Epoch 1: lr: 9.92e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.7626, train/behavior_loss: 0.2869, train/strict_loss: 0.1154, val/iit_loss: 0.5199, val/IIA: 88.10%, val/accuracy: 92.56%, val/strict_accuracy: 92.56%
Epoch 2: lr: 9.84e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.3963, train/behavior_loss: 0.1192, train/strict_loss: 0.0567, val/iit_loss: 0.2905, val/IIA: 93.55%, val/accuracy: 98.89%, val/strict_accuracy: 98.52%
Epoch 3: lr: 9.76e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.2357, train/behavior_loss: 0.0508, train/strict_loss: 0.0243, val/iit_loss: 0.1563, val/IIA: 96.20%, val/accuracy: 98.29%, val/strict_accuracy: 98.12%
Epoch 4: lr: 9.68e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 0.1216, train/behavior_loss: 0.0272, train/strict_loss: 0.0139, val/iit_loss: 0.0972, val

KeyboardInterrupt: 

In [19]:
save_model_to_dir(ll_model, f"./saved_models/{str(hl_model)}")

PosixPath('saved_models/unique_extractor_model')

# Push to HF

In [20]:
save_to_hf(local_dir="saved_models", message="pushes all monosemantic models")

model.safetensors:   0%|          | 0.00/46.5k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/10.8k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/68.1k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/46.4k [00:00<?, ?B/s]

Upload 5 LFS files:   0%|          | 0/5 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/161k [00:00<?, ?B/s]