In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import iit
print(iit.__file__)

/Users/evananders/far_cluster/iit/iit/__init__.py


In [3]:
import transformer_lens as tl
import numpy as np
import torch as t
import wandb

from iit.model_pairs.ioi_model_pair import IOI_ModelPair
from iit.utils.iit_dataset import train_test_split
from iit.utils.iit_dataset import IITDataset
# from iit.model_pairs.base_model_pair import *
# from iit.utils.metric import *
from iit.tasks.ioi import (
    NAMES,
    make_ioi_dataset_and_hl,
    make_corr_dict,
    ioi_cfg,
    suffixes
)
from iit.utils.correspondence import Correspondence
from iit.utils.argparsing import IOIArgParseNamespace


def train_ioi(args: IOIArgParseNamespace) -> IOI_ModelPair:
    device = args.device
    num_samples = args.num_samples
    epochs = args.epochs
    use_wandb = args.use_wandb

    training_args = {
        "batch_size": args.batch_size,
        "lr": args.lr,
        "iit_weight": args.iit,
        "behavior_weight": args.b,
        "strict_weight": args.s,
        "next_token": args.next_token,
        "lr_scheduler": None,
        "clip_grad_norm": args.clip_grad_norm,
        "early_stop": True,
        "use_single_loss": args.use_single_loss,
    }
    t.manual_seed(0)
    np.random.seed(0)

    ll_cfg = tl.HookedTransformer.from_pretrained(
        "gpt2"
    ).cfg.to_dict()
    ll_cfg.update(ioi_cfg)

    ll_cfg["init_weights"] = True
    ll_model = tl.HookedTransformer(ll_cfg).to(device)
    print("making ioi dataset and hl")
    ioi_dataset, hl_model = make_ioi_dataset_and_hl(
        num_samples, ll_model, NAMES, device=args.device, verbose=True
    )
    print("making IIT dataset")
    train_ioi_dataset, test_ioi_dataset = train_test_split(
        ioi_dataset, test_size=0.2, random_state=42
    )
    train_set = IITDataset(train_ioi_dataset, train_ioi_dataset, seed=0)
    test_set = IITDataset(test_ioi_dataset, test_ioi_dataset, seed=0)
    print("making ioi model pair")
    corr_dict = make_corr_dict(include_mlp=args.include_mlp)
    corr = Correspondence.make_corr_from_dict(corr_dict, suffixes=suffixes)
    model_pair = IOI_ModelPair(
        ll_model=ll_model,
        hl_model=hl_model,
        corr=corr,
        training_args=training_args,
    )
    print("training ioi model pair")
    model_pair.train(train_set, test_set, epochs=epochs, use_wandb=use_wandb)
    print(f"done training")

    if use_wandb:
        wandb.finish()
    return model_pair

In [4]:
train_args = IOIArgParseNamespace(
    include_mlp = True,
    use_wandb = False,
    num_samples = 120_000,
    batch_size = 128,
    next_token = False,
    
    epochs = 100,
    lr = 1e-3,
    iit = 1.0,
    b = 1.0,
    s = 1.0,
    clip_grad_norm = 1.0,
    use_single_loss = True,
    save_to_wandb = False
)

D_MODEL = 32
N_CTX = 23
N_LAYERS = 3
N_HEADS = 4
D_VOCAB = 4

#Want to specify this somewhere central, e.g., iit.tasks.ioi but iit.tasks.parens
ll_cfg = tl.HookedTransformerConfig(
        n_layers = N_LAYERS,
        d_model = D_MODEL,
        n_ctx = N_CTX,
        d_head = D_MODEL // N_HEADS,
        d_vocab = D_VOCAB,
        act_fn = "relu",
)

ll_model = tl.HookedTransformer(ll_cfg).to(train_args.device)

Moving model to device:  cpu


In [5]:
from paren_checker import HighLevelParensBalanceChecker, TwoTaskParensDataset
from torch.utils.data import Dataset

hl_model = HighLevelParensBalanceChecker(device=train_args.device)
dataset = TwoTaskParensDataset(
    N_samples = 20_000,
    n_ctx = N_CTX,
    seed = 42,
)

class CustomDataset(Dataset):
    def __init__(self, data, targets, markers):
        """
        Args:
            data (list or numpy array): List or array of input data.
            targets (list or numpy array): List or array of target data.
        """
        self.data = t.tensor(data).to(int)
        self.targets = t.tensor(targets).to(int)
        self.markers = t.tensor(markers).to(int)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        """
        Args:
            idx (int): Index
        Returns:
            tuple: (input tensor, target tensor)
        """
        return self.data[idx], self.targets[idx], self.markers[idx]


decorated_dset = CustomDataset(
    data = dataset.get_dataset()['tokens'],
    targets = np.array(dataset.get_dataset()['labels'])[:, None],
    markers = np.array(dataset.get_dataset()['markers'])[:, None]
)


print("making IIT dataset")
train_dataset, test_dataset = train_test_split(
    decorated_dset, test_size=0.2, random_state=42
)
train_set = IITDataset(train_dataset, train_dataset, seed=0)
test_set = IITDataset(test_dataset, test_dataset, seed=0)

making IIT dataset


In [6]:
print(train_set[0])

((tensor([3, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2]), tensor([1]), tensor([1])), (tensor([3, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2]), tensor([0]), tensor([2])))


In [41]:
from typing import Callable
from dataclasses import asdict
from iit.utils.index import Ix, TorchIndex
from iit.utils.nodes import HLNode

all_attns = [f"blocks.{i}.attn.hook_z" for i in range(ll_cfg.n_layers)]
all_mlps = [f"blocks.{i}.mlp.hook_post" for i in range(ll_cfg.n_layers)]
all_nodes_hook = "blocks.0.hook_resid_pre"
head_index: Callable[[int], TorchIndex] = lambda idx: Ix[[None, None, idx, None]]
#make correlation
corr_dict = {
        'input_hook' :           [(all_nodes_hook,  Ix[[None]], None)],
        'left_parens_hook' :     [(all_attns[0],    head_index(0), None)],
        'right_parens_hook' :    [(all_attns[0],    head_index(1), None)],
        'task_hook':             [(all_attns[0],    head_index(2), None)],
        'mlp0_hook':             [(all_mlps[0],     Ix[[None]], None)],
        'mlp1_hook' :            [(all_mlps[1],     Ix[[None]], None)],
        'horizon_lookback_hook': [(all_attns[2],    head_index(3), None)],
        'output_check_hook' :    [(all_mlps[2],     Ix[[None]], None)]
    }


print("making model pair")
corr = Correspondence.make_corr_from_dict(corr_dict)

class ParensModelPair(IOI_ModelPair):

    @property
    def loss_fn(self) -> Callable[[t.Tensor, t.Tensor], t.Tensor]:

        def per_token_weighted_cross_entropy(output: t.Tensor, target: t.Tensor) -> t.Tensor:
            if len(target.shape) == 2 and target.shape[1] == 2: #dumb one-hot fix
                true_target = t.zeros(target.shape[0])
                true_target[target[:, 1] == 1] = 1
                target = true_target
            return t.nn.BCEWithLogitsLoss()(output[:,-1], target.to(float).squeeze())
            
        self.__loss_fn = per_token_weighted_cross_entropy
        return self.__loss_fn

    def get_behaviour_loss_over_batch(
        self,
        base_input: tuple[t.Tensor, t.Tensor, t.Tensor],
        loss_fn: Callable[[t.Tensor, t.Tensor], t.Tensor],
    ) -> t.Tensor:
        x, y = base_input[0:2]
        ll_output = self.ll_model(x)
        # hl_argmax = t.argmax(hl_output[:, -1, :], dim=-1)

        loss = loss_fn(ll_output[:, -1, :], y[:,0])
        return loss
    
    #TODO: Fix this so that it works with ParensChecker.
    def run_eval_step(
        self,
        base_input: tuple[t.Tensor, t.Tensor, t.Tensor],
        ablation_input: tuple[t.Tensor, t.Tensor, t.Tensor],
        loss_fn: Callable[[t.Tensor, t.Tensor], t.Tensor],
    ) -> dict:
        # compute IIT loss and accuracy on last token position only
        hl_node = self.sample_hl_name()
        hl_output, ll_output = self.do_intervention(base_input, ablation_input, hl_node)
        # CrossEntropyLoss needs target probs, not logits
        # hl_output = t.nn.functional.softmax(hl_output, dim=-1)
        hl_argmax = t.argmax(hl_output[:, -1, :], dim=-1)
        hl_one_hot = t.nn.functional.one_hot(hl_argmax, num_classes=hl_output.shape[-1])
        hl_probs = hl_one_hot.float()
        assert self.hl_model.is_categorical()
        loss = loss_fn(ll_output[:, -1, :], hl_probs)
        if ll_output.shape == hl_output.shape:
            # To handle the case when labels are one-hot
            hl_output = t.argmax(hl_output, dim=-1)
        top1 = t.argmax(ll_output, dim=-1)
        accuracy = (top1[:, -1] == hl_output[:, -1]).float().mean().item()
        IIA = accuracy

        # compute behavioral accuracy
        base_x, base_y = base_input[0:2]
        output = self.ll_model(base_x)
        top1 = t.argmax(output, dim=-1)  # batch n_ctx
        if output.shape == base_y.shape:
            # To handle the case when labels are one-hot
            # TODO: is there a better way?
            base_y = t.argmax(base_y, dim=-1)  # batch n_ctx
        per_token_accuracy = (top1 == base_y).float().mean(dim=0).cpu().numpy()


        # strict accuracy
        base_x, base_y = base_input[0:2]
        ablation_x, ablation_y = ablation_input[0:2]
        # ll_node = self.sample_ll_node() 
        _, cache = self.ll_model.run_with_cache(ablation_x)
        self.ll_cache = cache
        label_idx = self.get_label_idxs()
        base_y = base_y[label_idx.as_index].to(self.ll_model.cfg.device)
        if self.hl_model.is_categorical:
            if len(base_y.shape) == 2:
                base_y = t.argmax(base_y, dim=-1)
        accuracies = []
        for node in self.nodes_not_in_circuit:
            out = self.ll_model.run_with_hooks(
                base_x, fwd_hooks=[(node.name, self.make_ll_ablation_hook(node))]
            )
            ll_output = out[label_idx.as_index]
            if self.hl_model.is_categorical:
                top1 = t.argmax(ll_output, dim=-1)
                accuracy = (top1 == base_y).float().mean().item()
            else:
                accuracy = ((ll_output - base_y).abs() < self.training_args["atol"]).float().mean().item()
            accuracies.append(accuracy)
        strict_accuracy = np.mean(accuracies)

        return {
            "val/iit_loss": loss.item(),
            "val/IIA": IIA,
            "val/accuracy": (
                per_token_accuracy.mean().item()
                if self.next_token
                else per_token_accuracy[-1]
            ),
            "val/strict_accuracy": strict_accuracy,
            "val/per_token_accuracy": per_token_accuracy,
        }

model_pair = ParensModelPair(
    ll_model=ll_model,
    hl_model=hl_model,
    corr=corr,
    training_args=asdict(train_args),
)

making model pair
{'input_hook': HookPoint(), 'left_parens_hook': HookPoint(), 'right_parens_hook': HookPoint(), 'task_hook': HookPoint(), 'greater_hook': HookPoint(), 'elevation_hook': HookPoint(), 'mlp0_hook': HookPoint(), 'mlp1_hook': HookPoint(), 'horizon_lookback_hook': HookPoint(), 'output_check_hook': HookPoint()}
dict_keys([input_hook, left_parens_hook, right_parens_hook, task_hook, mlp0_hook, mlp1_hook, horizon_lookback_hook, output_check_hook])


In [42]:
print("training model pair")
model_pair.train(train_set, test_set, epochs=train_args.epochs, use_wandb=train_args.use_wandb)

training model pair
training_args={'next_token': False, 'non_ioi_thresh': 0.65, 'use_per_token_check': False, 'batch_size': 128, 'lr': 0.001, 'num_workers': 0, 'early_stop': True, 'lr_scheduler': None, 'scheduler_val_metric': ['val/accuracy', 'val/IIA'], 'scheduler_mode': 'max', 'clip_grad_norm': 1.0, 'seed': 0, 'detach_while_caching': True, 'atol': 0.05, 'use_single_loss': True, 'iit_weight': 1.0, 'behavior_weight': 1.0, 'strict_weight': 1.0, 'output_dir': './results', 'include_mlp': True, 'use_wandb': False, 'num_samples': 120000, 'device': 'cpu', 'weights': '100_100_40', 'mean': True, 'load_from_wandb': False, 'epochs': 100, 'iit': 1.0, 'b': 1.0, 's': 1.0, 'save_to_wandb': False}


100%|██████████| 125/125 [00:17<00:00,  7.03it/s]
  0%|          | 0/100 [00:17<?, ?it/s]


RuntimeError: The size of tensor a (128) must match the size of tensor b (2) at non-singleton dimension 1