In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import iit
print(iit.__file__)

/Users/evananders/far_cluster/iit/iit/__init__.py


In [3]:
import transformer_lens as tl
import numpy as np
import torch as t
import wandb

from iit.utils.iit_dataset import train_test_split
from iit.utils.iit_dataset import IITDataset
from iit.utils.correspondence import Correspondence
from iit.utils.argparsing import IOIArgParseNamespace

device = t.device("cuda" if t.cuda.is_available() else "cpu")

In [4]:
D_MODEL = 48
N_CTX = 23
N_LAYERS = 3
N_HEADS = 4
D_VOCAB = 4

#Want to specify this somewhere central, e.g., iit.tasks.ioi but iit.tasks.parens
ll_cfg = tl.HookedTransformerConfig(
        n_layers = N_LAYERS,
        d_model = D_MODEL,
        n_ctx = N_CTX,
        d_head = D_MODEL // N_HEADS,
        d_vocab = D_VOCAB,
        act_fn = "relu",
)

ll_model = tl.HookedTransformer(ll_cfg).to(device)

Moving model to device:  cpu


In [5]:
from paren_checker import HighLevelParensBalanceChecker, TwoTaskParensDataset, LeftGreaterParensDataset, BalancedParensDataset
from torch.utils.data import Dataset

hl_model = HighLevelParensBalanceChecker(device=device)
dataset = TwoTaskParensDataset(
    N_samples = 20_000,
    n_ctx = N_CTX,
    seed = 42,
)

class CustomDataset(Dataset):
    def __init__(self, data, targets, markers):
        """
        Args:
            data (list or numpy array): List or array of input data.
            targets (list or numpy array): List or array of target data.
        """
        self.data = t.tensor(data).to(int)
        self.targets = t.tensor(targets).to(int)
        self.markers = t.tensor(markers).to(int)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        """
        Args:
            idx (int): Index
        Returns:
            tuple: (input tensor, target tensor)
        """
        return self.data[idx], self.targets[idx], self.markers[idx]


decorated_dset = CustomDataset(
    data = dataset.get_dataset()['tokens'],
    targets = np.array(dataset.get_dataset()['labels'])[:, None],
    markers = np.array(dataset.get_dataset()['markers'])[:, None]
)


print("making IIT dataset")
train_dataset, test_dataset = train_test_split(
    decorated_dset, test_size=0.2, random_state=42
)
train_set = IITDataset(train_dataset, train_dataset, seed=0)
test_set = IITDataset(test_dataset, test_dataset, seed=0)

making IIT dataset


In [6]:
print(train_set[0])

((tensor([3, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2]), tensor([0]), tensor([0])), (tensor([3, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 2]), tensor([0]), tensor([2])))


In [39]:
from typing import Callable
from dataclasses import asdict
from iit.utils.index import Ix, TorchIndex
from iit_repo_paren_checker import ParensModelPair

all_attns = [f"blocks.{i}.attn.hook_z" for i in range(ll_cfg.n_layers)]
all_mlps = [f"blocks.{i}.mlp.hook_post" for i in range(ll_cfg.n_layers)]
all_nodes_hook = "blocks.0.hook_resid_pre"
head_index: Callable[[int], TorchIndex] = lambda idx: Ix[[None, None, idx, None]]
#make correlation
corr_dict = {
        'input_hook' :           [(all_nodes_hook,  Ix[[None]], None)],
        'left_parens_hook' :     [(all_attns[0],    head_index(0), None)],
        'right_parens_hook' :    [(all_attns[0],    head_index(1), None)],
        'task_hook':             [(all_attns[0],    head_index(2), None)],
        'mlp0_hook':             [(all_mlps[0],     Ix[[None]], None)],
        'mlp1_hook' :            [(all_mlps[1],     Ix[[None]], None)],
        'horizon_lookback_hook': [(all_attns[2],    head_index(3), None)],
        'output_check_hook' :    [(all_mlps[2],     Ix[[None]], None)]
    }


print("making model pair")
corr = Correspondence.make_corr_from_dict(corr_dict, suffixes={"attn": "attn.hook_z", "mlp": "mlp.hook_post"})

train_args = IOIArgParseNamespace(
    include_mlp = True,
    use_wandb = False,
    num_samples = 20_000,
    batch_size = 128,
    next_token = False,
    
    epochs = 100,
    lr = 3e-5,
    iit = 1,
    b = 0.5,
    s = 0.5,
    clip_grad_norm = False,
    use_single_loss = True,
    save_to_wandb = False,
    device = device,
)
train_args_dict = asdict(train_args)
train_args_dict['siit_sample_strategy'] = 'sample_all' #or individual

model_pair = ParensModelPair(
    ll_model=ll_model,
    hl_model=hl_model,
    corr=corr,
    training_args=train_args_dict,
)
# for k, v in model_pair.training_args.items():
#     print(k, v)

making model pair
{'input_hook': HookPoint(), 'left_parens_hook': HookPoint(), 'right_parens_hook': HookPoint(), 'task_hook': HookPoint(), 'greater_hook': HookPoint(), 'elevation_hook': HookPoint(), 'mlp0_hook': HookPoint(), 'mlp1_hook': HookPoint(), 'horizon_lookback_hook': HookPoint(), 'output_check_hook': HookPoint()}
dict_keys([input_hook, left_parens_hook, right_parens_hook, task_hook, mlp0_hook, mlp1_hook, horizon_lookback_hook, output_check_hook])


In [40]:
for node in model_pair.nodes_not_in_circuit:
    print(node)

LLNode(name='blocks.0.attn.hook_z', index=[:, :, 3, :], subspace=None)
LLNode(name='blocks.1.attn.hook_z', index=[:, :, 0, :], subspace=None)
LLNode(name='blocks.1.attn.hook_z', index=[:, :, 1, :], subspace=None)
LLNode(name='blocks.1.attn.hook_z', index=[:, :, 2, :], subspace=None)
LLNode(name='blocks.1.attn.hook_z', index=[:, :, 3, :], subspace=None)
LLNode(name='blocks.2.attn.hook_z', index=[:, :, 0, :], subspace=None)
LLNode(name='blocks.2.attn.hook_z', index=[:, :, 1, :], subspace=None)
LLNode(name='blocks.2.attn.hook_z', index=[:, :, 2, :], subspace=None)


In [41]:
print("training model pair")
model_pair.train(train_set, test_set, epochs=train_args.epochs, use_wandb=train_args.use_wandb, optimizer_cls=t.optim.AdamW, optimizer_kwargs={'weight_decay' : 1e-3})

training model pair
training_args={'next_token': False, 'non_ioi_thresh': 0.65, 'use_per_token_check': False, 'batch_size': 128, 'lr': 3e-05, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': None, 'scheduler_val_metric': ['val/accuracy', 'val/IIA'], 'scheduler_mode': 'max', 'clip_grad_norm': False, 'seed': 0, 'detach_while_caching': True, 'atol': 0.05, 'use_single_loss': True, 'iit_weight': 1.0, 'behavior_weight': 1.0, 'strict_weight': 1.0, 'output_dir': './results', 'include_mlp': True, 'use_wandb': False, 'num_samples': 20000, 'device': device(type='cpu'), 'weights': '100_100_40', 'mean': True, 'load_from_wandb': False, 'epochs': 100, 'iit': 1, 'b': 0.5, 's': 0.5, 'save_to_wandb': False, 'siit_sample_strategy': 'sample_all'}


100%|██████████| 125/125 [00:23<00:00,  5.26it/s]
  1%|          | 1/100 [00:30<50:26, 30.57s/it]


Epoch 0: train/iit_loss: 0.0314, train/behavior_loss: 0.0014, train/strict_loss: 0.0026, val/iit_loss: 0.0343, val/IIA: 98.75%, val/accuracy: 100.00%, val/strict_accuracy: 99.99%, 


100%|██████████| 125/125 [00:23<00:00,  5.35it/s]
  2%|▏         | 2/100 [01:00<49:03, 30.04s/it]


Epoch 1: train/iit_loss: 0.0280, train/behavior_loss: 0.0014, train/strict_loss: 0.0031, val/iit_loss: 0.0203, val/IIA: 99.34%, val/accuracy: 100.00%, val/strict_accuracy: 99.99%, 


100%|██████████| 125/125 [00:21<00:00,  5.72it/s]
  3%|▎         | 3/100 [01:27<46:31, 28.78s/it]


Epoch 2: train/iit_loss: 0.0298, train/behavior_loss: 0.0013, train/strict_loss: 0.0023, val/iit_loss: 0.0255, val/IIA: 99.15%, val/accuracy: 100.00%, val/strict_accuracy: 99.99%, 


100%|██████████| 125/125 [00:20<00:00,  6.09it/s]
  4%|▍         | 4/100 [01:53<44:28, 27.80s/it]


Epoch 3: train/iit_loss: 0.0276, train/behavior_loss: 0.0011, train/strict_loss: 0.0022, val/iit_loss: 0.0307, val/IIA: 99.12%, val/accuracy: 100.00%, val/strict_accuracy: 99.99%, 


100%|██████████| 125/125 [00:21<00:00,  5.91it/s]
  5%|▌         | 5/100 [02:20<43:37, 27.56s/it]


Epoch 4: train/iit_loss: 0.0266, train/behavior_loss: 0.0011, train/strict_loss: 0.0020, val/iit_loss: 0.0253, val/IIA: 99.15%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.10it/s]
  6%|▌         | 6/100 [02:47<42:37, 27.21s/it]


Epoch 5: train/iit_loss: 0.0287, train/behavior_loss: 0.0010, train/strict_loss: 0.0020, val/iit_loss: 0.0253, val/IIA: 99.07%, val/accuracy: 100.00%, val/strict_accuracy: 99.99%, 


100%|██████████| 125/125 [00:20<00:00,  6.02it/s]
  7%|▋         | 7/100 [03:13<41:41, 26.90s/it]


Epoch 6: train/iit_loss: 0.0227, train/behavior_loss: 0.0009, train/strict_loss: 0.0016, val/iit_loss: 0.0302, val/IIA: 98.83%, val/accuracy: 100.00%, val/strict_accuracy: 99.93%, 


100%|██████████| 125/125 [00:20<00:00,  6.01it/s]
  8%|▊         | 8/100 [03:40<41:00, 26.75s/it]


Epoch 7: train/iit_loss: 0.0251, train/behavior_loss: 0.0009, train/strict_loss: 0.0018, val/iit_loss: 0.0300, val/IIA: 99.02%, val/accuracy: 100.00%, val/strict_accuracy: 99.94%, 


100%|██████████| 125/125 [00:20<00:00,  6.16it/s]
  9%|▉         | 9/100 [04:06<40:11, 26.50s/it]


Epoch 8: train/iit_loss: 0.0196, train/behavior_loss: 0.0008, train/strict_loss: 0.0014, val/iit_loss: 0.0219, val/IIA: 99.15%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.04it/s]
 10%|█         | 10/100 [04:32<39:53, 26.59s/it]


Epoch 9: train/iit_loss: 0.0229, train/behavior_loss: 0.0007, train/strict_loss: 0.0014, val/iit_loss: 0.0209, val/IIA: 99.39%, val/accuracy: 100.00%, val/strict_accuracy: 99.97%, 


100%|██████████| 125/125 [00:21<00:00,  5.91it/s]
 11%|█         | 11/100 [04:59<39:30, 26.64s/it]


Epoch 10: train/iit_loss: 0.0212, train/behavior_loss: 0.0007, train/strict_loss: 0.0013, val/iit_loss: 0.0210, val/IIA: 99.19%, val/accuracy: 100.00%, val/strict_accuracy: 99.98%, 


100%|██████████| 125/125 [00:20<00:00,  6.14it/s]
 12%|█▏        | 12/100 [05:25<38:43, 26.40s/it]


Epoch 11: train/iit_loss: 0.0190, train/behavior_loss: 0.0006, train/strict_loss: 0.0013, val/iit_loss: 0.0245, val/IIA: 99.34%, val/accuracy: 100.00%, val/strict_accuracy: 99.99%, 


100%|██████████| 125/125 [00:21<00:00,  5.76it/s]
 13%|█▎        | 13/100 [05:52<38:41, 26.68s/it]


Epoch 12: train/iit_loss: 0.0170, train/behavior_loss: 0.0006, train/strict_loss: 0.0012, val/iit_loss: 0.0115, val/IIA: 99.68%, val/accuracy: 100.00%, val/strict_accuracy: 99.99%, 


100%|██████████| 125/125 [00:22<00:00,  5.68it/s]
 14%|█▍        | 14/100 [06:21<39:09, 27.32s/it]


Epoch 13: train/iit_loss: 0.0196, train/behavior_loss: 0.0006, train/strict_loss: 0.0011, val/iit_loss: 0.0234, val/IIA: 99.05%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:22<00:00,  5.68it/s]
 15%|█▌        | 15/100 [06:50<39:13, 27.69s/it]


Epoch 14: train/iit_loss: 0.0184, train/behavior_loss: 0.0005, train/strict_loss: 0.0010, val/iit_loss: 0.0141, val/IIA: 99.41%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:21<00:00,  5.78it/s]
 16%|█▌        | 16/100 [07:17<38:46, 27.69s/it]


Epoch 15: train/iit_loss: 0.0154, train/behavior_loss: 0.0005, train/strict_loss: 0.0014, val/iit_loss: 0.0124, val/IIA: 99.68%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.12it/s]
 17%|█▋        | 17/100 [07:43<37:35, 27.17s/it]


Epoch 16: train/iit_loss: 0.0181, train/behavior_loss: 0.0005, train/strict_loss: 0.0010, val/iit_loss: 0.0162, val/IIA: 99.41%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.13it/s]
 18%|█▊        | 18/100 [08:10<36:43, 26.87s/it]


Epoch 17: train/iit_loss: 0.0180, train/behavior_loss: 0.0005, train/strict_loss: 0.0012, val/iit_loss: 0.0142, val/IIA: 99.44%, val/accuracy: 100.00%, val/strict_accuracy: 99.98%, 


100%|██████████| 125/125 [00:20<00:00,  6.02it/s]
 19%|█▉        | 19/100 [08:37<36:26, 26.99s/it]


Epoch 18: train/iit_loss: 0.0163, train/behavior_loss: 0.0004, train/strict_loss: 0.0009, val/iit_loss: 0.0141, val/IIA: 99.49%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:21<00:00,  5.93it/s]
 20%|██        | 20/100 [09:04<35:55, 26.94s/it]


Epoch 19: train/iit_loss: 0.0184, train/behavior_loss: 0.0004, train/strict_loss: 0.0011, val/iit_loss: 0.0184, val/IIA: 99.29%, val/accuracy: 100.00%, val/strict_accuracy: 99.99%, 


100%|██████████| 125/125 [00:20<00:00,  6.10it/s]
 21%|██        | 21/100 [09:29<35:02, 26.62s/it]


Epoch 20: train/iit_loss: 0.0185, train/behavior_loss: 0.0004, train/strict_loss: 0.0010, val/iit_loss: 0.0178, val/IIA: 99.24%, val/accuracy: 100.00%, val/strict_accuracy: 99.99%, 


100%|██████████| 125/125 [00:20<00:00,  6.07it/s]
 22%|██▏       | 22/100 [09:56<34:24, 26.47s/it]


Epoch 21: train/iit_loss: 0.0215, train/behavior_loss: 0.0006, train/strict_loss: 0.0024, val/iit_loss: 0.0142, val/IIA: 99.58%, val/accuracy: 100.00%, val/strict_accuracy: 99.99%, 


100%|██████████| 125/125 [00:20<00:00,  5.98it/s]
 23%|██▎       | 23/100 [10:22<34:06, 26.58s/it]


Epoch 22: train/iit_loss: 0.0151, train/behavior_loss: 0.0004, train/strict_loss: 0.0008, val/iit_loss: 0.0168, val/IIA: 99.41%, val/accuracy: 100.00%, val/strict_accuracy: 99.99%, 


100%|██████████| 125/125 [00:21<00:00,  5.92it/s]
 24%|██▍       | 24/100 [10:49<33:45, 26.66s/it]


Epoch 23: train/iit_loss: 0.0157, train/behavior_loss: 0.0004, train/strict_loss: 0.0009, val/iit_loss: 0.0255, val/IIA: 99.24%, val/accuracy: 100.00%, val/strict_accuracy: 99.95%, 


100%|██████████| 125/125 [00:21<00:00,  5.95it/s]
 25%|██▌       | 25/100 [11:17<33:36, 26.89s/it]


Epoch 24: train/iit_loss: 0.0156, train/behavior_loss: 0.0004, train/strict_loss: 0.0008, val/iit_loss: 0.0133, val/IIA: 99.58%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:21<00:00,  5.77it/s]
 26%|██▌       | 26/100 [11:45<33:39, 27.29s/it]


Epoch 25: train/iit_loss: 0.0101, train/behavior_loss: 0.0003, train/strict_loss: 0.0005, val/iit_loss: 0.0119, val/IIA: 99.58%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.08it/s]
 27%|██▋       | 27/100 [12:11<32:54, 27.05s/it]


Epoch 26: train/iit_loss: 0.0121, train/behavior_loss: 0.0003, train/strict_loss: 0.0005, val/iit_loss: 0.0176, val/IIA: 99.32%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  5.98it/s]
 28%|██▊       | 28/100 [12:38<32:16, 26.90s/it]


Epoch 27: train/iit_loss: 0.0171, train/behavior_loss: 0.0003, train/strict_loss: 0.0007, val/iit_loss: 0.0154, val/IIA: 99.34%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:21<00:00,  5.94it/s]
 29%|██▉       | 29/100 [13:05<32:00, 27.05s/it]


Epoch 28: train/iit_loss: 0.0178, train/behavior_loss: 0.0003, train/strict_loss: 0.0009, val/iit_loss: 0.0124, val/IIA: 99.54%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:21<00:00,  5.74it/s]
 30%|███       | 30/100 [13:34<31:58, 27.41s/it]


Epoch 29: train/iit_loss: 0.0155, train/behavior_loss: 0.0003, train/strict_loss: 0.0010, val/iit_loss: 0.0103, val/IIA: 99.68%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:21<00:00,  5.89it/s]
 31%|███       | 31/100 [14:00<31:17, 27.22s/it]


Epoch 30: train/iit_loss: 0.0135, train/behavior_loss: 0.0003, train/strict_loss: 0.0005, val/iit_loss: 0.0151, val/IIA: 99.37%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.18it/s]
 32%|███▏      | 32/100 [14:26<30:21, 26.79s/it]


Epoch 31: train/iit_loss: 0.0130, train/behavior_loss: 0.0003, train/strict_loss: 0.0005, val/iit_loss: 0.0088, val/IIA: 99.61%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:21<00:00,  5.94it/s]
 33%|███▎      | 33/100 [14:54<30:06, 26.97s/it]


Epoch 32: train/iit_loss: 0.0118, train/behavior_loss: 0.0002, train/strict_loss: 0.0004, val/iit_loss: 0.0136, val/IIA: 99.46%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:21<00:00,  5.89it/s]
 34%|███▍      | 34/100 [15:20<29:39, 26.96s/it]


Epoch 33: train/iit_loss: 0.0134, train/behavior_loss: 0.0002, train/strict_loss: 0.0005, val/iit_loss: 0.0099, val/IIA: 99.63%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.13it/s]
 35%|███▌      | 35/100 [15:47<28:59, 26.76s/it]


Epoch 34: train/iit_loss: 0.0124, train/behavior_loss: 0.0002, train/strict_loss: 0.0005, val/iit_loss: 0.0147, val/IIA: 99.56%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:21<00:00,  5.94it/s]
 36%|███▌      | 36/100 [16:14<28:38, 26.85s/it]


Epoch 35: train/iit_loss: 0.0114, train/behavior_loss: 0.0002, train/strict_loss: 0.0005, val/iit_loss: 0.0115, val/IIA: 99.63%, val/accuracy: 100.00%, val/strict_accuracy: 99.99%, 


100%|██████████| 125/125 [00:21<00:00,  5.95it/s]
 37%|███▋      | 37/100 [16:40<28:07, 26.78s/it]


Epoch 36: train/iit_loss: 0.0134, train/behavior_loss: 0.0002, train/strict_loss: 0.0005, val/iit_loss: 0.0064, val/IIA: 99.76%, val/accuracy: 100.00%, val/strict_accuracy: 99.99%, 


100%|██████████| 125/125 [00:20<00:00,  6.15it/s]
 38%|███▊      | 38/100 [17:06<27:24, 26.53s/it]


Epoch 37: train/iit_loss: 0.0133, train/behavior_loss: 0.0002, train/strict_loss: 0.0005, val/iit_loss: 0.0163, val/IIA: 99.44%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.12it/s]
 39%|███▉      | 39/100 [17:32<26:49, 26.39s/it]


Epoch 38: train/iit_loss: 0.0100, train/behavior_loss: 0.0002, train/strict_loss: 0.0004, val/iit_loss: 0.0099, val/IIA: 99.63%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.01it/s]
 40%|████      | 40/100 [17:59<26:29, 26.50s/it]


Epoch 39: train/iit_loss: 0.0097, train/behavior_loss: 0.0002, train/strict_loss: 0.0004, val/iit_loss: 0.0079, val/IIA: 99.78%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  5.99it/s]
 41%|████      | 41/100 [18:26<26:08, 26.58s/it]


Epoch 40: train/iit_loss: 0.0107, train/behavior_loss: 0.0002, train/strict_loss: 0.0004, val/iit_loss: 0.0126, val/IIA: 99.63%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.15it/s]
 42%|████▏     | 42/100 [18:52<25:31, 26.41s/it]


Epoch 41: train/iit_loss: 0.0106, train/behavior_loss: 0.0002, train/strict_loss: 0.0004, val/iit_loss: 0.0110, val/IIA: 99.54%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  5.98it/s]
 43%|████▎     | 43/100 [19:19<25:11, 26.52s/it]


Epoch 42: train/iit_loss: 0.0110, train/behavior_loss: 0.0002, train/strict_loss: 0.0005, val/iit_loss: 0.0103, val/IIA: 99.58%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.09it/s]
 44%|████▍     | 44/100 [19:45<24:40, 26.43s/it]


Epoch 43: train/iit_loss: 0.0109, train/behavior_loss: 0.0002, train/strict_loss: 0.0003, val/iit_loss: 0.0110, val/IIA: 99.54%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:21<00:00,  5.83it/s]
 45%|████▌     | 45/100 [20:13<24:34, 26.82s/it]


Epoch 44: train/iit_loss: 0.0125, train/behavior_loss: 0.0002, train/strict_loss: 0.0005, val/iit_loss: 0.0145, val/IIA: 99.46%, val/accuracy: 100.00%, val/strict_accuracy: 99.96%, 


100%|██████████| 125/125 [00:21<00:00,  5.93it/s]
 46%|████▌     | 46/100 [20:40<24:11, 26.87s/it]


Epoch 45: train/iit_loss: 0.0099, train/behavior_loss: 0.0002, train/strict_loss: 0.0004, val/iit_loss: 0.0150, val/IIA: 99.29%, val/accuracy: 100.00%, val/strict_accuracy: 99.98%, 


100%|██████████| 125/125 [00:20<00:00,  6.10it/s]
 47%|████▋     | 47/100 [21:07<23:43, 26.86s/it]


Epoch 46: train/iit_loss: 0.0089, train/behavior_loss: 0.0002, train/strict_loss: 0.0003, val/iit_loss: 0.0151, val/IIA: 99.41%, val/accuracy: 100.00%, val/strict_accuracy: 99.99%, 


100%|██████████| 125/125 [00:20<00:00,  6.19it/s]
 48%|████▊     | 48/100 [21:33<23:07, 26.68s/it]


Epoch 47: train/iit_loss: 0.0117, train/behavior_loss: 0.0002, train/strict_loss: 0.0004, val/iit_loss: 0.0081, val/IIA: 99.63%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.11it/s]
 49%|████▉     | 49/100 [21:59<22:33, 26.55s/it]


Epoch 48: train/iit_loss: 0.0109, train/behavior_loss: 0.0001, train/strict_loss: 0.0004, val/iit_loss: 0.0110, val/IIA: 99.66%, val/accuracy: 100.00%, val/strict_accuracy: 99.98%, 


100%|██████████| 125/125 [00:20<00:00,  6.16it/s]
 50%|█████     | 50/100 [22:25<22:05, 26.51s/it]


Epoch 49: train/iit_loss: 0.0108, train/behavior_loss: 0.0002, train/strict_loss: 0.0003, val/iit_loss: 0.0095, val/IIA: 99.66%, val/accuracy: 100.00%, val/strict_accuracy: 99.99%, 


100%|██████████| 125/125 [00:20<00:00,  6.03it/s]
 51%|█████     | 51/100 [22:52<21:40, 26.55s/it]


Epoch 50: train/iit_loss: 0.0104, train/behavior_loss: 0.0001, train/strict_loss: 0.0003, val/iit_loss: 0.0101, val/IIA: 99.66%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.09it/s]
 52%|█████▏    | 52/100 [23:18<21:06, 26.39s/it]


Epoch 51: train/iit_loss: 0.0119, train/behavior_loss: 0.0002, train/strict_loss: 0.0006, val/iit_loss: 0.0070, val/IIA: 99.73%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.19it/s]
 53%|█████▎    | 53/100 [23:45<20:48, 26.56s/it]


Epoch 52: train/iit_loss: 0.0088, train/behavior_loss: 0.0002, train/strict_loss: 0.0004, val/iit_loss: 0.0114, val/IIA: 99.61%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.12it/s]
 54%|█████▍    | 54/100 [24:11<20:15, 26.42s/it]


Epoch 53: train/iit_loss: 0.0101, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0128, val/IIA: 99.51%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:21<00:00,  5.93it/s]
 55%|█████▌    | 55/100 [24:38<19:54, 26.55s/it]


Epoch 54: train/iit_loss: 0.0091, train/behavior_loss: 0.0001, train/strict_loss: 0.0003, val/iit_loss: 0.0082, val/IIA: 99.66%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:21<00:00,  5.94it/s]
 56%|█████▌    | 56/100 [25:05<19:36, 26.74s/it]


Epoch 55: train/iit_loss: 0.0088, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0088, val/IIA: 99.58%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  5.98it/s]
 57%|█████▋    | 57/100 [25:32<19:11, 26.79s/it]


Epoch 56: train/iit_loss: 0.0087, train/behavior_loss: 0.0001, train/strict_loss: 0.0003, val/iit_loss: 0.0156, val/IIA: 99.29%, val/accuracy: 100.00%, val/strict_accuracy: 99.98%, 


100%|██████████| 125/125 [00:20<00:00,  6.08it/s]
 58%|█████▊    | 58/100 [25:58<18:36, 26.58s/it]


Epoch 57: train/iit_loss: 0.0130, train/behavior_loss: 0.0002, train/strict_loss: 0.0004, val/iit_loss: 0.0202, val/IIA: 99.29%, val/accuracy: 100.00%, val/strict_accuracy: 99.99%, 


100%|██████████| 125/125 [00:21<00:00,  5.92it/s]
 59%|█████▉    | 59/100 [26:26<18:22, 26.89s/it]


Epoch 58: train/iit_loss: 0.0149, train/behavior_loss: 0.0002, train/strict_loss: 0.0009, val/iit_loss: 0.0058, val/IIA: 99.71%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:22<00:00,  5.62it/s]
 60%|██████    | 60/100 [26:55<18:18, 27.47s/it]


Epoch 59: train/iit_loss: 0.0103, train/behavior_loss: 0.0001, train/strict_loss: 0.0004, val/iit_loss: 0.0134, val/IIA: 99.49%, val/accuracy: 100.00%, val/strict_accuracy: 99.99%, 


100%|██████████| 125/125 [00:21<00:00,  5.95it/s]
 61%|██████    | 61/100 [27:21<17:39, 27.17s/it]


Epoch 60: train/iit_loss: 0.0095, train/behavior_loss: 0.0001, train/strict_loss: 0.0003, val/iit_loss: 0.0044, val/IIA: 99.83%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.19it/s]
 62%|██████▏   | 62/100 [27:47<16:55, 26.72s/it]


Epoch 61: train/iit_loss: 0.0091, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0054, val/IIA: 99.78%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.24it/s]
 63%|██████▎   | 63/100 [28:12<16:14, 26.35s/it]


Epoch 62: train/iit_loss: 0.0099, train/behavior_loss: 0.0001, train/strict_loss: 0.0003, val/iit_loss: 0.0109, val/IIA: 99.68%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:21<00:00,  5.90it/s]
 64%|██████▍   | 64/100 [28:39<15:56, 26.58s/it]


Epoch 63: train/iit_loss: 0.0087, train/behavior_loss: 0.0003, train/strict_loss: 0.0006, val/iit_loss: 0.0220, val/IIA: 99.19%, val/accuracy: 100.00%, val/strict_accuracy: 99.96%, 


100%|██████████| 125/125 [00:21<00:00,  5.90it/s]
 65%|██████▌   | 65/100 [29:07<15:44, 26.98s/it]


Epoch 64: train/iit_loss: 0.0112, train/behavior_loss: 0.0002, train/strict_loss: 0.0006, val/iit_loss: 0.0112, val/IIA: 99.73%, val/accuracy: 100.00%, val/strict_accuracy: 99.99%, 


100%|██████████| 125/125 [00:20<00:00,  6.02it/s]
 66%|██████▌   | 66/100 [29:34<15:09, 26.75s/it]


Epoch 65: train/iit_loss: 0.0091, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0101, val/IIA: 99.73%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.00it/s]
 67%|██████▋   | 67/100 [30:00<14:42, 26.76s/it]


Epoch 66: train/iit_loss: 0.0073, train/behavior_loss: 0.0001, train/strict_loss: 0.0003, val/iit_loss: 0.0076, val/IIA: 99.71%, val/accuracy: 100.00%, val/strict_accuracy: 99.99%, 


100%|██████████| 125/125 [00:20<00:00,  6.08it/s]
 68%|██████▊   | 68/100 [30:26<14:10, 26.59s/it]


Epoch 67: train/iit_loss: 0.0084, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0056, val/IIA: 99.73%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:21<00:00,  5.84it/s]
 69%|██████▉   | 69/100 [30:54<13:49, 26.77s/it]


Epoch 68: train/iit_loss: 0.0114, train/behavior_loss: 0.0001, train/strict_loss: 0.0006, val/iit_loss: 0.0099, val/IIA: 99.56%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.05it/s]
 70%|███████   | 70/100 [31:20<13:20, 26.69s/it]


Epoch 69: train/iit_loss: 0.0064, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0120, val/IIA: 99.46%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:21<00:00,  5.82it/s]
 71%|███████   | 71/100 [31:47<12:58, 26.85s/it]


Epoch 70: train/iit_loss: 0.0092, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0085, val/IIA: 99.61%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:21<00:00,  5.83it/s]
 72%|███████▏  | 72/100 [32:15<12:39, 27.11s/it]


Epoch 71: train/iit_loss: 0.0080, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0198, val/IIA: 99.34%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:21<00:00,  5.90it/s]
 73%|███████▎  | 73/100 [32:42<12:13, 27.16s/it]


Epoch 72: train/iit_loss: 0.0101, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0074, val/IIA: 99.63%, val/accuracy: 100.00%, val/strict_accuracy: 99.98%, 


100%|██████████| 125/125 [00:20<00:00,  6.07it/s]
 74%|███████▍  | 74/100 [33:09<11:38, 26.86s/it]


Epoch 73: train/iit_loss: 0.0112, train/behavior_loss: 0.0001, train/strict_loss: 0.0003, val/iit_loss: 0.0080, val/IIA: 99.73%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.04it/s]
 75%|███████▌  | 75/100 [33:35<11:06, 26.64s/it]


Epoch 74: train/iit_loss: 0.0088, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0072, val/IIA: 99.76%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:21<00:00,  5.94it/s]
 76%|███████▌  | 76/100 [34:01<10:40, 26.68s/it]


Epoch 75: train/iit_loss: 0.0065, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0044, val/IIA: 99.83%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  5.96it/s]
 77%|███████▋  | 77/100 [34:28<10:13, 26.66s/it]


Epoch 76: train/iit_loss: 0.0062, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0054, val/IIA: 99.85%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.01it/s]
 78%|███████▊  | 78/100 [34:54<09:43, 26.51s/it]


Epoch 77: train/iit_loss: 0.0081, train/behavior_loss: 0.0001, train/strict_loss: 0.0003, val/iit_loss: 0.0060, val/IIA: 99.76%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.08it/s]
 79%|███████▉  | 79/100 [35:21<09:17, 26.55s/it]


Epoch 78: train/iit_loss: 0.0150, train/behavior_loss: 0.0002, train/strict_loss: 0.0007, val/iit_loss: 0.0164, val/IIA: 99.56%, val/accuracy: 100.00%, val/strict_accuracy: 99.98%, 


100%|██████████| 125/125 [00:20<00:00,  6.14it/s]
 80%|████████  | 80/100 [35:47<08:48, 26.44s/it]


Epoch 79: train/iit_loss: 0.0085, train/behavior_loss: 0.0001, train/strict_loss: 0.0003, val/iit_loss: 0.0094, val/IIA: 99.58%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.03it/s]
 81%|████████  | 81/100 [36:14<08:23, 26.48s/it]


Epoch 80: train/iit_loss: 0.0087, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0084, val/IIA: 99.73%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:21<00:00,  5.77it/s]
 82%|████████▏ | 82/100 [36:41<08:03, 26.86s/it]


Epoch 81: train/iit_loss: 0.0067, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0071, val/IIA: 99.73%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:21<00:00,  5.74it/s]
 83%|████████▎ | 83/100 [37:11<07:48, 27.54s/it]


Epoch 82: train/iit_loss: 0.0091, train/behavior_loss: 0.0001, train/strict_loss: 0.0004, val/iit_loss: 0.0095, val/IIA: 99.54%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:22<00:00,  5.66it/s]
 84%|████████▍ | 84/100 [37:39<07:26, 27.91s/it]


Epoch 83: train/iit_loss: 0.0100, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0083, val/IIA: 99.63%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:21<00:00,  5.75it/s]
 85%|████████▌ | 85/100 [38:07<06:56, 27.77s/it]


Epoch 84: train/iit_loss: 0.0080, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0088, val/IIA: 99.63%, val/accuracy: 100.00%, val/strict_accuracy: 99.99%, 


100%|██████████| 125/125 [00:21<00:00,  5.93it/s]
 86%|████████▌ | 86/100 [38:33<06:23, 27.43s/it]


Epoch 85: train/iit_loss: 0.0073, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0060, val/IIA: 99.80%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.10it/s]
 87%|████████▋ | 87/100 [38:59<05:51, 27.04s/it]


Epoch 86: train/iit_loss: 0.0067, train/behavior_loss: 0.0001, train/strict_loss: 0.0001, val/iit_loss: 0.0075, val/IIA: 99.66%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.19it/s]
 88%|████████▊ | 88/100 [39:25<05:19, 26.64s/it]


Epoch 87: train/iit_loss: 0.0082, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0043, val/IIA: 99.83%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.24it/s]
 89%|████████▉ | 89/100 [39:51<04:49, 26.28s/it]


Epoch 88: train/iit_loss: 0.0126, train/behavior_loss: 0.0002, train/strict_loss: 0.0005, val/iit_loss: 0.0116, val/IIA: 99.51%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.19it/s]
 90%|█████████ | 90/100 [40:16<04:20, 26.07s/it]


Epoch 89: train/iit_loss: 0.0079, train/behavior_loss: 0.0001, train/strict_loss: 0.0001, val/iit_loss: 0.0053, val/IIA: 99.78%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.18it/s]
 91%|█████████ | 91/100 [40:42<03:53, 25.99s/it]


Epoch 90: train/iit_loss: 0.0104, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0097, val/IIA: 99.66%, val/accuracy: 100.00%, val/strict_accuracy: 99.99%, 


100%|██████████| 125/125 [00:20<00:00,  6.11it/s]
 92%|█████████▏| 92/100 [41:09<03:29, 26.20s/it]


Epoch 91: train/iit_loss: 0.0098, train/behavior_loss: 0.0001, train/strict_loss: 0.0003, val/iit_loss: 0.0127, val/IIA: 99.41%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.03it/s]
 93%|█████████▎| 93/100 [41:35<03:03, 26.27s/it]


Epoch 92: train/iit_loss: 0.0130, train/behavior_loss: 0.0001, train/strict_loss: 0.0010, val/iit_loss: 0.0062, val/IIA: 99.68%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.10it/s]
 94%|█████████▍| 94/100 [42:02<02:39, 26.51s/it]


Epoch 93: train/iit_loss: 0.0094, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0064, val/IIA: 99.68%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.03it/s]
 95%|█████████▌| 95/100 [42:29<02:12, 26.46s/it]


Epoch 94: train/iit_loss: 0.0079, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0037, val/IIA: 99.88%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.15it/s]
 96%|█████████▌| 96/100 [42:54<01:45, 26.29s/it]


Epoch 95: train/iit_loss: 0.0079, train/behavior_loss: 0.0001, train/strict_loss: 0.0003, val/iit_loss: 0.0113, val/IIA: 99.51%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.24it/s]
 97%|█████████▋| 97/100 [43:20<01:18, 26.16s/it]


Epoch 96: train/iit_loss: 0.0084, train/behavior_loss: 0.0001, train/strict_loss: 0.0002, val/iit_loss: 0.0057, val/IIA: 99.76%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.12it/s]
 98%|█████████▊| 98/100 [43:46<00:52, 26.06s/it]


Epoch 97: train/iit_loss: 0.0089, train/behavior_loss: 0.0001, train/strict_loss: 0.0004, val/iit_loss: 0.0071, val/IIA: 99.73%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  5.98it/s]
 99%|█████████▉| 99/100 [44:13<00:26, 26.21s/it]


Epoch 98: train/iit_loss: 0.0061, train/behavior_loss: 0.0001, train/strict_loss: 0.0001, val/iit_loss: 0.0050, val/IIA: 99.85%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 


100%|██████████| 125/125 [00:20<00:00,  6.13it/s]
100%|██████████| 100/100 [44:39<00:00, 26.79s/it]


Epoch 99: train/iit_loss: 0.0059, train/behavior_loss: 0.0001, train/strict_loss: 0.0001, val/iit_loss: 0.0073, val/IIA: 99.76%, val/accuracy: 100.00%, val/strict_accuracy: 100.00%, 



