In [1]:
%load_ext autoreload
%autoreload 2

In [24]:
import torch

from iit.model_pairs.strict_iit_model_pair import StrictIITModelPair

from poly_bench.cases.paren_checker import HighLevelParensBalanceChecker, test_HL_parens_balancer_components, BalancedParensDataset
from poly_bench.cases.left_greater import HighLevelLeftGreater, test_HL_left_greater_components, LeftGreaterDataset
from poly_bench.cases.duplicate_remover import HighLevelDuplicateRemover, test_HL_duplicate_remover_components, DuplicateRemoverDataset
from poly_bench.cases.unique_extractor import HighLevelUniqueExtractor, test_HL_unique_extractor_components, UniqueExtractorDataset
from poly_bench.utils import save_model_to_dir, save_to_hf


In [3]:
n_epochs = 1_000
n_samples = 1_000
#iit_weight = 1. / siit_weight = 0.4 / behavior_weight = 1. works!
training_args = {
    "batch_size": 256,
    "num_workers": 0,
    "use_single_loss": True,
    "behavior_weight": 0.4, #basically doubles the strict weight's job.
    "iit_weight": 1.,
    "strict_weight": 0.4,
    "clip_grad_norm": 1.0,
    "iit_weight_schedule" : lambda s, i: s,
    "strict_weight_schedule" : lambda s, i: s,
    "behavior_weight_schedule" : lambda s, i: s, #0.955*s if 0.955**i > 0.01 else s, #have behavior weight decay over time
    "early_stop" : True,
    "lr_scheduler": torch.optim.lr_scheduler.LinearLR,
    "scheduler_kwargs": dict(start_factor=1, end_factor=0.2, total_iters=int(n_epochs)),
    "optimizer_kwargs": dict(lr=1e-3, betas=(0.9, 0.9)),
    "scheduler_val_metric": ["val/accuracy", "val/IIA"], #for ReduceLRonPlateau
    "scheduler_mode": "max", #for ReduceLRonPlateau
    "siit_sampling" : "sample_all",
    "seed" : 42
}

# Paren Checker

In [4]:
test_HL_parens_balancer_components()

All Balance tests passed!


True

In [5]:
hl_model = HighLevelParensBalanceChecker()
corr = hl_model.get_correspondence()
dataset = BalancedParensDataset(N_samples=n_samples, n_ctx=hl_model.get_ll_model_cfg().n_ctx, seed=42)
train_set, test_set = dataset.get_IIT_train_test_set()

making IIT dataset


In [6]:
i = 6
print(dataset.get_dataset()[i]['tokens'], dataset.get_dataset()[i]['labels'])

[0, 2, 3, 2, 2, 2, 3, 3, 2, 2, 3, 3, 3, 2, 3] [[0.0, 0.0, 1.0, 0.0], [1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]]


In [7]:
print(dataset.get_dataset().shape)
print(dataset.get_dataset()[:10]['tokens'])
print(dataset.get_dataset()[:10]['labels'])
for i in range(10):
    tokens, labels, hl_outputs = dataset.get_dataset()[i]['tokens'], dataset.get_dataset()[i]['labels'], hl_model((torch.tensor(dataset.get_dataset()[i]['tokens'])[None,:], None, None))
    nonzero = (torch.tensor(labels) - hl_outputs[0].cpu()).nonzero()
    if nonzero.numel() > 0:
        print(tokens, torch.unique(nonzero[:,0]))
        bad_indices = torch.unique(nonzero[:,0]).tolist()
        for idx in bad_indices:
            print(labels[idx], hl_outputs[0,idx])

(1000, 4)
[[0, 3, 3, 2, 3, 2, 2, 3, 3, 3, 3, 2, 2, 2, 2], [0, 3, 2, 3, 3, 2, 3, 3, 2, 3, 3, 3, 3, 3, 3], [0, 3, 2, 2, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 3], [0, 2, 2, 3, 2, 3, 2, 2, 3, 2, 2, 3, 3, 3, 3], [0, 2, 2, 3, 2, 3, 2, 3, 2, 2, 3, 3, 3, 2, 3], [0, 2, 2, 3, 3, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2], [0, 2, 3, 2, 2, 2, 3, 3, 2, 2, 3, 3, 3, 2, 3], [0, 3, 3, 3, 2, 3, 2, 3, 2, 3, 2, 2, 2, 2, 3], [0, 3, 3, 2, 2, 2, 3, 3, 3, 3, 2, 2, 3, 3, 3], [0, 3, 2, 2, 3, 2, 3, 2, 2, 3, 2, 3, 3, 3, 2]]
[[[0.0, 0.0, 1.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 1.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0

In [8]:

ll_model = hl_model.get_ll_model(seed=42)
model_pair = StrictIITModelPair(hl_model=hl_model, ll_model=ll_model, corr=corr, training_args=training_args)

model_pair.train(
    train_set=train_set,
    test_set=test_set,
    # optimizer_cls=torch.optim.AdamW,
    epochs=n_epochs,
)

training_args={'batch_size': 256, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_scheduler.LinearLR'>, 'scheduler_val_metric': ['val/accuracy', 'val/IIA'], 'scheduler_mode': 'max', 'scheduler_kwargs': {'start_factor': 1, 'end_factor': 0.2, 'total_iters': 1000}, 'optimizer_kwargs': {'lr': 0.001, 'betas': (0.9, 0.9)}, 'clip_grad_norm': 1.0, 'seed': 42, 'detach_while_caching': True, 'lr': 0.001, 'atol': 0.05, 'use_single_loss': True, 'iit_weight': 1.0, 'behavior_weight': 0.4, 'strict_weight': 0.4, 'siit_sampling': 'sample_all', 'iit_weight_schedule': <function <lambda> at 0x14779c680>, 'strict_weight_schedule': <function <lambda> at 0x14779c540>, 'behavior_weight_schedule': <function <lambda> at 0x14779c5e0>}


Training Epochs:   0%|          | 0/1000 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 1: lr: 9.99e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 9.46e-01, train/behavior_loss: 3.71e-01, train/strict_loss: 1.48e-01, val/iit_loss: 7.37e-01, val/IIA: 87.83, val/accuracy: 87.64, val/strict_accuracy: 87.64
Epoch 2: lr: 9.98e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 6.92e-01, train/behavior_loss: 2.73e-01, train/strict_loss: 1.09e-01, val/iit_loss: 6.31e-01, val/IIA: 94.49, val/accuracy: 94.30, val/strict_accuracy: 94.30
Epoch 3: lr: 9.98e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 6.08e-01, train/behavior_loss: 2.41e-01, train/strict_loss: 9.63e-02, val/iit_loss: 5.70e-01, val/IIA: 94.49, val/accuracy: 94.30, val/strict_accuracy: 94.30
Epoch 4: lr: 9.97e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 5.92e-01, train/behavior_loss: 2.22e-01, train/strict_loss: 8.86e-02, val/

In [9]:
i = 6
input = torch.tensor(dataset.get_dataset()[i]['tokens'])
print(input)
print(torch.round(torch.nn.functional.softmax(model_pair.ll_model.forward(input), dim=-1)))
print(hl_model((torch.tensor(input)[None,:], None, None))) 
print(dataset.get_dataset()[i]['labels'])

tensor([0, 2, 3, 2, 2, 2, 3, 3, 2, 2, 3, 3, 3, 2, 3])
tensor([[[0., 0., 1., 0.],
         [1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [1., 0., 0., 0.],
         [0., 1., 0., 0.]]], device='mps:0', grad_fn=<RoundBackward0>)
tensor([[[0., 0., 1., 0.],
         [1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [1., 0., 0., 0.],
         [0., 1., 0., 0.]]], device='mps:0')
[[0.0, 0.0, 1.0, 0.0], [1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [1.0, 0

  print(hl_model((torch.tensor(input)[None,:], None, None)))


# Left > Right

In [20]:
test_HL_left_greater_components()

All left greater tests passed!


True

In [21]:
hl_model = HighLevelLeftGreater()
corr = hl_model.get_correspondence()
dataset = LeftGreaterDataset(N_samples=n_samples, n_ctx=hl_model.get_ll_model_cfg().n_ctx, seed=42)
train_set, test_set = dataset.get_IIT_train_test_set()
print(hl_model.hook_dict)
print(list(corr.keys()))

making IIT dataset
{'input_hook': HookPoint(), 'paren_counts_hook': HookPoint(), 'mlp0_hook': HookPoint()}
[input_hook, paren_counts_hook, mlp0_hook]


In [22]:
ll_model = hl_model.get_ll_model(seed=42)

model_pair = StrictIITModelPair(hl_model=hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    # optimizer_cls=torch.optim.AdamW,
    epochs=n_epochs,
)

training_args={'batch_size': 256, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_scheduler.LinearLR'>, 'scheduler_val_metric': ['val/accuracy', 'val/IIA'], 'scheduler_mode': 'max', 'scheduler_kwargs': {'start_factor': 1, 'end_factor': 0.2, 'total_iters': 1000}, 'optimizer_kwargs': {'lr': 0.001, 'betas': (0.9, 0.9)}, 'clip_grad_norm': 1.0, 'seed': 42, 'detach_while_caching': True, 'lr': 0.001, 'atol': 0.05, 'use_single_loss': True, 'iit_weight': 1.0, 'behavior_weight': 0.4, 'strict_weight': 0.4, 'siit_sampling': 'sample_all', 'iit_weight_schedule': <function <lambda> at 0x147783600>, 'strict_weight_schedule': <function <lambda> at 0x166967ba0>, 'behavior_weight_schedule': <function <lambda> at 0x166967600>}


Training Epochs:   0%|          | 0/1000 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 1: lr: 9.99e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.33e+00, train/behavior_loss: 5.24e-01, train/strict_loss: 2.10e-01, val/iit_loss: 1.26e+00, val/IIA: 42.66, val/accuracy: 53.87, val/strict_accuracy: 52.04
Epoch 2: lr: 9.98e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.21e+00, train/behavior_loss: 4.68e-01, train/strict_loss: 1.88e-01, val/iit_loss: 1.16e+00, val/IIA: 55.53, val/accuracy: 66.10, val/strict_accuracy: 65.73
Epoch 3: lr: 9.98e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.13e+00, train/behavior_loss: 4.39e-01, train/strict_loss: 1.77e-01, val/iit_loss: 1.10e+00, val/IIA: 60.64, val/accuracy: 68.91, val/strict_accuracy: 67.47
Epoch 4: lr: 9.97e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.10e+00, train/behavior_loss: 4.24e-01, train/strict_loss: 1.71e-01, val/

In [34]:
!huggingface-cli login

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)



    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): Traceback (most recent call last):
  File "/Users/evananders/far_cluster/polysemantic-benchmark/.venv/bin/huggingface-cli", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/Users/evananders/far_cluster/polysemantic-benchmark/.venv/lib/python3.12/site-pack

In [36]:
save_model_to_dir(ll_model, "./saved_models/left_greater_model")
save_to_hf(local_dir="saved_models", message="first push of left greater model")

In [41]:
from poly_bench.utils import load_from_hf
loaded_model = load_from_hf(model_name="left_greater_model")

# Duplicate remover
case 19 in circuits-bench

In [13]:

test_HL_duplicate_remover_components()

All DuplicateRemover tests passed!




True

In [14]:
hl_model = HighLevelDuplicateRemover()
corr = hl_model.get_correspondence()
dataset = DuplicateRemoverDataset(N_samples=n_samples, n_ctx=hl_model.get_ll_model_cfg().n_ctx, seed=42)
train_set, test_set = dataset.get_IIT_train_test_set()
print(hl_model.hook_dict)
print(list(corr.keys()))

making IIT dataset
{'input_hook': HookPoint(), 'prev_token_hook': HookPoint(), 'prev_equal_hook': HookPoint(), 'output_hook': HookPoint()}
[input_hook, prev_token_hook, prev_equal_hook, output_hook]


In [15]:
ll_model = hl_model.get_ll_model(seed=42)

model_pair = StrictIITModelPair(hl_model=hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    # optimizer_cls=torch.optim.AdamW,
    epochs=n_epochs,
)

training_args={'batch_size': 256, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_scheduler.LinearLR'>, 'scheduler_val_metric': ['val/accuracy', 'val/IIA'], 'scheduler_mode': 'max', 'scheduler_kwargs': {'start_factor': 1, 'end_factor': 0.2, 'total_iters': 1000}, 'optimizer_kwargs': {'lr': 0.001, 'betas': (0.9, 0.9)}, 'clip_grad_norm': 1.0, 'seed': 42, 'detach_while_caching': True, 'lr': 0.001, 'atol': 0.05, 'use_single_loss': True, 'iit_weight': 1.0, 'behavior_weight': 0.4, 'strict_weight': 0.4, 'siit_sampling': 'sample_all', 'iit_weight_schedule': <function <lambda> at 0x14779c680>, 'strict_weight_schedule': <function <lambda> at 0x14779c540>, 'behavior_weight_schedule': <function <lambda> at 0x14779c5e0>}


Training Epochs:   0%|          | 0/1000 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 1: lr: 9.99e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.53e+00, train/behavior_loss: 6.09e-01, train/strict_loss: 2.44e-01, val/iit_loss: 1.47e+00, val/IIA: 53.28, val/accuracy: 58.12, val/strict_accuracy: 58.69
Epoch 2: lr: 9.98e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.38e+00, train/behavior_loss: 5.34e-01, train/strict_loss: 2.14e-01, val/iit_loss: 1.39e+00, val/IIA: 61.48, val/accuracy: 71.06, val/strict_accuracy: 71.08
Epoch 3: lr: 9.98e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.28e+00, train/behavior_loss: 4.87e-01, train/strict_loss: 1.95e-01, val/iit_loss: 1.34e+00, val/IIA: 61.82, val/accuracy: 71.09, val/strict_accuracy: 71.09
Epoch 4: lr: 9.97e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.26e+00, train/behavior_loss: 4.56e-01, train/strict_loss: 1.82e-01, val/

# Unique Extractor

In [16]:
n_epochs = 1_000
n_samples = 1_000
#iit_weight = 1. / siit_weight = 0.4 / behavior_weight = 1. works!
training_args = {
    "batch_size": 256,
    "num_workers": 0,
    "use_single_loss": True,
    "behavior_weight": 0.4, #basically doubles the strict weight's job.
    "iit_weight": 1.,
    "strict_weight": 0.4,
    "clip_grad_norm": 1.0,
    "iit_weight_schedule" : lambda s, i: s,
    "strict_weight_schedule" : lambda s, i: s,
    "behavior_weight_schedule" : lambda s, i: s, #0.955*s if 0.955**i > 0.01 else s, #have behavior weight decay over time
    "early_stop" : True,
    "lr_scheduler": torch.optim.lr_scheduler.LinearLR,
    "scheduler_kwargs": dict(start_factor=1, end_factor=0.2, total_iters=int(n_epochs)),
    "optimizer_kwargs": dict(lr=1e-3, betas=(0.9, 0.9)),
    "scheduler_val_metric": ["val/accuracy", "val/IIA"], #for ReduceLRonPlateau
    "scheduler_mode": "max", #for ReduceLRonPlateau
    "siit_sampling" : "sample_all",
    "seed" : 42
}

In [17]:
test_HL_unique_extractor_components()

All UniqueExtractor tests passed!


True

In [18]:
hl_model = HighLevelUniqueExtractor()
corr = hl_model.get_correspondence()
dataset = UniqueExtractorDataset(N_samples=n_samples, n_ctx=hl_model.get_ll_model_cfg().n_ctx, seed=42)
train_set, test_set = dataset.get_IIT_train_test_set()
print(hl_model.hook_dict)
print(list(corr.keys()))

making IIT dataset
{'input_hook': HookPoint(), 'counter_head': HookPoint(), 'appeared_mlp': HookPoint(), 'mask_mlp': HookPoint(), 'output_mlp': HookPoint()}
[input_hook, counter_head, appeared_mlp, mask_mlp, output_mlp]


In [19]:
ll_model = hl_model.get_ll_model(seed=42)

model_pair = StrictIITModelPair(hl_model=hl_model, ll_model=ll_model, corr=corr, training_args=training_args)
model_pair.train(
    train_set=train_set,
    test_set=test_set,
    # optimizer_cls=torch.optim.AdamW,
    epochs=n_epochs,
)

training_args={'batch_size': 256, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': <class 'torch.optim.lr_scheduler.LinearLR'>, 'scheduler_val_metric': ['val/accuracy', 'val/IIA'], 'scheduler_mode': 'max', 'scheduler_kwargs': {'start_factor': 1, 'end_factor': 0.2, 'total_iters': 1000}, 'optimizer_kwargs': {'lr': 0.001, 'betas': (0.9, 0.9)}, 'clip_grad_norm': 1.0, 'seed': 42, 'detach_while_caching': True, 'lr': 0.001, 'atol': 0.05, 'use_single_loss': True, 'iit_weight': 1.0, 'behavior_weight': 0.4, 'strict_weight': 0.4, 'siit_sampling': 'sample_all', 'iit_weight_schedule': <function <lambda> at 0x147783600>, 'strict_weight_schedule': <function <lambda> at 0x166967ba0>, 'behavior_weight_schedule': <function <lambda> at 0x166967600>}


Training Epochs:   0%|          | 0/1000 [00:00<?, ?it/s]

Training Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 1: lr: 9.99e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 1.20e+00, train/behavior_loss: 4.77e-01, train/strict_loss: 1.91e-01, val/iit_loss: 9.70e-01, val/IIA: 77.61, val/accuracy: 80.10, val/strict_accuracy: 80.10
Epoch 2: lr: 9.98e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 9.05e-01, train/behavior_loss: 3.58e-01, train/strict_loss: 1.43e-01, val/iit_loss: 8.46e-01, val/IIA: 82.20, val/accuracy: 84.66, val/strict_accuracy: 84.66
Epoch 3: lr: 9.98e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 8.45e-01, train/behavior_loss: 3.17e-01, train/strict_loss: 1.27e-01, val/iit_loss: 7.80e-01, val/IIA: 83.32, val/accuracy: 86.87, val/strict_accuracy: 86.76
Epoch 4: lr: 9.97e-04, iit_weight: 1.00e+00, behavior_weight: 4.00e-01, strict_weight: 4.00e-01, train/iit_loss: 7.25e-01, train/behavior_loss: 2.83e-01, train/strict_loss: 1.14e-01, val/